In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from gmdn_dataset import TUDatasetInterface, BarabasiAlbertDataset, ErdosRenyiDataset
import json
import os
import os.path as osp
from gmdn import GMDN
from pydgn.experiment.experiment import Experiment
from pydgn.experiment.util import s2c
from torch.distributions import Binomial, Independent, Categorical, MixtureSameFamily, Normal
from torch_geometric.utils import *
from networkx import *

In [None]:
data_root = 'DATA/'
size = 100
dataset_name = 'alchemy_full'
dataset_class = 'data.dataset.TUDatasetInterface'
d = TUDatasetInterface(data_root, 'alchemy_full', use_node_attr=True)
y = d.data.y
len(d)

In [None]:
mdn_config_path = f'RESULTS/MDN/GMDN_{dataset_name}_SupervisedTask/MODEL_ASSESSMENT/OUTER_FOLD_1/MODEL_SELECTION/winner_config.json' 
mdn_ckpt_path = f'RESULTS/MDN/GMDN_{dataset_name}_SupervisedTask/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth'
dgn_config_path = f'RESULTS/DGN/GMDN_{dataset_name}_SupervisedTask/MODEL_ASSESSMENT/OUTER_FOLD_1/MODEL_SELECTION/winner_config.json' 
dgn_ckpt_path = f'RESULTS/DGN/GMDN_{dataset_name}_SupervisedTask/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth'
gmdn_config_path = f'RESULTS/GMDN/GMDN_{dataset_name}_SupervisedTask/MODEL_ASSESSMENT/OUTER_FOLD_1/MODEL_SELECTION/winner_config.json' 
gmdn_ckpt_path = f'RESULTS/GMDN/GMDN_{dataset_name}_SupervisedTask/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth'
mdn_config = json.load(open(mdn_config_path, 'r'))['config']
mdn_ckpt = torch.load(mdn_ckpt_path)['model_state']
dgn_config = json.load(open(dgn_config_path, 'r'))['config']
dgn_ckpt = torch.load(dgn_ckpt_path)['model_state']
gmdn_config = json.load(open(gmdn_config_path, 'r'))['config']
gmdn_ckpt = torch.load(gmdn_ckpt_path)['model_state']

In [None]:
mdn_exp = Experiment(mdn_config, f'OUTPUTS/MDN_TESTS_{dataset_name}')
dgn_exp = Experiment(dgn_config, f'OUTPUTS/DGN_TESTS_{dataset_name}')
gmdn_exp = Experiment(gmdn_config, f'OUTPUTS/GMDN_TESTS_{dataset_name}')

In [None]:
batch_size = len(d)
shuffle = False
device = 'cuda:0'
dataset_getter_class = s2c('pydgn.data.provider.DataProvider')
dataset_getter = dataset_getter_class(data_root,
                                      'SPLITS',
                                      'SPLITS/alchemy_full/alchemy_full_outer1_inner1.splits',
                                      s2c(dataset_class),
                                      dataset_name,
                                      1, # outer_folds
                                      1, # inner folds
                                      2, # num_workers
                                      True)  # pin memory
dataset_getter.set_outer_k(0)
dataset_getter.set_inner_k(0)



# Instantiate the Dataset Loaders
dim_node_features = dataset_getter.get_dim_node_features()
dim_edge_features = dataset_getter.get_dim_edge_features()
dim_target = dataset_getter.get_dim_target()
train_loader = dataset_getter.get_outer_train(batch_size=batch_size, shuffle=shuffle)
val_loader = dataset_getter.get_outer_val(batch_size=batch_size, shuffle=shuffle)
test_loader = dataset_getter.get_outer_test(batch_size=batch_size, shuffle=shuffle)

In [None]:
dataset_getter.dataset.data.x.shape

In [None]:
# Instantiate the Model
mdn_model = mdn_exp.create_supervised_model(dim_node_features, dim_edge_features, dim_target)
mdn_model.load_state_dict(mdn_ckpt)
mdn_model.to(device)
dgn_model = dgn_exp.create_supervised_model(dim_node_features, dim_edge_features, dim_target)
dgn_model.load_state_dict(dgn_ckpt)
dgn_model.to(device)
gmdn_model = gmdn_exp.create_supervised_model(dim_node_features, dim_edge_features, dim_target)
gmdn_model.load_state_dict(gmdn_ckpt)
gmdn_model.to(device)

In [None]:
with torch.no_grad():
    for b in train_loader:
        b.to(device)
        mdn_train_res = mdn_model(b)
        dgn_train_res = dgn_model(b)
        gmdn_train_res = gmdn_model(b)
        break
with torch.no_grad():
    for b in val_loader:
        b.to(device)
        mdn_val_res = mdn_model(b)
        dgn_val_res = dgn_model(b)
        gmdn_val_res = gmdn_model(b)
        break
with torch.no_grad():
    for b in test_loader:
        b.to(device)
        mdn_test_res = mdn_model(b)
        dgn_test_res = dgn_model(b)
        gmdn_test_res = gmdn_model(b)
        break

In [None]:
component_index = 0  # third component
#component_index = 4

def gen_probs(res, component_index):
    params, weights = res[5]
    weights = weights.cpu()
    mu = params[0].cpu()
    var = params[1].cpu()
    experts = weights.shape[1]

    mu = mu[:,:,component_index].unsqueeze(2)
    var = var[:,:,component_index].unsqueeze(2)
        
    mix = Categorical(weights)
    g = Independent(Normal(mu, var), 1)
    mm = MixtureSameFamily(mix, g)

    probs = []
    x = np.linspace(y[:,component_index].min(), y[:,component_index].max(), 1000)
    for i in range(1000):
        probs.append(mm.log_prob(torch.tensor([x[i]]).float().repeat(mu.shape[0],1)).exp())
    probs = torch.stack(probs, dim=1)

    return probs, x

In [None]:
mdn_train_probs, x = gen_probs(mdn_train_res, component_index)
dgn_train_probs, _ = gen_probs(dgn_train_res, component_index)
gmdn_train_probs, _ = gen_probs(gmdn_train_res, component_index)

mdn_val_probs, x = gen_probs(mdn_val_res, component_index)
dgn_val_probs, _ = gen_probs(dgn_val_res, component_index)
gmdn_val_probs, _ = gen_probs(gmdn_val_res, component_index)

mdn_test_probs, x = gen_probs(mdn_test_res, component_index)
dgn_test_probs, _ = gen_probs(dgn_test_res, component_index)
gmdn_test_probs, _ = gen_probs(gmdn_test_res, component_index)

In [None]:
mdn_probs = torch.cat((mdn_train_probs, mdn_val_probs, mdn_test_probs), dim=0)
dgn_probs = torch.cat((dgn_train_probs, dgn_val_probs, dgn_test_probs), dim=0)
gmdn_probs = torch.cat((gmdn_train_probs, gmdn_val_probs, gmdn_test_probs), dim=0)

In [None]:
'''
for i in range(1024):
    print(i)
    plt.figure()
    plt.plot(x, mdn_train_probs[i].numpy(), linestyle=':', linewidth=2, label='MDN')
    plt.plot(x, dgn_train_probs[i].numpy(), linestyle='--', linewidth=2, label='DGN')
    plt.plot(x, gmdn_train_probs[i].numpy(), linewidth=2, label='GMDN')
    plt.xlabel('output')
    plt.ylabel('pdf')
    plt.rcParams.update({'font.size': 12})
    plt.tight_layout()
    plt.legend()
    #plt.savefig(f'OUTPUTS/alchemy_study_component{component_index}_ex{i}.eps')
    plt.show()
'''

In [None]:
for i in range(511,2048):#[40, 44, 63, 842, 874, 548, 213, 511]:
    print(i)
    plt.figure()
    plt.plot(x, mdn_probs[i].numpy(), linestyle=':', linewidth=2, label='MDN')
    plt.plot(x, dgn_probs[i].numpy(), linestyle='--', linewidth=2, label='DGN')
    plt.plot(x, gmdn_probs[i].numpy(), linewidth=2, label='GMDN')
    plt.xlabel('output')
    plt.ylabel('pdf')
    plt.rcParams.update({'font.size': 12})
    plt.tight_layout()
    plt.legend()
    #plt.savefig(f'OUTPUTS/alchemy_study_component{component_index}_ex{i}.eps')
    plt.show()

Interesting values for component 0

40, 44, 63, 842, 874, 548 (many more in the middle or for other components)

## RAND STUDY

In [None]:
max_val, _ = y.max(dim=0)
min_val, _ = y.min(dim=0)

In [None]:
max_val

In [None]:
min_val

In [None]:
py = torch.prod(1/(max_val-min_val), dim=0)
py

In [None]:
# p(y) = \prod(for each dimension i) p(y_i) using a uniform distribution
torch.log(py)

In [None]:
u = 1/(max_val-min_val)
u.log().sum()

In [None]:
u.log()

In [None]:
data = []
for b in train_loader:
    data.extend(b.to_data_list())

In [None]:
for b in val_loader:
    data.extend(b.to_data_list())
for b in test_loader:
    data.extend(b.to_data_list())

In [None]:
nx_data = [to_networkx(d, node_attrs=['x'], to_undirected=True) for d in data]

In [None]:
len(nx_data)

In [None]:
y = [d.y for d in data]

In [None]:
# I want to exclude 3D info from graph isomorphism test
print(nx_data[0].nodes[1]['x'])

In [None]:
def node_match_callback(n1, n2):
    return n1['x'][3:] == n2['x'][3:]

In [None]:
ref_index = 804#842
ref_graph = nx_data[ref_index]

iso_graphs = []
for i in range(len(nx_data)):
    if i == ref_index:
        continue
    
    if is_isomorphic(nx_data[i], ref_graph, node_match=node_match_callback):
        iso_graphs.append(i)
        
iso_targets = np.array([y[iso_idx][:, component_index].item() for iso_idx in iso_graphs])
print(iso_graphs)

In [None]:
plt.figure()
for i in [ref_index] + iso_graphs:
    target = dataset_getter.dataset[i].y[0,0].item()
    plt.axvline(x=target, color='#D2E3F0', linestyle='-', alpha=1)    
plt.plot(x, mdn_probs[ref_index].numpy(), linestyle=':', color='C0', linewidth=2, label='MDN')
plt.plot(x, dgn_probs[ref_index].numpy(), linestyle='--', color='C1', linewidth=2, label='DGN')
plt.plot(x, gmdn_probs[ref_index].numpy(), linewidth=2, color='C2', label='GMDN')
plt.xlabel('output')
plt.ylabel('pdf')
plt.rcParams.update({'font.size': 12})
plt.tight_layout()
plt.legend()
plt.savefig(f'OUTPUTS/alchemy_study_component{component_index}_ex{ref_index}.eps')
plt.show()