In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
from utils.Config import Config

with open("config.json", "r") as f:
    cfg_json = json.load(f)
    cfg = Config(**cfg_json)

In [3]:
import pickle as pkl
from dataset.dataset import SaladsDataset

with open("../data/pickles/breakfast_unified.pkl", "rb") as f:
    dataset = pkl.load(f)

salads_dataset = SaladsDataset(dataset['target'], dataset['stochastic'])

In [4]:
from sklearn.model_selection import train_test_split

train_dataset, test_dataset = train_test_split(salads_dataset, train_size=0.7, shuffle=True, random_state=42)

In [5]:
from utils.pm_utils import discover_dk_process, remove_duplicates_dataset
from utils.graph_utils import prepare_process_model_for_gnn

dk_process_model, dk_init_marking, dk_final_marking = discover_dk_process(train_dataset, cfg, preprocess=remove_duplicates_dataset)

In [6]:
pm_nx_data = prepare_process_model_for_gnn(dk_process_model, dk_init_marking, dk_final_marking, cfg)

  data_dict[key] = torch.as_tensor(value)


In [7]:
from utils.graph_utils import get_process_model_petri_net_transition_matrix
import torch

rg_nx, rg_transition_matrix = get_process_model_petri_net_transition_matrix(dk_process_model,
                                                                                        dk_init_marking,
                                                                                        dk_final_marking)
rg_transition_matrix = torch.tensor(rg_transition_matrix, device=cfg.device).unsqueeze(0).float()

In [14]:
from modules.GraphEncoder import GraphEncoder

g_enc = GraphEncoder(pm_nx_data.num_nodes, 128, 128, 128, pooling=None, conv_type="gin").to(cfg.device).float()
g_enc_add = GraphEncoder(pm_nx_data.num_nodes, 128, 128, 128).to(cfg.device).float()

In [15]:
pm_nx_data = pm_nx_data.to(cfg.device)

In [16]:
out_g = g_enc(pm_nx_data)

In [17]:
out_g_add = g_enc_add(pm_nx_data)

In [18]:
out_g.shape

torch.Size([248, 128])

In [17]:
from modules.GPSGraphEncoder import GPSGraphEncoder

gps_enc = GPSGraphEncoder(pm_nx_data.num_nodes).to(cfg.device)

In [18]:
gps_enc(pm_nx_data)

tensor([[ 1.1455e-01, -2.3044e-01,  7.7353e-02,  1.5637e-01, -3.8996e-02,
          2.3642e-01,  1.8649e-01, -2.4842e-01,  2.7754e-01, -2.4390e-01,
          2.5757e-02, -2.3198e-01, -1.1258e-01,  1.3795e-01, -3.1045e-02,
          2.1474e-02,  3.7689e-03, -2.3639e-01,  2.1930e-01, -3.2777e-02,
          1.4851e-01,  3.0859e-01,  5.6461e-02, -1.6508e-01,  2.8071e-01,
         -3.2719e-01,  4.0509e-03, -2.7674e-01,  3.3085e-01, -7.0256e-02,
          3.0164e-01, -1.2166e-01,  7.3045e-02, -3.4861e-02, -6.5103e-02,
         -1.3883e-01,  5.1055e-01, -4.5663e-01, -1.1720e-01,  1.9582e-02,
         -1.0648e-01,  1.6981e-01, -5.0277e-02,  2.4448e-02,  8.5197e-02,
         -1.2556e-01, -1.7314e-01, -7.2981e-02,  3.5583e-02,  4.9703e-02,
          1.7977e-01, -1.4681e-01,  1.2579e-01,  2.7148e-02, -7.6920e-02,
         -1.6628e-01,  3.1316e-01, -1.0636e-01, -2.2127e-01,  3.1380e-01,
          2.9516e-02, -1.3896e-01,  7.5745e-02, -2.2781e-01,  3.4043e-01,
         -4.2164e-02, -3.1808e-01, -1.

In [12]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=8
    )

dummy_x, dummy_y = (w.permute(0, 2, 1).to(cfg.device).float() for w in next(iter(train_loader)))

In [13]:
from ddpm.ddpm_multinomial import Diffusion

diffuser = Diffusion(noise_steps=cfg.num_timesteps, device=cfg.device)

In [19]:
from denoisers.ConditionalUnetGraphDenoiser import ConditionalUnetGraphDenoiser

graph_denoiser_atn = ConditionalUnetGraphDenoiser(in_ch=cfg.num_classes, out_ch=cfg.num_classes, max_input_dim=salads_dataset.sequence_length, graph_data=pm_nx_data, num_nodes=pm_nx_data.num_nodes, embedding_dim=128, hidden_dim=128, pooling=None).to(cfg.device)

In [20]:
graph_denoiser = ConditionalUnetGraphDenoiser(in_ch=cfg.num_classes, out_ch=cfg.num_classes, max_input_dim=salads_dataset.sequence_length, graph_data=pm_nx_data, num_nodes=pm_nx_data.num_nodes, embedding_dim=128, hidden_dim=128, pooling='mean').to(cfg.device)

In [21]:
from denoisers.ConditionalUnetMatrixDenoiser import ConditionalUnetMatrixDenoiser

matrix_denoiser = ConditionalUnetMatrixDenoiser(in_ch=cfg.num_classes, out_ch=cfg.num_classes,
                                                 max_input_dim=salads_dataset.sequence_length,
                                                 transition_dim=rg_transition_matrix.shape[-1],
                                                 gamma=cfg.gamma,
                                                 matrix_out_channels=rg_transition_matrix.shape[0],
                                                 device=cfg.device).to(cfg.device).float()

In [22]:
sum(p.numel() for p in matrix_denoiser.parameters() if p.requires_grad)

46505010

In [23]:
sum(p.numel() for p in graph_denoiser.parameters() if p.requires_grad)

18676465

In [24]:
sum(p.numel() for p in graph_denoiser_atn.parameters() if p.requires_grad)

20758769

In [27]:
t = diffuser.sample_timesteps(dummy_x.shape[0]).to(cfg.device)

In [28]:
graph_denoiser(dummy_x, t, None, None)[0].shape

RuntimeError: shape '[3, 64, -1]' is invalid for input of size 4672