### Dataset for Missing Data with Uncertainty

In [None]:
import numpy as np
import torch
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from torch.distributions import *
import pydgn
import math
from matplotlib import cm
import sys
from torch_geometric.utils import *
# plt.xkcd()

In [None]:
seed = 0

num_features = 1  # how many independent features to generate
num_components = 3  # for each feature, the number of components for its own mixture model
mean_deviation = 30
max_std = 5

### Plot the mixture

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# First, set seeds for each feature
feature_seed = random.randint(0, sys.maxsize) % 100
print(f'Chosen feature seed is {feature_seed}')

random.seed(feature_seed)
np.random.seed(feature_seed)
torch.manual_seed(feature_seed)

# mean between -100 and 100
mean = torch.rand(num_components, num_features)*(200) - 100
max_mean = mean + mean_deviation
min_mean = mean - mean_deviation

print(f'Chosen feature mean range is {mean}')

# for each feature, I generate a different mixture model (hence 1 in the argument of torch.rand)
mu, std = torch.rand(num_components, num_features)*(max_mean-min_mean) + min_mean, torch.rand(num_components, num_features)*max_std

### Generate Graphs with a fixed number of communities 5

In [None]:
import community as community_louvain
import networkx as nx
from numgraph.distributions import *
from numgraph.utils import *
from numpy.random import default_rng

def _find_between_community_edges(g, partition):

    edges = dict()

    for (ni, nj) in g.edges():
        ci = partition[ni]
        cj = partition[nj]

        if ci != cj:
            try:
                edges[(ci, cj)] += [(ni, nj)]
            except KeyError:
                edges[(ci, cj)] = [(ni, nj)]

    return edges

def _position_nodes(g, partition, **kwargs):

    communities = dict()
    for node, community in partition.items():
        try:
            communities[community] += [node]
        except KeyError:
            communities[community] = [node]
    pos = dict()
    for ci, nodes in communities.items():
        subgraph = g.subgraph(nodes)
        pos_subgraph = nx.spring_layout(subgraph, **kwargs)
        pos.update(pos_subgraph)
    return pos

def _position_communities(g, partition, **kwargs):
    # create a weighted graph, in which each node corresponds to a community,
    # and each edge weight to the number of edges between communities
    between_community_edges = _find_between_community_edges(g, partition)
    communities = set(partition.values())
    hypergraph = nx.DiGraph()
    hypergraph.add_nodes_from(communities)
    for (ci, cj), edges in between_community_edges.items():
        hypergraph.add_edge(ci, cj, weight=len(edges))
    # find layout for communities
    pos_communities = nx.spring_layout(hypergraph, **kwargs)
    # set node positions to position of community
    pos = dict()
    for node, community in partition.items():
        pos[node] = pos_communities[community]
    return pos

def community_layout(g, partition):
    pos_communities = _position_communities(g, partition, scale=3.)
    pos_nodes = _position_nodes(g, partition, scale=1.)
    # combine positions
    pos = dict()
    for node in g.nodes():
        pos[node] = pos_communities[node] + pos_nodes[node]
    return pos

def plot_sbm(G, seed):
    partition = community_louvain.best_partition(G, random_state=seed)
    pos = community_layout(G, partition)
    nx.draw(G, pos, node_color=list(partition.values()), arrowstyle='-|>')
    plt.show()
    
# SBM
print('SBM')
block_size = [15, 5, 3]
probs = [[0.5, 0.01, 0.01], 
         [0.01, 0.5, 0.01],
         [0.01, 0.01, 0.5]]
generator = lambda b, p, rng: erdos_renyi_coo(b, p)
e, _ = stochastic_block_model_coo(block_size, probs, generator, rng = default_rng(seed))
G = nx.from_edgelist(e)

plot_sbm(G, seed=seed)

In [None]:
from torch_geometric.data import Batch, Data
from torch_geometric.utils import *
from torch_geometric.transforms import RemoveIsolatedNodes
if not os.path.exists('GENERATED_DATA/missing_data/'):
    os.makedirs('GENERATED_DATA/missing_data/')

transform = RemoveIsolatedNodes()

# BACKUP
# min_size = 50
# max_size = 100
# inter_max_prob = 0.01
# inter_min_prob = 0.002
# intra_max_prob = 0.25
# intra_min_prob = 0.1
min_size = 50
max_size = 100
inter_max_prob = 0.01
inter_min_prob = 0.002
intra_max_prob = 0.25
intra_min_prob = 0.1

num_graphs = 100
per_graph_num_samples = 100

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

num_communities = 5
assert num_components == 3

per_community_dirichlet = [
    Dirichlet(torch.tensor([9., 1., 1.])),
    Dirichlet(torch.tensor([1., 9., 1.])),
    Dirichlet(torch.tensor([1., 1., 9.])),
    Dirichlet(torch.tensor([2., 2., 1.])),
    Dirichlet(torch.tensor([1., 1., 2.])),
]
assert len(per_community_dirichlet) == num_communities    

plot = False

dataset = []
for graph in range(num_graphs):
    if (graph+1)%100 == 0:
        print(f'Processed graph {graph+1}')
            
    block_size = ((torch.rand(num_communities)*(max_size-min_size) + min_size)/num_communities).int()
    community_assignment = torch.tensor([i for i in range(num_communities) for v in range(block_size[i])]).long()
    
    intra_probs = torch.rand(num_communities, num_communities)*(inter_max_prob-inter_min_prob) + inter_min_prob
    intra_probs.fill_diagonal_(0.)
    inter_probs = torch.eye(num_communities)*(torch.rand(num_communities, num_communities)*(intra_max_prob-intra_min_prob) + intra_min_prob)    
    probs = intra_probs + inter_probs
                       
    generator = lambda b, p, rng: erdos_renyi_coo(b, p)
    e, _ = stochastic_block_model_coo(block_size.tolist(), probs.tolist(), generator, directed=False)  
    G = nx.from_edgelist(e)
    graph_data = from_networkx(G)
    
    if plot:
        plt.figure()
        plot_sbm(G, seed=seed)

    assert is_undirected(graph_data.edge_index)
    
    mixing_weights = []
    for c in range(num_communities):
        mixing_weights_per_community = per_community_dirichlet[c].sample((block_size[c],))        
        mixing_weights.append(mixing_weights_per_community)

    mixing_weights = torch.cat(mixing_weights, dim=0)

    
    if plot:
        print(homophily(graph_data.edge_index, community_assignment))
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.scatter(mixing_weights[:,0], mixing_weights[:,1], mixing_weights[:,2])
        ax.view_init(30, 30) 
    
    graph_data = Data(x=mixing_weights, edge_index=graph_data.edge_index, y=community_assignment)
    
    # Remove isolated nodes
    graph_data = transform(graph_data)
    assert not torch.any(degree(graph_data.edge_index[1]) == 0)

    dataset.append(graph_data)
    if (graph+1)%num_graphs == 0:
        torch.save(dataset, f'GENERATED_DATA/missing_data/data_list_{graph+1}_step1.pt')
        dataset = []

### Load the dataset and perform one step of neighboring aggregation
#### Linearly Combine the node feature and the neighborhood contribution

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

from torch_scatter import scatter_min, scatter_max, scatter_mean, scatter_sum, scatter_std

single_node_weight = 0
struc_weight = 100
assert single_node_weight+struc_weight == 100

dataset_name = f'Synthetic_{single_node_weight}_{struc_weight}'

print((single_node_weight/100), (struc_weight/100))

plot = False

# for i in range(num_graphs//100):
#     index = (i+1)*100
#     print(index)
    
#     dataset = torch.load(f'GENERATED_DATA/missing_data/data_list_{index}_step1.pt')
    
#     new_dataset = []
#     num_samples = 0
#     for g in dataset:
        
#         # aggregate cluster assignments for each feature --> distribution of cluster according to neighbors
#         x_mean_aggr = scatter_mean(g.x[g.edge_index[0],:], g.edge_index[1], dim=0, out=torch.zeros_like(g.x[:,:]))
#         structure_dependent_mixing_weights = x_mean_aggr
#         weights = (single_node_weight/100)*g.x + (struc_weight/100)*structure_dependent_mixing_weights

#         mix = Categorical(probs=weights)          
#         comp = Independent(Normal(loc=mu.unsqueeze(0).repeat(g.x.shape[0], 1, 1), 
#                                   scale=std.unsqueeze(0).repeat(g.x.shape[0], 1, 1)),
#                            1)
#         mm = MixtureSameFamily(mix, comp)

# #         print(mu.shape, std.shape, weights.shape)


#         if plot:
#             fig = plt.figure()
#             ax = fig.add_subplot(1, 2, 1, projection='3d')              
#             ax.scatter(g.x[:,0], g.x[:,1], g.x[:,2])
#             ax.view_init(30, 30) 

#             ax = fig.add_subplot(1, 2, 2, projection='3d')              
#             ax.scatter(weights[:,0], weights[:,1], weights[:,2])
#             ax.view_init(30, 30) 
#             plt.show()

#         for _ in range(per_graph_num_samples):

#             if (num_samples+1)%1000 == 0:
#                 print(num_samples)

#             sample = mm.sample()
            
# #             print(sample.shape)
# #             
#             g_sample = Data(x=sample, edge_index=g.edge_index.clone(), y=g.y.clone())
        

#             new_dataset.append(g_sample)
#             num_samples += 1

#     if not os.path.exists(f'GENERATED_DATA/missing_data/{dataset_name}'):
#         os.makedirs(f'GENERATED_DATA/missing_data/{dataset_name}')
        
#     torch.save(new_dataset, f'GENERATED_DATA/missing_data/{dataset_name}/data_list_{index}.pt')

### Now compute dataset statistics for each community, to see how the node distribution changed

In [None]:
from torch_geometric.utils import *
all_x_pre = []
all_y_pre = []
all_degree = []
for i in range(num_graphs//100):
    index = (i+1)*num_graphs
    dataset = torch.load(f'GENERATED_DATA/missing_data/data_list_{index}_step1.pt')
    
    for g in dataset:
        all_x_pre.append(g.x)
        all_y_pre.append(g.y)
        all_degree.append(degree(g.edge_index[1], num_nodes=g.x.shape[0]))
        
all_x_pre = torch.cat(all_x_pre, dim=0)
all_y_pre = torch.cat(all_y_pre, dim=0)
all_degree = torch.cat(all_degree, dim=0)

In [None]:
from torch_geometric.utils import *
all_num_nodes = []
all_num_edges = []
for i in range(num_graphs//100):
    index = (i+1)*num_graphs
    dataset = torch.load(f'GENERATED_DATA/missing_data/{dataset_name}/data_list_{index}.pt')
    
    for g in dataset:
        all_num_nodes.append(g.x.shape[0])
        all_num_edges.append(g.edge_index.shape[1])
        
print(np.mean(all_num_nodes), np.mean(all_num_edges), len(all_num_nodes))

In [None]:
plt.hist(all_degree.numpy(), bins=40)
print(torch.unique(all_degree, return_counts=True))

In [None]:
all_x = []
all_y = []
print(dataset_name)
for i in range(num_graphs//100):
    index = (i+1)*num_graphs
    dataset = torch.load(f'GENERATED_DATA/missing_data/{dataset_name}/data_list_{index}.pt')
    
    for g in dataset:
        all_x.append(g.x)
        all_y.append(g.y)

all_x = torch.cat(all_x, dim=0)
all_y = torch.cat(all_y, dim=0)

In [None]:
data = all_x.numpy()
#plt.scatter(x=data[:10000, 10], y=data[:10000, 4], s=40, cmap='viridis')
#sns.kdeplot(x=data[:10000, 0], y=data[:10000, 4], s=40, cmap='viridis')
plt.hist(data[:20000, 0], bins=100)

## This code evaluates the entire network

In [None]:
import json
import os.path as osp
from torch_geometric.data import Batch

from pydgn.evaluation.config import Config
from pydgn.experiment.util import s2c

outer_k = 0
ASSESSMENT_FOLDER = f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_2layer_SyntheticDataset/MODEL_ASSESSMENT'
OUTER_FOLD_BASE = 'OUTER_FOLD_'
SELECTION_FOLDER = 'MODEL_SELECTION'
WINNER_CONFIG = 'winner_config.json'

outer_folder = osp.join(ASSESSMENT_FOLDER, OUTER_FOLD_BASE + str(outer_k + 1))
config_fname = osp.join(outer_folder, SELECTION_FOLDER, WINNER_CONFIG)

dataset_name = f'SyntheticDataset_{per_comm_weight}_{struc_weight}'
splits_filepath = f'DATA_SPLITS/{dataset_name}/SyntheticDataset_outer1_inner1.splits'
outer_folds = 1
inner_folds = 1

with open(config_fname, 'r') as f:
    best_config = json.load(f)

config_with_metadata = Config(best_config['config'])

dataset_getter_class = s2c(config_with_metadata.dataset_getter)
dataset_getter = dataset_getter_class(config_with_metadata.data_root,
                                      splits_filepath,
                                      s2c(config_with_metadata.dataset_class),
                                      dataset_name,
                                      s2c(config_with_metadata.data_loader),
                                      config_with_metadata.data_loader_args,
                                      outer_folds,
                                      inner_folds)

dataset_getter.set_inner_k(0)
dataset_getter.set_outer_k(0)

# not really used
dataset_getter.set_exp_seed(0)

batch_size = 32
shuffle = False

# Instantiate the Dataset
train_loader = dataset_getter.get_outer_train(batch_size=batch_size, shuffle=shuffle)
val_loader = dataset_getter.get_outer_val(batch_size=batch_size, shuffle=shuffle)
test_loader = dataset_getter.get_outer_test(batch_size=batch_size, shuffle=shuffle)

# Call this after the loaders: the datasets may need to be instantiated with additional parameters
dim_node_features = dataset_getter.get_dim_node_features()
dim_edge_features = dataset_getter.get_dim_edge_features()
dim_target = dataset_getter.get_dim_target()

In [None]:
import torch.nn as nn
from model_log import GSPN
device = 'cuda:3'
ckpt = torch.load(f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_2layer_SyntheticDataset/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth', map_location='cpu')

model = GSPN(dim_node_features, 0, dim_node_features, None, config_with_metadata['supervised_config'])

model_state = ckpt['model_state']
model.load_state_dict(ckpt['model_state'])
model.to(device)
model.eval()

In [None]:
masked_nodes = []
non_masked_nodes = []
x = []
x_imputed = []
edge_index = []
perc_masked_features = []
log_lik = []

curr_node_id = 0

for batch in test_loader:
    # Move data to device
    batch.to(device)

    output, embs, extra = model.forward(batch)
    
    edge_index.append(batch.edge_index.to('cpu') + curr_node_id)
    curr_node_id += batch.num_nodes
    
    perc_masked_features.append(batch.perc_masked_features.to('cpu'))
    
    # Move output to cpu
    batch.to('cpu')
    embs.to('cpu')
    for t in extra:
        if t is not None:
            t.to('cpu')
    if output is not None:
        output.to('cpu')
    
    log_lik_batch, _, _, x_batch, x_imputed_batch, masked_nodes_batch, _, _, _, _ = extra
    
    log_lik.append(log_lik_batch.to('cpu'))
    x.append(x_batch.to('cpu'))
    x_imputed.append(x_imputed_batch.to('cpu'))
    masked_nodes.append(masked_nodes_batch.to('cpu'))

log_lik = torch.cat(log_lik, dim=0)
x = torch.cat(x, dim=0)
x_imputed = torch.cat(x_imputed, dim=0)
masked_nodes = torch.cat(masked_nodes, dim=0)
edge_index = torch.cat(edge_index, dim=1)
perc_masked_features = torch.cat(perc_masked_features, dim=0)

In [None]:
model.to('cpu')

In [None]:
## Using the model's prediction, compute MSE for missing features and bin the results according to percentage of masked features for each node

In [None]:
x.shape

In [None]:
x = x.clone()
x_imputed = x_imputed.clone()
x_mvi = x.clone()
degree_vec = degree(edge_index[1], num_nodes=x.shape[0])

# Ensure MSE is 0 for
x[non_masked_nodes] = 0.
x_imputed[non_masked_nodes] = 0.

x_mvi[masked_nodes] = torch.nan
mv = torch.nanmean(x_mvi, dim=0, keepdim=True).repeat(x.shape[0], 1)
x_mvi[masked_nodes] = mv[masked_nodes]


# mean value imputation
mse_per_vertex_mvi = torch.nn.functional.mse_loss(x, x_mvi, reduction='none').mean(dim=1)


mse_per_vertex = torch.nn.functional.mse_loss(x, x_imputed, reduction='none').mean(dim=1)
mse_per_vertex.shape



In [None]:
xs = []
ys = []
ys_mvi = []
for i in range(1, 21):
    min_perc = (i-1)/20
    max_perc = i/20
    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  
    
    bin_values = log_lik[bin_mask].mean().detach().numpy()
    bin_values_mvi = log_lik[bin_mask].mean().log().detach().numpy()
    
    #bin_values = mse_per_vertex[bin_mask].mean().detach().numpy()
    #bin_values_mvi = mse_per_vertex_mvi[bin_mask].mean().log().detach().numpy()
    
    xs.append(i)
    ys.append(bin_values)
    ys_mvi.append(bin_values_mvi)
    
xs = torch.tensor(xs).numpy()

unique_degrees = torch.unique(degree_vec).int().numpy().tolist()

heatmap = torch.zeros(len(unique_degrees), 20).numpy()
for deg_id in range(len(unique_degrees)):
    for i in range(1, 21):
        min_perc = (i-1)/20
        max_perc = i/20
        bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  

        degree_mask = (degree_vec == unique_degrees[deg_id]) 
        bin_mask = torch.logical_and(bin_mask, degree_mask)
        
        heatmap[deg_id, i-1] = log_lik[bin_mask].mean().detach().numpy()

In [None]:
log_lik.mean()

In [None]:
ax = sns.heatmap(heatmap)
ax.invert_yaxis()
plt.yticks(np.arange(len(unique_degrees)))
ax.set_yticklabels(unique_degrees)

x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]
plt.xticks(np.arange(1, 21))
ax.set_xticklabels(x_ticks_labels, rotation=-30)

In [None]:
import torch.nn as nn
from model_log import GSPN
device = 'cuda:3'
ckpt = torch.load(f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_SyntheticDataset/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth', map_location='cpu')

outer_k = 0
ASSESSMENT_FOLDER = f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_SyntheticDataset/MODEL_ASSESSMENT'

outer_folder = osp.join(ASSESSMENT_FOLDER, OUTER_FOLD_BASE + str(outer_k + 1))
config_fname = osp.join(outer_folder, SELECTION_FOLDER, WINNER_CONFIG)

with open(config_fname, 'r') as f:
    best_config = json.load(f)

config_with_metadata = Config(best_config['config'])

model = GSPN(dim_node_features, 0, dim_node_features, None, config_with_metadata['supervised_config'])

model_state = ckpt['model_state']
model.load_state_dict(ckpt['model_state'])
model.to(device)
model.eval()

In [None]:
masked_nodes = []
non_masked_nodes = []
x = []
x_imputed = []
edge_index = []
perc_masked_features = []
log_lik_gmm = []

curr_node_id = 0

for batch in test_loader:
    # Move data to device
    batch.to(device)

    output, embs, extra = model.forward(batch)
    
    edge_index.append(batch.edge_index.to('cpu') + curr_node_id)
    curr_node_id += batch.num_nodes
    
    perc_masked_features.append(batch.perc_masked_features.to('cpu'))
    
    # Move output to cpu
    batch.to('cpu')
    embs.to('cpu')
    for t in extra:
        if t is not None:
            t.to('cpu')
    if output is not None:
        output.to('cpu')
    
    log_lik_batch, _, _, x_batch, x_imputed_batch, masked_nodes_batch, _, _, _, _ = extra
    
    log_lik_gmm.append(log_lik_batch.to('cpu'))
    x.append(x_batch.to('cpu'))
    x_imputed.append(x_imputed_batch.to('cpu'))
    masked_nodes.append(masked_nodes_batch.to('cpu'))

log_lik_gmm = torch.cat(log_lik_gmm, dim=0)
x = torch.cat(x, dim=0)
x_imputed = torch.cat(x_imputed, dim=0)
masked_nodes = torch.cat(masked_nodes, dim=0)
edge_index = torch.cat(edge_index, dim=1)
perc_masked_features = torch.cat(perc_masked_features, dim=0)

In [None]:
model.to('cpu')

In [None]:
x = x.clone()
x_imputed = x_imputed.clone()
x_mvi = x.clone()

# Ensure MSE is 0 for
x[non_masked_nodes] = 0.
x_imputed[non_masked_nodes] = 0.

mse_per_vertex_gmm = torch.nn.functional.mse_loss(x, x_imputed, reduction='none').mean(dim=1)
mse_per_vertex_gmm.shape

xs = []
ys_gmm = []


for i in range(1, 21):
    min_perc = (i-1)/20
    max_perc = i/20
    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  
    
    bin_values_gmm = log_lik_gmm[bin_mask].mean().detach().numpy()
    #bin_values_gmm = mse_per_vertex_gmm[bin_mask].mean().detach().numpy()
    
    xs.append(i)
    ys_gmm.append(bin_values_gmm)

xs = torch.tensor(xs).numpy()

unique_degrees = torch.unique(degree_vec).int().numpy().tolist()

heatmap_gmm = torch.zeros(len(unique_degrees), 20).numpy()
for deg_id in range(len(unique_degrees)):
    for i in range(1, 21):
        min_perc = (i-1)/20
        max_perc = i/20
        bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  

        degree_mask = (degree_vec == unique_degrees[deg_id]) 
        bin_mask = torch.logical_and(bin_mask, degree_mask)
        
        heatmap_gmm[deg_id, i-1] = log_lik_gmm[bin_mask].mean().detach().numpy()

In [None]:
ax = sns.heatmap(heatmap_gmm)
ax.invert_yaxis()
plt.yticks(np.arange(len(unique_degrees)))
ax.set_yticklabels(unique_degrees)

x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]
plt.xticks(np.arange(1, 21))
ax.set_xticklabels(x_ticks_labels, rotation=-30)

plt.figure()
ax = sns.heatmap(heatmap-heatmap_gmm)
ax.invert_yaxis()
plt.yticks(np.arange(len(unique_degrees)))
ax.set_yticklabels(unique_degrees)

x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]
plt.xticks(np.arange(1, 21))
ax.set_xticklabels(x_ticks_labels, rotation=-30)

In [None]:
from baseline_mask import *
model = MeanAggregation(dim_node_features, 0, dim_node_features, None, {})
model.to(device)

masked_nodes = []
non_masked_nodes = []
x = []
x_imputed = []
edge_index = []
perc_masked_features = []

curr_node_id = 0

for batch in test_loader:
    # Move data to device
    batch.to(device)

    output, embs, extra = model.forward(batch)
    
    edge_index.append(batch.edge_index.to('cpu') + curr_node_id)
    curr_node_id += batch.num_nodes
    
    perc_masked_features.append(batch.perc_masked_features.to('cpu'))
    
    # Move output to cpu
    batch.to('cpu')
    embs.to('cpu')
    for t in extra:
        if t is not None:
            t.to('cpu')
    if output is not None:
        output.to('cpu')
    
    _, _, _, x_batch, x_imputed_batch, masked_nodes_batch, _, _, _, _ = extra
    
    x.append(x_batch.to('cpu'))
    x_imputed.append(x_imputed_batch.to('cpu'))
    masked_nodes.append(masked_nodes_batch.to('cpu'))

x = torch.cat(x, dim=0)
x_imputed = torch.cat(x_imputed, dim=0)
masked_nodes = torch.cat(masked_nodes, dim=0)
edge_index = torch.cat(edge_index, dim=1)
perc_masked_features = torch.cat(perc_masked_features, dim=0)

In [None]:
model.to('cpu')

In [None]:
x = x.clone()
x_imputed = x_imputed.clone()
x_mvi = x.clone()

# Ensure MSE is 0 for
x[non_masked_nodes] = 0.
x_imputed[non_masked_nodes] = 0.

mse_per_vertex_baseline = torch.nn.functional.mse_loss(x, x_imputed, reduction='none').mean(dim=1)
mse_per_vertex_baseline.shape

xs = []
ys_baseline = []


for i in range(1, 21):
    min_perc = (i-1)/20
    max_perc = i/20
    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  
    
    bin_values_baseline = mse_per_vertex_baseline[bin_mask].mean().detach().numpy()
    
    xs.append(i)
    ys_baseline.append(bin_values_baseline)

xs = torch.tensor(xs).numpy()

unique_degrees = torch.unique(degree_vec).int().numpy().tolist()

mse_heatmap_baseline = torch.zeros(len(unique_degrees), 20).numpy()
for deg_id in range(len(unique_degrees)):
    for i in range(1, 21):
        min_perc = (i-1)/20
        max_perc = i/20
        bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  

        degree_mask = (degree_vec == unique_degrees[deg_id]) 
        bin_mask = torch.logical_and(bin_mask, degree_mask)
        
        mse_heatmap_baseline[deg_id, i-1] = mse_per_vertex_baseline[bin_mask].mean().detach().numpy()

plt.scatter(x_imputed[:1000, 0], x_imputed[:1000, 1])

In [None]:
neigh_aggr_score = torch.nn.functional.mse_loss(x, x_imputed, reduction='mean').mean(dim=0)
print(f'DGN score is {neigh_aggr_score}')

In [None]:
ax = sns.heatmap(mse_heatmap_baseline)
ax.invert_yaxis()
plt.yticks(np.arange(len(unique_degrees)))
ax.set_yticklabels(unique_degrees)

x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]
plt.xticks(np.arange(1, 21))
ax.set_xticklabels(x_ticks_labels, rotation=-30)

In [None]:
plt.style.use('seaborn-colorblind')

fig, ax = plt.subplots(1,1) 
plt.bar(xs-0.2, ys_gmm, width=0.2, label='GMM', fill=False, hatch='')
x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]
plt.xticks(np.arange(1, 21))
ax.set_xticklabels(x_ticks_labels, rotation=-90)
plt.ylabel('Missing features MSE')
plt.xlabel('Percentage of masked features per node')
plt.bar(xs, ys, width=0.2, label='BGC', fill=True, hatch='///')
#plt.bar(xs+0.2, ys_baseline, width=0.2, label='Baseline')
plt.legend()
plt.tight_layout()
plt.savefig('perf_vs_percentage_masking.png', dpi=350)

In [None]:
ys_gmm

In [None]:
ys

In [None]:
for i in range(1, 21):
    min_perc = (i-1)/20
    max_perc = i/20
    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))
    print(bin_mask.sum())