In [1]:
import pickle as pkl
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from scipy.stats import entropy
import streamlit as st
import pickle as pkl

import torch
from sklearn.model_selection import train_test_split

from denoisers.ConditionalUnetDenoiser import ConditionalUnetDenoiser
from denoisers.ConditionalUnetMatrixDenoiser import ConditionalUnetMatrixDenoiser
from utils.graph_utils import get_process_model_reachability_graph_transition_matrix, get_process_model_petri_net_transition_matrix
from utils.pm_utils import discover_dk_process, remove_duplicates_dataset, pad_to_multiple_of_n
from utils.Config import Config
import plotly.express as px
import plotly.graph_objects as go
from dataset.dataset import SaladsDataset
from ddpm.ddpm_multinomial import Diffusion
import os
import json
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from utils.pm_utils import conformance_measure
import numpy as np
from scipy.stats import wasserstein_distance
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score

In [2]:
with open("../data/pickles/50_salads_unified.pkl", "rb") as f:
    salads_data = pkl.load(f)
salads_dataset = SaladsDataset(salads_data['target'], salads_data['stochastic'])
salads_train, salads_test = train_test_split(salads_dataset, train_size=0.75, shuffle=True, random_state=42)

In [5]:
data_paths = ["50_salads_synth_from_det_{:.3f}.pkl".format(x) for x in [0.68, 0.72, 0.75, 0.8, 0.83]]
datas = []
for data_path in data_paths:
    with open(f"../data/synthetic/{data_path}", "rb") as f:
        data = pkl.load(f)
        datas.append(data)

In [6]:
datas.append(salads_data)
augmented_data = {"target": [], "stochastic": []}
for data in datas:
    augmented_data["target"].extend(data["target"])
    augmented_data["stochastic"].extend(data["stochastic"])

In [7]:
with open("../data/pickles/50_salads_aug.pkl", "wb") as f:
    pkl.dump(augmented_data, f)

In [3]:
def load_experiment_config(target_dir):
    config_path = os.path.join(target_dir, "cfg.json")
    if os.path.exists(config_path):
        with open(config_path, "r") as f:
            return Config(**json.load(f))
    else:
        st.warning("Configuration file not found.")
        return None
target_dir = r"D:\Projects\trace-denoise\final_runs\50_salads_unified_rg"
cfg = load_experiment_config(target_dir)

In [5]:
dk_process_model, dk_init_marking, dk_final_marking = discover_dk_process(salads_train, cfg,
                                                                          preprocess=remove_duplicates_dataset)

In [19]:
from pm4py.objects.petri_net.utils import reachability_graph
import networkx as nx

rg = reachability_graph.construct_reachability_graph(dk_process_model, dk_init_marking)

rg_nx = nx.MultiDiGraph()

for state in rg.states:
    rg_nx.add_node(state.name)

transition_names = {tuple(s.strip(" '") for s in transition.name.strip("()").split(","))[1] for transition in
                    rg.transitions}
transition_name_index = {name: idx for idx, name in enumerate(sorted(transition_names))}

In [9]:
rg.transitions

{(38f24f27-ee1a-4e7e-aa0d-88e055902a52, 'cut_cheese'),
 (skip_3, None),
 (skip_14, None),
 (skip_3, None),
 (skip_3, None),
 (skip_20, None),
 (skip_27, None),
 (7653de03-a2bc-4041-9155-a8ddef1c9eb2, 'add_oil'),
 (skip_8, None),
 (skip_19, None),
 (skip_19, None),
 (d1ab66f1-bb11-4263-a263-2d6ea71d302a, 'add_vinegar'),
 (d1ab66f1-bb11-4263-a263-2d6ea71d302a, 'add_vinegar'),
 (7653de03-a2bc-4041-9155-a8ddef1c9eb2, 'add_oil'),
 (d1ab66f1-bb11-4263-a263-2d6ea71d302a, 'add_vinegar'),
 (b9ee192d-6f83-402b-8f06-b1b529d19e4a, 'place_cucumber_into_bowl'),
 (init_loop_12, None),
 (skip_26, None),
 (skip_27, None),
 (38f24f27-ee1a-4e7e-aa0d-88e055902a52, 'cut_cheese'),
 (c5f7565e-94cf-4fa9-b0bc-879dbc339397, 'add_dressing'),
 (d1ab66f1-bb11-4263-a263-2d6ea71d302a, 'add_vinegar'),
 (c5f7565e-94cf-4fa9-b0bc-879dbc339397, 'add_dressing'),
 (c5f7565e-94cf-4fa9-b0bc-879dbc339397, 'add_dressing'),
 (d1ab66f1-bb11-4263-a263-2d6ea71d302a, 'add_vinegar'),
 (skip_3, None),
 (38f24f27-ee1a-4e7e-aa0d-88e055

In [20]:
for transition in rg.transitions:
    transition_name = tuple(s.strip(" '") for s in transition.name.strip("()").split(","))
    # if (transition.from_state.name, transition.to_state.name) in rg_nx.edges():
    #     print(f"edge already exists: {transition.from_state.name} -> {transition.to_state.name}")
    #     rg_nx[transition.from_state.name][transition.to_state.name]['label'] = (*rg_nx[transition.from_state.name][transition.to_state.name]['label'], transition_name)
    # else:
    rg_nx.add_edge(
        transition.from_state.name,
        transition.to_state.name,
        label=transition_name
    )

In [23]:
# Dictionary to collect edges for each directed node pair
directed_pairs = {}

# Iterate over all edges in the MultiDiGraph
for u, v, key, data in rg_nx.edges(data=True, keys=True):
    # Use tuple (u, v) as the key to group edges
    directed_pairs.setdefault((u, v), []).append(data.get("label"))

# Process each pair to check if there are multiple edges; then print the labels.
for (u, v), labels in directed_pairs.items():
    if len(labels) > 1:
        print(f"Directed pair ({u}, {v}) has {len(labels)} edges with labels: {labels}")

Directed pair (p_111p_231p_321p_51p_81p_91, p_121p_231p_321p_51p_81p_91) has 2 edges with labels: [('skip_3', 'None'), ('c5f7565e-94cf-4fa9-b0bc-879dbc339397', 'add_dressing')]
Directed pair (p_111p_231p_321p_51p_81p_91, p_111p_231p_331p_51p_81p_91) has 3 edges with labels: [('4b26f01c-29f2-499b-a0fe-83320723afb4', 'mix_dressing'), ('ef0f8f7b-273c-4e5f-a1b5-e5ed6d8d0222', 'mix_ingredients'), ('79d5922e-d2cb-4f6e-9c66-7b2a1022e16c', 'place_tomato_into_bowl')]
Directed pair (p_111p_221p_281p_51p_71p_91, p_121p_221p_281p_51p_71p_91) has 2 edges with labels: [('skip_3', 'None'), ('c5f7565e-94cf-4fa9-b0bc-879dbc339397', 'add_dressing')]
Directed pair (p_101p_111p_221p_261p_61p_71, p_101p_121p_221p_261p_61p_71) has 2 edges with labels: [('c5f7565e-94cf-4fa9-b0bc-879dbc339397', 'add_dressing'), ('skip_3', 'None')]
Directed pair (p_111p_211p_331p_51p_81p_91, p_121p_211p_331p_51p_81p_91) has 2 edges with labels: [('skip_3', 'None'), ('c5f7565e-94cf-4fa9-b0bc-879dbc339397', 'add_dressing')]
Dire

In [24]:
nodes = sorted(rg_nx.nodes())
num_transitions = len(transition_names)
num_nodes = len(nodes)
transition_matrix = np.zeros((num_transitions, num_nodes, num_nodes), dtype=int)

for edge in rg_nx.edges(data=True):
    from_node = nodes.index(edge[0])
    to_node = nodes.index(edge[1])
    transition_name = edge[2]['label'][1]
    if transition_name in transition_name_index:
        transition_idx = transition_name_index[transition_name]
        transition_matrix[transition_idx, from_node, to_node] = 1
    else:
        raise RuntimeError(f"somehow, transition: {transition_name} was encountered but not indexed")

In [31]:
for i in range(num_nodes):
    for j in range(num_nodes):
        count = 0
        for k in range(num_transitions):
            if transition_matrix[k, i, j] == 1:
                count += 1
            if count > 3:
                print(f"Transition {k} from {i} to {j}")

Transition 15 from 16 to 24
Transition 16 from 16 to 24
Transition 17 from 16 to 24
Transition 18 from 16 to 24
Transition 19 from 16 to 24
Transition 20 from 16 to 24
Transition 15 from 17 to 25
Transition 16 from 17 to 25
Transition 17 from 17 to 25
Transition 18 from 17 to 25
Transition 19 from 17 to 25
Transition 20 from 17 to 25
Transition 15 from 18 to 26
Transition 16 from 18 to 26
Transition 17 from 18 to 26
Transition 18 from 18 to 26
Transition 19 from 18 to 26
Transition 20 from 18 to 26
Transition 15 from 19 to 27
Transition 16 from 19 to 27
Transition 17 from 19 to 27
Transition 18 from 19 to 27
Transition 19 from 19 to 27
Transition 20 from 19 to 27
Transition 15 from 172 to 180
Transition 16 from 172 to 180
Transition 17 from 172 to 180
Transition 18 from 172 to 180
Transition 19 from 172 to 180
Transition 20 from 172 to 180
Transition 15 from 173 to 181
Transition 16 from 173 to 181
Transition 17 from 173 to 181
Transition 18 from 173 to 181
Transition 19 from 173 to 18