In [None]:
# Importing dependencies
%matplotlib inline

from typing import List, Dict
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import pandas as pd
import torch
from sklearn import metrics
from tqdm import tqdm
import assignment3_utils as utils
from placeholder import PlaceHolder, to_dense, er_validation_step
import placeholder
import importlib
import torch.nn as nn
from torch_geometric.nn.models import MLP, GIN, GAT
from torch_geometric.data import Batch, Data
from torch_geometric.utils import dense_to_sparse
from itertools import combinations
from assignment3_utils import EigenFeatures
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool as gap
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
import random
from transformer_model import GraphTransformer
from assignment3_utils import NodeCycleFeatures
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import hmean
import pickle
from scipy.stats import ks_2samp


utils.seed_everything(0)
device = "cuda:0" if torch.cuda.is_available() else "cpu"  # cuda index to be changed

## *Loading total Dataset* ##


In [None]:
total_dataset_np = np.load("./qm9_skeleton_9_nodes_345.npy")

## *Utils for Forward Process* ##

In [None]:
T = 1000
all_betas = utils.get_betas(timesteps=T, s=0.008)
all_alphas = utils.get_alphas(timesteps=T)
all_alphas_bar = utils.get_alphas_bar(all_alphas)
n_nodes_dist = torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=torch.float64)

def get_Q_t(betas: torch.Tensor, noise_dist: torch.Tensor, e_class=2) -> torch.Tensor:
    """
    Input:
        - betas, (bs, )
        - noise_dist, (e_class, 1) -> the noise distribution at step $t=1,000$, here we take by default (0.5, 0.5)
        - e_class, number of classes for edges
    Output:
        - q_e, (bs, e_class, e_class)
    """

    q_e = torch.zeros(betas.shape[0], e_class, e_class)
    for i in range(e_class):
        q_e[:, i, i] = 1 - (betas*noise_dist.sum()).view(-1)
    noise=noise_dist.unsqueeze(0).repeat(e_class,1)
    for batch in range(betas.shape[0]):
        q_e[batch,:,:]=q_e[batch,:,:]+betas[batch]*noise 
    assert ((q_e.sum(dim=2) - 1.0).abs() < 1e-4).all()  # ensure each row of q_e represents a distribution
    return q_e

def get_Q_t_bar(alphas_bar: torch.Tensor, noise_dist: torch.Tensor, e_class: int =2) -> torch.Tensor:
    """
    Input:
        - alphas_bar, (bs, )
        - noise_dist, (e_class, 1) -> the noise distribution at step $t=1,000$, here we take by default (0.5, 0.5)
        - e_class, number of classes for edges
    Output:
        - q_e, (bs, e_class, e_class)
    """

    q_e = torch.zeros(alphas_bar.shape[0], e_class, e_class)
    for i in range(e_class):
        q_e[:, i, i] = alphas_bar.view(-1)
    noise=noise_dist.unsqueeze(0).repeat(e_class,1)
    for batch in range(alphas_bar.shape[0]):  
        q_e[batch,:,:]=q_e[batch,:,:]+(1-alphas_bar[batch])*noise
    assert ((q_e.sum(dim=2) - 1.0).abs() < 1e-4).all()  # ensure each row of q_e represents a distribution
    return q_e

def corrupt_edges(E: torch.Tensor, t_int: torch.Tensor, noise_dist: torch.Tensor, node_mask: torch.Tensor):
    bs = E.size(0)
    n = E.size(1)  # number of nodes
    de = len(noise_dist)  # number of edge classes, i.e., 2

    idx = t_int.to('cpu').squeeze(-1)
    Qt_bar=get_Q_t_bar(all_alphas_bar[idx], noise_dist, de)
    probE = torch.zeros(bs, n, n, de)
    for i in range(bs):
        probE[i] = E[i].matmul(Qt_bar[i])
    
    inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))
    diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1)
    probE[inverse_edge_mask] = 1 / probE.shape[-1]
    probE[diag_mask.bool()] = 1 / probE.shape[-1]
    probE = probE.reshape(bs * n * n, -1)    # (bs * n * n, de_out)
    E_t = torch.multinomial(probE, 1).reshape(bs, n, n,1)
    upper_triangle = torch.triu(E_t.squeeze(), diagonal=0)  
    symmetric_matrix = upper_triangle + upper_triangle.transpose(1, 2) - torch.diag_embed(upper_triangle.diagonal(dim1=-2, dim2=-1))  
    E_t = symmetric_matrix.unsqueeze(-1)
    E_t= torch.nn.functional.one_hot(E_t, de).float().squeeze()
    return E_t

def apply_noise(holder, T, noise_dist, node_mask):
    t_int = torch.randint(1, T + 1, size=(holder.E.size(0), 1))
    E = holder.E
    E_t = corrupt_edges(E, t_int, noise_dist, node_mask)
    return PlaceHolder(X=holder.X, E=E_t, y=holder.y).mask(node_mask), t_int

def seed_torch(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multiple GPUs
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    random.seed(seed)
    np.random.seed(seed)


## *Utils for Denoiser Model (graph transformer is in transformer_model.py)* ##


In [None]:
def collapse_placeholder(holder):
    E = torch.argmax(holder.E, dim=-1)
    return PlaceHolder(X=holder.X, E=E, y=holder.y)

def holder_to_data_batch(holder, node_mask):
    # WARNING: holder should already be masked!
    data_list = []
    holder = collapse_placeholder(holder)

    for graph_idx in range(holder.X.size(0)):
        this_node_mask = node_mask[graph_idx]
        n_nodes = this_node_mask.sum()
        X = holder.X[graph_idx].squeeze()
        X = X[this_node_mask]
        E = holder.E[graph_idx]
        edge_index, edge_attr = dense_to_sparse(adj=E[:n_nodes, :n_nodes])
        y = holder.y[graph_idx]
        data = Data(x=X, edge_index=edge_index, edge_attr=edge_attr, y=y)
        data.validate(raise_on_error=True)
        data_list.append(data)

        data_batch = Batch.from_data_list(data_list)

    return data_batch

class ER_prob_model():
    def __init__(self) -> None:
        pass

    def fit(self, dataset):
        self.p = self.estimate_p(dataset)
        self.n_dist = self.estimate_n_cat_dist(dataset)
        print(f"Estimated p: {self.p}")
        print(f"Estimated n categorical distribution: {self.n_dist}")

    def estimate_n_cat_dist(self, dataset: List[nx.Graph]) -> Dict[int, float]:
        """Output must be a dictionary where keys are the number of nodes and values are the probability of that number of nodes."""
        n_dist = {}
        for g in dataset:
            n = g.number_of_nodes()
            if n in n_dist:
                n_dist[n] += 1
            else:
                n_dist[n] = 1
        for n in n_dist:
            n_dist[n] /= len(dataset)
        return n_dist

    def estimate_p(self, dataset: List[nx.Graph]) -> float:
        """Output must be a single floating point number representing the probability of an edge existing between two nodes."""
        p = sum([g.number_of_edges() for g in dataset]) / sum([g.number_of_nodes() * (g.number_of_nodes() - 1) / 2 for g in dataset])
        return p

    def sample_ER_graphs(self, n_graphs: int) -> List[nx.Graph]:
        """Output must be a list of n_graphs ER graphs sampled from the model. Note that you should use the estimated parameters, self.n_dist and self.p (see `fit` function)."""
        graphs = []
        for _ in range(n_graphs):
            n = np.random.choice(list(self.n_dist.keys()), p=list(self.n_dist.values()))
            graph = nx.Graph()
            graph.add_nodes_from(range(n))
            for i in range(n):
                for j in range(i+1, n):
                    if np.random.rand() < self.p:
                        graph.add_edge(i, j)
            graphs.append(graph)
        return graphs  
    
def generate_ER_graph(n: int, p_er: float) -> nx.Graph:
    graph = nx.Graph()
    graph.add_nodes_from(range(n))
    for i in range(n):
        for j in range(i+1, n):
            if np.random.rand() < p_er:
                graph.add_edge(i, j)
    return graph

class SimpleModel(nn.Module):
    def __init__(self,
                 input_dims: int,
                 num_GNN_layers: int,
                 hidden_dims: int,
                 num_MLP_layers: int,
                 hidden_MLP_dims: int,
                 output_dims: int):
        super().__init__()
        self.gnn = GIN(
            in_channels=input_dims,
            out_channels=hidden_dims,
            num_layers=num_GNN_layers,
            hidden_channels=hidden_dims,
            train_eps=False
        )
        self.mlp = MLP(
            in_channels=hidden_dims,
            hidden_channels=hidden_MLP_dims,
            out_channels=output_dims,
            num_layers=num_MLP_layers,
            activation_layer=torch.nn.ReLU()
            # activation_layer=None
        )

    def forward(self, holder: PlaceHolder, node_mask: torch.Tensor) -> PlaceHolder:
        max_n_nodes = holder.E.shape[1]
        bs=holder.X.size(0)
        data = holder_to_data_batch(holder, node_mask)

        # add timestep concatenating in node features
        time = data.y
        x, edge_index = data.x, data.edge_index
        time_as_x = time[data.batch].unsqueeze(-1)
        if x.dim() == 1:  
            x = torch.hstack((x.unsqueeze(-1), time_as_x))  # if you don't use eigenfeatures
        elif x.dim() == 2:
            x = torch.hstack((x, time_as_x))  # if you  use eigenfeatures
        node_embeddings = self.gnn(x[:,:-1], edge_index)
        padded_node_embeddings = torch.zeros(bs, max_n_nodes, node_embeddings.size(-1))
        total_sum=0
        
        for batch in range(bs):
            padded_node_embeddings[batch, :sum(node_mask[batch]), :] = node_embeddings[total_sum:total_sum+sum(node_mask[batch])].view(sum(node_mask[batch]), -1)
            total_sum+=sum(node_mask[batch])
        node_embeddings = padded_node_embeddings
        reshaped_original = node_embeddings[:,:, np.newaxis, :]
        mlp_input_feature = reshaped_original * reshaped_original.swapaxes(1, 2)
        mlp_input_feature = mlp_input_feature.view(-1, mlp_input_feature.size(-1))
        output_E = self.mlp(mlp_input_feature)
        output_E = output_E.view(-1,max_n_nodes,max_n_nodes,2)
        output_E=utils.symmetrize(output_E)
        output_holder = PlaceHolder(X=holder.X, E=output_E, y=None).mask(node_mask)

        return output_holder
    
class SimpleModelGCN(nn.Module):
    def __init__(self,
                 input_dims: int,
                 num_GNN_layers: int,
                 hidden_dims: int,
                 num_MLP_layers: int,
                 hidden_MLP_dims: int,
                 output_dims: int):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(input_dims, hidden_dims))
        for _ in range(num_GNN_layers - 1):
            self.convs.append(GCNConv(hidden_dims, hidden_dims))
        self.mlp = MLP(
            in_channels=hidden_dims,
            hidden_channels=hidden_MLP_dims,
            out_channels=output_dims,
            num_layers=num_MLP_layers,
            # activation_layer=None
            activation_layer=torch.nn.ReLU()
        )

    def forward(self, holder: PlaceHolder, node_mask: torch.Tensor) -> PlaceHolder:
        max_n_nodes = holder.E.shape[1]
        bs = holder.X.size(0)
        data = holder_to_data_batch(holder, node_mask)

        # Add timestep concatenating in node features
        time = data.y
        x, edge_index = data.x, data.edge_index
        time_as_x = time[data.batch].unsqueeze(-1)
        if x.dim() == 1:  
            x = torch.hstack((x.unsqueeze(-1), time_as_x))  # if you don't use eigenfeatures
        elif x.dim() == 2:
            x = torch.hstack((x, time_as_x))  # if you use eigenfeatures
        
        # Use the node features without the timestep for GNN
        node_embeddings = x[:, :-1]
        for conv in self.convs:
            node_embeddings = conv(node_embeddings, edge_index).relu()
        
        padded_node_embeddings = torch.zeros(bs, max_n_nodes, node_embeddings.size(-1))
        total_sum = 0

        for batch in range(bs):
            padded_node_embeddings[batch, :sum(node_mask[batch]), :] = node_embeddings[total_sum:total_sum + sum(node_mask[batch])].view(sum(node_mask[batch]), -1)
            total_sum += sum(node_mask[batch])
        node_embeddings = padded_node_embeddings
        reshaped_original = node_embeddings[:, :, np.newaxis, :]
        mlp_input_feature = reshaped_original * reshaped_original.swapaxes(1, 2)
        mlp_input_feature = mlp_input_feature.view(-1, mlp_input_feature.size(-1))
        output_E = self.mlp(mlp_input_feature)
        output_E = output_E.view(-1, max_n_nodes, max_n_nodes, 2)
        output_E = utils.symmetrize(output_E)
        output_holder = PlaceHolder(X=holder.X, E=output_E, y=None).mask(node_mask)

        return output_holder

class SimpleModelGAT(nn.Module):
    def __init__(self,
                 input_dims: int,
                 num_GNN_layers: int,
                 hidden_dims: int,
                 num_MLP_layers: int,
                 hidden_MLP_dims: int,
                 output_dims: int,
                 heads: int = 1):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(input_dims, hidden_dims, heads=heads))
        for _ in range(num_GNN_layers - 1):
            self.convs.append(GATConv(hidden_dims * heads, hidden_dims, heads=heads))
        self.mlp = MLP(
            in_channels=hidden_dims * heads,
            hidden_channels=hidden_MLP_dims,
            out_channels=output_dims,
            num_layers=num_MLP_layers,
            # activation_layer=None
            activation_layer=torch.nn.ReLU()
        )

    def forward(self, holder: PlaceHolder, node_mask: torch.Tensor) -> PlaceHolder:
        max_n_nodes = holder.E.shape[1]
        bs = holder.X.size(0)
        data = holder_to_data_batch(holder, node_mask)

        # Add timestep concatenating in node features
        time = data.y
        x, edge_index = data.x, data.edge_index
        time_as_x = time[data.batch].unsqueeze(-1)
        if x.dim() == 1:  
            x = torch.hstack((x.unsqueeze(-1), time_as_x))  # if you don't use eigenfeatures
        elif x.dim() == 2:
            x = torch.hstack((x, time_as_x))  # if you use eigenfeatures
        
        # Use the node features without the timestep for GNN
        node_embeddings = x[:, :-1]
        for conv in self.convs:
            node_embeddings = conv(node_embeddings, edge_index).relu()
        
        padded_node_embeddings = torch.zeros(bs, max_n_nodes, node_embeddings.size(-1))
        total_sum = 0

        for batch in range(bs):
            padded_node_embeddings[batch, :sum(node_mask[batch]), :] = node_embeddings[total_sum:total_sum + sum(node_mask[batch])].view(sum(node_mask[batch]), -1)
            total_sum += sum(node_mask[batch])
        node_embeddings = padded_node_embeddings
        reshaped_original = node_embeddings[:, :, np.newaxis, :]
        mlp_input_feature = reshaped_original * reshaped_original.swapaxes(1, 2)
        mlp_input_feature = mlp_input_feature.view(-1, mlp_input_feature.size(-1))
        output_E = self.mlp(mlp_input_feature)
        output_E = output_E.view(-1, max_n_nodes, max_n_nodes, 2)
        output_E = utils.symmetrize(output_E)
        output_holder = PlaceHolder(X=holder.X, E=output_E, y=None).mask(node_mask)

        return output_holder



## *Utils for Training* ##

In [None]:
def train_model(seed, model, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr=2e-4):
    # Training parameters
    seed_torch(seed)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    val_loss_list = []
    loss_list = []
    val_metric = torch.nn.CrossEntropyLoss()
    train_loss = torch.nn.CrossEntropyLoss()

    # Iterate over the batches
    for epoch_idx in tqdm(range(n_epochs)):
        model.train()
        epoch_loss = []
        for batch in train_dataloader:

            # Access the batch data
            unmasked_holder, node_mask = to_dense(batch)
            noise_holder, t_int = apply_noise(unmasked_holder, T, noise_dist, node_mask)
            target_holder = unmasked_holder.mask(node_mask)
            # Prepare data for inference
            noise_holder.y = t_int.float() / T
            noise_holder = noise_holder.to(device)
            node_mask = node_mask.to(device)
            target_holder = target_holder.to(device)

            if eigen_feats is not None:
                add_feats = eigen_feats(noise_holder.E, node_mask)
                cycle_features = NodeCycleFeatures()
                x_cycles, y_cycles = cycle_features(noise_holder,node_mask)
   
                noise_holder.X = torch.cat((noise_holder.X, x_cycles,add_feats[2], add_feats[3]), -1)
                noise_holder.y = torch.cat((noise_holder.y, y_cycles,add_feats[0], add_feats[1]), -1)
                
            pred = model(noise_holder, node_mask)
            E_true = target_holder.E.reshape(-1, target_holder.E.shape[-1])
            E_pred = pred.E.reshape(-1, pred.E.shape[-1])
            loss = train_loss(E_pred, E_true)
     
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss = loss.detach().cpu().item()           
            epoch_loss.append(loss)

        loss_list.append(np.mean(epoch_loss))
        val_loss = validation_step(val_dataloader, val_metric, model, noise_dist, T, eigen_feats)

        val_loss_list.append(val_loss)

    # Baseline
    er_train_loss = er_validation_step(train_dataloader,val_metric, er_model)
    er_train_loss_arr = er_train_loss * np.ones(n_epochs)
    er_val_loss = er_validation_step(val_dataloader,val_metric, er_model)
    er_val_loss_arr = er_val_loss * np.ones(n_epochs)
    print(epoch_loss[-1])

    plt.plot(loss_list, label="train_loss")
    window_size = 10
    plt.plot(np.convolve(loss_list, np.ones(window_size)/window_size, mode='valid'), label="smooth train")
    plt.plot(val_loss_list, label="val_loss")
    plt.plot(er_train_loss_arr, label="er_train_loss")
    plt.plot(er_val_loss_arr, label="er_val_loss")
    plt.legend()
    plt.show()

    return model, loss_list, val_loss_list, er_train_loss, er_val_loss

def validation_step(val_dataloader, val_metric, model, noise_dist, T, eigen_feats):
    model.eval()

    # Get val loss
    loss_list = []
    for batch in val_dataloader:
        # Access the batch data
        unmasked_holder, node_mask = to_dense(batch)
        noise_holder, t_int = apply_noise(unmasked_holder, T, noise_dist, node_mask)
        target_holder = unmasked_holder.mask(node_mask)

        # Prepare data for inference
        noise_holder.y = t_int.float() / T
        noise_holder = noise_holder.to(device)
        node_mask = node_mask.to(device)
        target_holder = target_holder.to(device)

        # get values for training model
        if eigen_feats is not None:
            add_feats = eigen_feats(noise_holder.E, node_mask)
            cycle_features = NodeCycleFeatures()
            x_cycles, y_cycles = cycle_features(noise_holder,node_mask)
            noise_holder.X = torch.cat((noise_holder.X, x_cycles,add_feats[2], add_feats[3]), -1)
            noise_holder.y = torch.cat((noise_holder.y, y_cycles,add_feats[0], add_feats[1]), -1)

        pred = model(noise_holder, node_mask)
        E_true = target_holder.E.reshape(-1, target_holder.E.shape[-1])
        E_pred = pred.E.reshape(-1, pred.E.shape[-1])
        val_loss = val_metric(E_pred, E_true)


        loss_list.append(val_loss.detach().cpu().numpy())

    val_loss = sum(loss_list)/len(loss_list)

    return val_loss

def sample_graph(seed, model, noise_dist, T, n_graphs, num_timesteps_to_save=5, eigen_feats=None):
    seed_torch(seed)
    max_n_nodes = len(n_nodes_dist) - 1
    de = len(noise_dist)

    n_nodes = torch.multinomial(n_nodes_dist, n_graphs, replacement=True) 
    node_mask=torch.zeros(n_graphs, max_n_nodes).bool()
    for i in range(n_graphs):
        node_mask[i, :n_nodes[i]] = True

    # sample a purely noised graph at step T
    limit_dist = noise_dist.repeat(n_graphs * max_n_nodes * max_n_nodes, 1)
    limit_E = limit_dist.multinomial(1, replacement=True)
    limit_E = limit_E.reshape(n_graphs, max_n_nodes, max_n_nodes)
    limit_E = torch.nn.functional.one_hot(limit_E, 2).float()
    limit_E = utils.symmetrize(limit_E)

    holder = PlaceHolder(X=torch.ones(n_graphs, max_n_nodes, 1),
                         E=limit_E,
                         y=torch.Tensor([1]).unsqueeze(-1).repeat(n_graphs, 1),).to(device).mask(node_mask)

    saving_steps = torch.linspace(0, T, num_timesteps_to_save).round()
    saved_holders = [holder.to('cpu')]  # save initial holder
    print("Saving timesteps: ", saving_steps)
    model.eval()

    with torch.no_grad():
        for t in tqdm(range(1, T+1)[::-1]):
            steps = torch.Tensor([t]).unsqueeze(-1).repeat(n_graphs, 1).to(device)
            steps_next = steps - 1
            holder.y = steps / T

            # if eigen_feats is not None:
            #     add_feats = eigen_feats(holder.E, node_mask)
            #     holder.X = torch.cat((holder.X, add_feats[2], add_feats[3]), -1)
            #     holder.y = torch.cat((holder.y, add_feats[0], add_feats[1]), -1)
            if eigen_feats is not None:
                add_feats = eigen_feats(holder.E, node_mask)
                cycle_features = NodeCycleFeatures()
                x_cycles, y_cycles = cycle_features(holder,node_mask)
                holder.X = torch.cat((holder.X, x_cycles,add_feats[2], add_feats[3]), -1)
                holder.y = torch.cat((holder.y, y_cycles,add_feats[0], add_feats[1]), -1)

            pred_G0 = model(holder, node_mask)
            pred_E = torch.softmax(pred_G0.E, dim=-1)
            shape_E = pred_E.shape
            E_t = holder.E.to(torch.float32)              # bs, n, n, 2

            # transition matrix used in the reverse process of size (bs, 2, 2)
            Qtb = get_Q_t_bar(all_alphas_bar[steps[:, 0].long().to('cpu')], noise_dist, e_class=2).to(device)
            Qsb = get_Q_t_bar(all_alphas_bar[steps_next[:, 0].long().to('cpu')], noise_dist, e_class=2).to(device)
            Qt = get_Q_t(all_betas[steps[:, 0].long().to('cpu')], noise_dist, e_class=2).to(device)

            # posterior computation
            E_t = E_t.flatten(start_dim=1, end_dim=-2).to(torch.float32)            # bs x N x dt

            Qt_T = Qt.transpose(-1, -2)                 # bs, dt, d_t-1
            left_term = E_t @ Qt_T                      # bs, N, d_t-1
            left_term = left_term.unsqueeze(dim=2)      # bs, N, 1, d_t-1

            right_term = Qsb.unsqueeze(1)               # bs, 1, d0, d_t-1
            numerator = left_term * right_term          # bs, N, d0, d_t-1

            X_t_transposed = E_t.transpose(-1, -2)      # bs, dt, N

            prod = Qtb @ X_t_transposed                 # bs, d0, N
            prod = prod.transpose(-1, -2)               # bs, N, d0
            denominator = prod.unsqueeze(-1)            # bs, N, d0, 1
            denominator[denominator == 0] = 1e-6

            p_s_and_t_given_0_E = numerator / denominator

            # noise back
            pred_E = pred_E.reshape((n_graphs, -1, pred_E.shape[-1]))
            weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
            unnormalized_prob_E = weighted_E.sum(dim=-2)
            unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
            prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
            prob_E = prob_E.reshape(n_graphs, max_n_nodes, max_n_nodes, pred_E.shape[-1])

            # sample at timestep t-1
            E_s = prob_E.reshape(-1, de).multinomial(1, replacement=True)
            E_s = torch.nn.functional.one_hot(E_s, 2).float().reshape(shape_E)
            E_s = utils.symmetrize(E_s)

            assert (E_s == torch.transpose(E_s, 1, 2)).all()
            assert (E_s.shape == E_s.shape)
            holder = PlaceHolder(X=torch.ones(n_graphs, max_n_nodes, 1).to(device),
                                E=E_s, y=steps_next/T).to(device).mask(node_mask)
            # Save graphs
            if t-1 in saving_steps:
                saved_holders.append(holder.to('cpu'))

    return holder, saved_holders


## *Utils for Evaluation* ##


In [None]:
def get_degree_hist_array(nx_dataset: List[nx.Graph], max_degree: int =20) -> np.ndarray:
    """The output should be an np.array of shape (len(nx_dataset), max_degree),  where each row corresponds to the degree histogram obtained for a different graph of the dataset."""
    deg_hist_array = np.zeros((len(nx_dataset), max_degree))
    for i, g in enumerate(nx_dataset):
        deg = np.array([g.degree(node) for node in g.nodes()])
        deg_hist, _ = np.histogram(deg, bins=np.arange(max_degree+1))
        deg_hist_array[i] = deg_hist
    return deg_hist_array

def get_clustering_hist_array(nx_dataset, num_bins=10) -> np.ndarray:
    """The output should be an np.array of shape (len(nx_dataset), num_bins), where each row corresponds to the node cluster coefficient histogram obtained for a different graph of the dataset."""
    clustering_hist_array = np.zeros((len(nx_dataset), num_bins))
    for i, g in enumerate(nx_dataset):
        clustering = np.array(list(nx.clustering(g).values()))
        clustering_hist, _ = np.histogram(clustering, bins=np.linspace(0, 1, num_bins+1))
        clustering_hist_array[i] = clustering_hist
    return clustering_hist_array

def is_valid_graph(g: nx.Graph) -> bool:
    """Check if the graph is valid based on the given criteria."""
    # Check if the graph has exactly one connected component
    if not nx.is_connected(g):
        return False
    
    # Check if every node has an edge number between 1 and 4
    for node in g.nodes():
        if not (1 <= g.degree(node) <= 4):
            return False
    
    # Check if the graph is planar
    is_planar, _ = nx.check_planarity(g)
    if not is_planar:
        return False
    
    return True

def count_valid_graphs(nx_dataset: List[nx.Graph]) -> int:
    """Count the number of valid graphs in the list based on the criteria."""
    valid_count = 0
    for g in nx_dataset:
        if is_valid_graph(g):
            valid_count += 1
    return valid_count

def graph_to_canonical_form(graph: nx.Graph) -> Tuple:
    """Convert a graph to a canonical form using sorted adjacency list."""
    adjacency_list = nx.to_dict_of_lists(graph)
    for key in adjacency_list:
        adjacency_list[key].sort()
    return tuple(sorted((k, tuple(v)) for k, v in adjacency_list.items()))

def count_unique_and_novel_graphs(graph_list: List[nx.Graph], train_set: List[nx.Graph]) -> Tuple[int, int]:
    """Count the number of unique graphs and the number of novel graphs from the train set."""
    train_canonical_forms = {graph_to_canonical_form(g) for g in train_set}
    graph_canonical_forms = {graph_to_canonical_form(g) for g in graph_list}
    
    unique_graphs = len(graph_canonical_forms)
    novel_graphs = len(graph_canonical_forms - train_canonical_forms)
    
    return unique_graphs, novel_graphs



## *Actual Training, and generating graphs* ##

In [None]:
graphs_to_eval= {}
train_loss_list = {}
val_loss_list = {}
er_train_loss_list = {}
er_val_loss_list = {}

trial_num=5
for i in range(trial_num):
        seed=i
        utils.seed_everything(seed)
        # [0.5,0.5] without eigenfeatures
        noise_dist = torch.Tensor([0.5, 0.5])

        train_indices = np.random.choice(len(total_dataset_np), 250, replace=False)
        train_dataset_np = total_dataset_np[train_indices]
        test_dateset_np = total_dataset_np[np.delete(np.arange(len(total_dataset_np)), train_indices)] # get the rest as test dataset
        nx_dataset={}
        nx_dataset["train"]=[nx.from_numpy_matrix(data) for data in train_dataset_np]
        nx_dataset["test"]=[nx.from_numpy_matrix(data) for data in test_dateset_np]
        train_dataset = nx_dataset["train"]
        
        graphs_to_eval[f"train_{i}"] = train_dataset
        graphs_to_eval[f"test_{i}"] = nx_dataset["test"]
        num_graphs_to_generate = 20
        bs = 32
        train_dataloader = utils.nx_list_to_dataloader(nx_dataset["train"], bs=bs)
        val_dataloader = utils.nx_list_to_dataloader(nx_dataset["test"], bs=bs)
        test_dataloader = utils.nx_list_to_dataloader(nx_dataset["test"], bs=bs)

        def count_n_node(dataset: List[nx.Graph]) -> np.ndarray:
                """Output must be a numpy array of shape (max_n_nodes + 1,) containing the probability of each number of nodes in the dataset, where max_n_nodes is the highest number of nodes found for a single graph in the dataset."""
                max_n_nodes=max([g.number_of_nodes() for g in dataset])
                prob = np.zeros(max_n_nodes + 1)
                for g in dataset:
                        prob[g.number_of_nodes()] += 1
                prob /= len(dataset)
                return prob

        n_nodes_dist = torch.tensor(count_n_node(nx_dataset['train']))
        max_n_nodes = len(n_nodes_dist)-1


        er_model = ER_prob_model()
        er_model.fit(train_dataset)
        
        graphs_to_eval[f"ER_{i}"] = er_model.sample_ER_graphs(num_graphs_to_generate)

        n_epochs = 20 # you can try epoches if you want

        args_simple = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 1,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model = SimpleModel(**args_simple).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        # eigen_feats = utils.EigenFeatures('all')
        eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model,train_loss,val_loss,er_train_loss,er_val_loss = train_model(seed, simple_model, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_{i}"] = train_loss
        val_loss_list[f"simple_{i}"] = val_loss
        er_train_loss_list[f"simple_{i}"] = er_train_loss
        er_val_loss_list[f"simple_{i}"] = er_val_loss
        simple_model_graphs, simple_model_intermediate_graphs = sample_graph(seed, simple_model, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_{i}"] = simple_model_graphs.to_nx_graph_list()



        args_simple_gcn = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 1,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gcn = SimpleModelGCN(**args_simple_gcn).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        # eigen_feats = utils.EigenFeatures('all')
        eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_gcn,train_loss,val_loss, _, _ = train_model(seed, simple_model_gcn, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gcn_{i}"] = train_loss
        val_loss_list[f"simple_gcn_{i}"] = val_loss
        simple_model_gcn_graphs, simple_model_gcn_intermediate_graphs = sample_graph(seed, simple_model_gcn, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=None)
        graphs_to_eval[f"simple_gcn_{i}"] = simple_model_gcn_graphs.to_nx_graph_list()


        args_simple_gat = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 1,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gat = SimpleModelGAT(**args_simple_gat).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        # eigen_feats = utils.EigenFeatures('all')
        eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_gat, train_loss, val_loss, _,_ = train_model(seed, simple_model_gat, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gat_{i}"] = train_loss
        val_loss_list[f"simple_gat_{i}"] = val_loss
        simple_model_gat_graphs, simple_model_gat_intermediate_graphs = sample_graph(seed, simple_model_gat, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_gat_{i}"] = simple_model_gat_graphs.to_nx_graph_list()


        args_gt = {'n_layers': 4,
                'n_head': 8,
                'input_dims': {'X': 1, 'E': 2, 'y': 1},
                'hidden_dims': {'X': 256, 'E': 256, 'y': 256, 'dx': 64, 'de': 64, 'dy': 64},
                'output_dims': {'X': 1, 'E': 2, 'y': 1},}
        n_epochs = 5
        gt_model = GraphTransformer(**args_gt).to(device)
        lr=2e-4
        gt_model,train_loss,val_loss,_,_ = train_model(seed, gt_model, train_dataloader, val_dataloader, n_epochs, noise_dist, T, None,lr)
        train_loss_list[f"gt_{i}"] = train_loss
        val_loss_list[f"gt_{i}"] = val_loss
        num_graphs_to_generate = len(nx_dataset["test"])
        gt_model_graphs, gt_model_intermediate_graphs = sample_graph(seed, gt_model, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=None)
        graphs_to_eval[f"gt_{i}"] = gt_model_graphs.to_nx_graph_list()
























        #[0.5,0,5] with eigenfeatures
        noise_dist = torch.Tensor([0.5, 0.5])

        n_epochs = 100 # you can try epoches if you want

        args_simple = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 7,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_eigen = SimpleModel(**args_simple).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        eigen_feats = utils.EigenFeatures('all')
        # eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_eigen, train_loss, val_loss, _, _ = train_model(seed, simple_model_eigen, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_eigen_{i}"] = train_loss
        val_loss_list[f"simple_eigen_{i}"] = val_loss
        simple_model_eigen_graphs, simple_model_eigen_intermediate_graphs = sample_graph(seed, simple_model_eigen, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_eigen_{i}"] = simple_model_eigen_graphs.to_nx_graph_list()



        args_simple_gcn = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 7,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gcn_eigen = SimpleModelGCN(**args_simple_gcn).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        eigen_feats = utils.EigenFeatures('all')
        #     eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_gcn_eigen, train_loss, val_loss,_,_ = train_model(seed, simple_model_gcn_eigen, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gcn_eigen_{i}"] = train_loss
        val_loss_list[f"simple_gcn_eigen_{i}"] = val_loss
        simple_model_gcn_eigen_graphs, simple_model_gcn_eigen_intermediate_graphs = sample_graph(seed, simple_model_gcn_eigen, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_gcn_eigen_{i}"] = simple_model_gcn_eigen_graphs.to_nx_graph_list()


        args_simple_gat = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 7,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gat_eigen = SimpleModelGAT(**args_simple_gat).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        eigen_feats = utils.EigenFeatures('all')
        # eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_gat_eigen,train_loss, val_loss,_,_ = train_model(seed, simple_model_gat_eigen, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gat_eigen_{i}"] = train_loss
        val_loss_list[f"simple_gat_eigen_{i}"] = val_loss
        simple_model_gat_eigen_graphs, simple_model_gat_eigen_intermediate_graphs = sample_graph(seed, simple_model_gat_eigen, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_gat_eigen_{i}"] = simple_model_gat_eigen_graphs.to_nx_graph_list()


        args_gt_pp = {'n_layers': 4,
        'n_head': 8,
        'input_dims': {'X': 7, 'E': 2, 'y': 7},
        'hidden_dims': {'X': 256, 'E': 256, 'y': 256, 'dx': 64, 'de': 64, 'dy': 64},
        'output_dims': {'X': 1, 'E': 2, 'y': 1},}
        n_epochs=5
        eigen_feats = utils.EigenFeatures('all')
        gt_pp_model = GraphTransformer(**args_gt_pp).to(device)
        lr=2e-4
        gt_pp_model,train_loss,val_loss,_,_ = train_model(seed, gt_pp_model, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"gt_eigen_{i}"] = train_loss
        val_loss_list[f"gt_eigen_{i}"] = val_loss
        gt_pp_model_graphs, gt_pp_model_intermediate_graphs = sample_graph(seed, gt_pp_model, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"gt_eigen_{i}"] = gt_pp_model_graphs.to_nx_graph_list()

























        # Marginal without eigenfeatures
        noise_dist= torch.Tensor([1-0.30574074074074076, 0.30574074074074076]) 
        n_epochs = 100 # you can try epoches if you want

        args_simple = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 1,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_marginal = SimpleModel(**args_simple).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        #     eigen_feats = utils.EigenFeatures('all')
        eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_marginal,train_loss,val_loss,_,_ = train_model(seed, simple_model_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_marginal_{i}"] = train_loss
        val_loss_list[f"simple_marginal_{i}"] = val_loss
        simple_model_marginal_graphs, simple_model_marginal_intermediate_graphs = sample_graph(seed, simple_model_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_marginal_{i}"] = simple_model_marginal_graphs.to_nx_graph_list()



        args_simple_gcn = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 1,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gcn_marginal = SimpleModelGCN(**args_simple_gcn).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        # eigen_feats = utils.EigenFeatures('all')
        eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_gcn_marginal,train_loss,val_loss,_,_ = train_model(seed, simple_model_gcn_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gcn_marginal_{i}"] = train_loss
        val_loss_list[f"simple_gcn_marginal_{i}"] = val_loss
        simple_model_gcn_marginal_graphs, simple_model_gcn_marginal_intermediate_graphs = sample_graph(seed, simple_model_gcn_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=None)
        graphs_to_eval[f"simple_gcn_marginal_{i}"] = simple_model_gcn_marginal_graphs.to_nx_graph_list()


        args_simple_gat = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 1,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gat_marginal = SimpleModelGAT(**args_simple_gat).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        #     eigen_feats = utils.EigenFeatures('all')
        eigen_feats= None
        plt.ylim(0.090,0.12)
        simple_model_gat_marginal,train_loss,val_loss,_,_ = train_model(seed, simple_model_gat_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gat_marginal_{i}"] = train_loss
        val_loss_list[f"simple_gat_marginal_{i}"] = val_loss
        simple_model_gat_marginal_graphs, simple_model_gat_marginal_intermediate_graphs = sample_graph(seed, simple_model_gat_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_gat_marginal_{i}"] = simple_model_gat_marginal_graphs.to_nx_graph_list()


        args_gt = {'n_layers': 4,
                'n_head': 8,
                'input_dims': {'X': 1, 'E': 2, 'y': 1},
                'hidden_dims': {'X': 256, 'E': 256, 'y': 256, 'dx': 64, 'de': 64, 'dy': 64},
                'output_dims': {'X': 1, 'E': 2, 'y': 1},}
        n_epochs = 5
        gt_model_marginal = GraphTransformer(**args_gt).to(device)
        lr=2e-4
        gt_model_marginal,train_loss,val_loss,_,_ = train_model(seed, gt_model_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, None,lr)
        train_loss_list[f"gt_marginal_{i}"] = train_loss
        val_loss_list[f"gt_marginal_{i}"] = val_loss
        num_graphs_to_generate = len(nx_dataset["test"])
        gt_model_marginal_graphs, gt_model_marginal_intermediate_graphs = sample_graph(seed, gt_model_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=None)
        graphs_to_eval[f"gt_marginal_{i}"] = gt_model_marginal_graphs.to_nx_graph_list()
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        #marginal with eigenfeatures
        noise_dist = torch.Tensor([1-0.30574074074074076, 0.30574074074074076]) 

        n_epochs = 100 # you can try epoches if you want

        args_simple = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 7,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_eigen_marginal = SimpleModel(**args_simple).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        eigen_feats = utils.EigenFeatures('all')
        # eigen_feats= None
        # plt.ylim(0.090,0.12)
        simple_model_eigen_marginal,train_loss,val_loss,_,_ = train_model(seed, simple_model_eigen_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_eigen_marginal_{i}"] = train_loss
        val_loss_list[f"simple_eigen_marginal_{i}"] = val_loss
        simple_model_eigen_marginal_graphs, simple_model_eigen_marginal_intermediate_graphs = sample_graph(seed, simple_model_eigen_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_eigen_marginal_{i}"] = simple_model_eigen_marginal_graphs.to_nx_graph_list()



        args_simple_gcn = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 7,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gcn_eigen_marginal = SimpleModelGCN(**args_simple_gcn).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        eigen_feats = utils.EigenFeatures('all')
        #     eigen_feats= None
        # plt.ylim(0.090,0.12)
        simple_model_gcn_eigen_marginal,train_loss,val_loss,_,_ = train_model(seed, simple_model_gcn_eigen_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gcn_eigen_marginal_{i}"] = train_loss
        val_loss_list[f"simple_gcn_eigen_marginal_{i}"] = val_loss
        simple_model_gcn_eigen_marginal_graphs, simple_model_gcn_eigen_marginal_intermediate_graphs = sample_graph(seed, simple_model_gcn_eigen_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_gcn_eigen_marginal_{i}"] = simple_model_gcn_eigen_marginal_graphs.to_nx_graph_list()


        args_simple_gat = {'num_MLP_layers': 2,
                'num_GNN_layers': 2,
                'input_dims': 7,            # 1 if you dont use eigenfeature, 7 if you use eigenfeatures
                'hidden_dims': 100,
                'hidden_MLP_dims': 100,
                'output_dims': 2}
        simple_model_gat_eigen_marginal = SimpleModelGAT(**args_simple_gat).to(device)
        lr=2e-3      #1e-3 if you use eigenfeatures, 2e-4 if you dont use eigenfeatures
        eigen_feats = utils.EigenFeatures('all')
        # eigen_feats= None
        # plt.ylim(0.090,0.12)
        simple_model_gat_eigen_marginal,train_loss,val_loss,_,_ = train_model(seed, simple_model_gat_eigen_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"simple_gat_eigen_marginal_{i}"] = train_loss
        val_loss_list[f"simple_gat_eigen_marginal_{i}"] = val_loss
        simple_model_gat_eigen_marginal_graphs, simple_model_gat_eigen_marginal_intermediate_graphs = sample_graph(seed, simple_model_gat_eigen_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"simple_gat_eigen_marginal_{i}"] = simple_model_gat_eigen_marginal_graphs.to_nx_graph_list()


        args_gt_pp = {'n_layers': 4,
        'n_head': 8,
        'input_dims': {'X': 7, 'E': 2, 'y': 7},
        'hidden_dims': {'X': 256, 'E': 256, 'y': 256, 'dx': 64, 'de': 64, 'dy': 64},
        'output_dims': {'X': 1, 'E': 2, 'y': 1},}
        n_epochs=5
        eigen_feats = utils.EigenFeatures('all')
        gt_pp_model_marginal = GraphTransformer(**args_gt_pp).to(device)
        lr=2e-4
        gt_pp_model_marginal,train_loss,val_loss,_,_ = train_model(seed, gt_pp_model_marginal, train_dataloader, val_dataloader, n_epochs, noise_dist, T, eigen_feats,lr)
        train_loss_list[f"gt_eigen_marginal_{i}"] = train_loss
        val_loss_list[f"gt_eigen_marginal_{i}"] = val_loss
        gt_pp_model_marginal_graphs, gt_pp_model_marginal_intermediate_graphs = sample_graph(seed, gt_pp_model_marginal, noise_dist, T, num_graphs_to_generate, num_timesteps_to_save=5, eigen_feats=eigen_feats)
        graphs_to_eval[f"gt_eigen_marginal_{i}"] = gt_pp_model_marginal_graphs.to_nx_graph_list()
        
        



#Save learning curve and generated graphs

with open('graphs_dict.pkl', 'wb') as f:
    pickle.dump(graphs_to_eval, f)
with open('train_loss_dict.pkl', 'wb') as f:
    pickle.dump(train_loss_list, f)
with open('val_loss_dict.pkl', 'wb') as f:
    pickle.dump(val_loss_list,f)
with open('er_train_loss_dict.pkl', 'wb') as f:
    pickle.dump(er_train_loss_list, f)
with open('er_val_loss_dict.pkl', 'wb') as f:
    pickle.dump(er_val_loss_list, f)

## *Plotting Learning Curve* ##

In [None]:
#Load saved learning curve and generated graphs

def load_pickle(file_path):
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
    except FileNotFoundError:
        data = {}  
    return data


graphs_dict = load_pickle('graphs_dict_merged.pkl')
train_loss_dict = load_pickle('train_loss_dict_merged.pkl')
val_loss_dict = load_pickle('val_loss_dict_merged.pkl')
er_train_loss_dict = load_pickle('er_train_loss_dict_merged.pkl')
er_val_loss_dict = load_pickle('er_val_loss_dict_merged.pkl')

In [None]:
# Initialize dictionaries
average_train_loss = {}
average_val_loss = {}
std_train_loss = {}
std_val_loss = {}
average_er_train_loss = {}
average_er_val_loss = {}

# Compute means and standard deviations for different configurations
configs = ["simple", "simple_eigen","simple_eigen_cycles", "simple_marginal", "simple_eigen_marginal", "simple_eigen_marginal_cycles", "simple_gcn", "simple_gcn_eigen", "simple_gcn_marginal", "simple_gcn_eigen_marginal","simple_gat", "simple_gat_eigen", "simple_gat_marginal", "simple_gat_eigen_marginal","gt", "gt_eigen", "gt_eigen_cycles","gt_marginal", "gt_eigen_marginal","gt_eigen_marginal_cycles"]

for config in configs:
    if config == "simple":
        average_er_train_loss[config] = np.mean([er_train_loss_list[f"{config}_{i}"] for i in range(trial_num)])
        average_er_val_loss[config] = np.mean([er_val_loss_list[f"{config}_{i}"] for i in range(trial_num)])
        average_er_train_loss[config] = [average_er_train_loss[config]] * 100
        average_er_val_loss[config] = [average_er_val_loss[config]] * 100
        average_train_loss[config] = np.mean([train_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
        average_val_loss[config] = np.mean([val_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
        std_train_loss[config] = np.std([train_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
        std_val_loss[config] = np.std([val_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
    elif config in ["simple_gcn_eigen_cycle", "simple_gat_eigen_cycle","simple_gcn_eigen_cycle_marginal","simple_gat_eigen_cycle_marginal"]:
        average_train_loss[config] = np.mean([train_loss_list[f"{config}{i}"] for i in range(trial_num)], axis=0)
        average_val_loss[config] = np.mean([val_loss_list[f"{config}{i}"] for i in range(trial_num)], axis=0)
        std_train_loss[config] = np.std([train_loss_list[f"{config}{i}"] for i in range(trial_num)], axis=0)
        std_val_loss[config] = np.std([val_loss_list[f"{config}{i}"] for i in range(trial_num)], axis=0)
    else: 
        average_train_loss[config] = np.mean([train_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
        average_val_loss[config] = np.mean([val_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
        std_train_loss[config] = np.std([train_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
        std_val_loss[config] = np.std([val_loss_list[f"{config}_{i}"] for i in range(trial_num)], axis=0)
    

# Plotting function
def plot_losses(config):
    plt.figure(figsize=(10, 6))
    plt.plot(average_train_loss[config], label="train_loss", color='blue')
    
    plt.fill_between(range(len(average_train_loss[config])),
                     average_train_loss[config] - std_train_loss[config],
                     average_train_loss[config] + std_train_loss[config], color='blue', alpha=0.3)
    plt.plot(average_val_loss[config], label="val_loss", color='green')
    plt.fill_between(range(len(average_val_loss[config])),
                     average_val_loss[config] - std_val_loss[config],
                     average_val_loss[config] + std_val_loss[config], color='green', alpha=0.3)

    plt.plot(np.asarray(average_er_train_loss["simple"][:len(average_train_loss[config])])-0.001, label="er_train_loss", color='red')
    plt.plot(average_er_val_loss["simple"][:len(average_train_loss[config])], label="er_val_loss", color='orange')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss with Variations for simple_marginal')
    plt.xlim(0,20)
    plt.ylim(0.1, 0.14)
    plt.show()

# Plot losses for each configuration
for config in configs:
    plot_losses(config)


## *Evaluating generated graph* ##

In [None]:
#MMD of clustering coeffecing and degree

dict_stat_fn = {
    "degree": get_degree_hist_array,
    "clustering": get_clustering_hist_array,
}

mmd_dict={}
mmd_dict = {f"degree_{config}": [] for config in configs}
mmd_dict.update({f"clustering_{config}": [] for config in configs})


for i in range(trial_num):
  for configuration in configs:
    graphs_to_get_result={}
    graphs_to_get_result[configuration] = graphs_dict[f"{configuration}{i}"]
    final_results = utils.QuantitativeResults(dict_stat_fn=dict_stat_fn, train_dataset=graphs_dict[f"train_{i}"], test_dataset=graphs_dict[f"test_{i}"])
    for gen_method, gen_dist in graphs_to_get_result.items():
        final_results.add_results(gen_method, gen_dist)
        results_df = final_results.show_table()
        mmd_values = results_df.loc[:, results_df.columns != 'ref']
        mmd_dict[f"degree_{configuration}"].append(mmd_values.loc["degree", gen_method])
        mmd_dict[f"clustering_{configuration}"].append(mmd_values.loc["clustering", gen_method])
        
for config in configs:
    mmd_dict[f"degree_{config}"] = np.array(mmd_dict[f"degree_{config}"])
    mmd_dict[f"clustering_{config}"] = np.array(mmd_dict[f"clustering_{config}"])
    mmd_dict[f"degree_{config}"] = hmean(mmd_dict[f"degree_{config}"])
    mmd_dict[f"clustering_{config}"] = hmean(mmd_dict[f"clustering_{config}"])
    print(f"{config} degree: {mmd_dict[f'degree_{config}']:.2f}")
    print(f"{config} clustering: {mmd_dict[f'clustering_{config}']:.2f}\n")

In [None]:
# Kolmogorov-Smirnov statistics of eigenvalue distribution

ks_eigenvalue_dict={}
ks_eigenvalue_dict = {f"{config}": [] for config in configs}
for i in range(5):
  train_dataloader = utils.nx_list_to_dataloader(graphs_dict[f"train_{i}"], bs=250)
  for configuration in configs:
    eigen_feats = EigenFeatures(mode='all')
    generated_dataloader=utils.nx_list_to_dataloader(graphs_dict[f"{configuration}{i}"], bs=100)
    
    for batch in train_dataloader:
        unmasked_holder, node_mask = to_dense(batch)
        eigen_feature_train = eigen_feats(unmasked_holder.E, node_mask)[1]
    eigen_feature_train=eigen_feature_train.reshape(-1)
    for batch in generated_dataloader:
        unmasked_holder, node_mask = to_dense(batch)
        eigen_feature_generated = eigen_feats(unmasked_holder.E, node_mask)[1]
    eigen_feature_generated=eigen_feature_generated.reshape(-1)

    # Kolmogorov-Smirnov test
    ks_stat, ks_p_value = ks_2samp(eigen_feature_train, eigen_feature_generated)
    ks_eigenvalue_dict[configuration].append(ks_stat)


# Assuming ks_eigenvalue_dict is already populated
for config in configs:
    ks_eigenvalue_dict[config] = np.array(ks_eigenvalue_dict[config])
    harmonic_mean = hmean(ks_eigenvalue_dict[config])
    print(f"{config} ks: {harmonic_mean:.2f}\n")


In [None]:
#Validness, Unique and Novelty

valid={}
unique={}
novel={}
valid = {f"{config}": [] for config in configs}
unique = {f"{config}": [] for config in configs}
novel = {f"{config}": [] for config in configs}


for i in range(trial_num):
  for configuration in configs:
    valid[configuration].append(count_valid_graphs(graphs_dict[f"{configuration}{i}"]))
    unique[configuration].append(count_unique_and_novel_graphs(graphs_dict[f"{configuration}{i}"], graphs_dict[f"train_{i}"])[0])
    novel[configuration].append(count_unique_and_novel_graphs(graphs_dict[f"{configuration}{i}"], graphs_dict[f"train_{i}"])[1])


for config in configs:
    valid[config] = np.array(valid[config])
    unique[config] = np.array(unique[config])
    novel[config] = np.array(novel[config])
    harmonic_mean_valid = hmean(valid[config])
    harmonic_mean_unique = hmean(unique[config])
    harmonic_mean_novel = hmean(novel[config])
    print(f"{config} valid: {harmonic_mean_valid/100:.2f}")
    print(f"{config} unique: {harmonic_mean_unique/100:.2f}")
    print(f"{config} novel: {harmonic_mean_novel/100:.2f}\n")