In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import scipy.sparse as sp
from scipy.sparse import load_npz

import torch

from cell.utils import train_val_test_split, link_prediction_performance
from cell.cell import Cell, EdgeOverlapCriterion, LinkPredictionCriterion
from cell.graph_statistics import compute_graph_statistics

## Load Graph

In [2]:
_A_obs = load_npz(f'./data/CORA-ML.npz')

val_share = 0.1
test_share = 0.05
seed = 10

train_graph, val_ones, val_zeros, test_ones, test_zeros = train_val_test_split(_A_obs,
                                                                               val_share,
                                                                               test_share,
                                                                               seed)

## Edge Overlap Criterion

In [3]:
# initialize model with EO-criterion
model = Cell(A=train_graph,
             H=9,
             callbacks=[EdgeOverlapCriterion(invoke_every=10, edge_overlap_limit=.5)])

In [4]:
# train model 
model.train(steps=200,
            optimizer_fn=torch.optim.Adam,
            optimizer_args={'lr': 0.1,
                            'weight_decay': 1e-7})

Step:  10/200 Loss: 5.96087 Edge-Overlap: 0.033 Total-Time: 4
Step:  20/200 Loss: 4.05810 Edge-Overlap: 0.284 Total-Time: 9
Step:  30/200 Loss: 3.31627 Edge-Overlap: 0.440 Total-Time: 13
Step:  40/200 Loss: 2.97212 Edge-Overlap: 0.548 Total-Time: 18


In [5]:
generated_graph = model.sample_graph()

In [6]:
compute_graph_statistics(generated_graph)

{'d_max': 171.0,
 'd_min': 1.0,
 'd': 4.8284697508896794,
 'LCC': 2808,
 'wedge_count': 77574.0,
 'claw_count': 1355036.0,
 'triangle_count': 1468,
 'square_count': 7467.0,
 'power_law_exp': 1.8207420386565856,
 'gini': 0.4509338489558854,
 'rel_edge_distr_entropy': 0.9506692508260353,
 'assortativity': -0.07828604309967814,
 'clustering_coefficient': 0.0032500981523738114,
 'cpl': 5.229567709346975}

In [7]:
compute_graph_statistics(train_graph)

{'d_max': 209.0,
 'd_min': 1.0,
 'd': 4.8284697508896794,
 'LCC': 2810,
 'wedge_count': 95802.0,
 'claw_count': 2301166.0,
 'triangle_count': 2867,
 'square_count': 14969.0,
 'power_law_exp': 1.8585544293441962,
 'gini': 0.4855507830860135,
 'rel_edge_distr_entropy': 0.9408725499709523,
 'assortativity': -0.07557187056300232,
 'clustering_coefficient': 0.0037376703810155375,
 'cpl': 5.672324214617732}

## Validation Criterion

In [8]:
# initialize model with LP-criterion
model = Cell(A=train_graph,
             H=9,
             callbacks=[LinkPredictionCriterion(invoke_every=2,
                                                val_ones=val_ones,
                                                val_zeros=val_zeros,
                                                max_patience=3)])

In [9]:
# train model 
model.train(steps=200,
            optimizer_fn=torch.optim.Adam,
            optimizer_args={'lr': 0.1,
                            'weight_decay': 1e-6})

Step:   2/200 Loss: 7.92695 ROC-AUC Score: 0.607 Average Precision: 0.606 Total-Time: 0
Step:   4/200 Loss: 7.72123 ROC-AUC Score: 0.683 Average Precision: 0.686 Total-Time: 1
Step:   6/200 Loss: 7.29660 ROC-AUC Score: 0.727 Average Precision: 0.733 Total-Time: 2
Step:   8/200 Loss: 6.68226 ROC-AUC Score: 0.764 Average Precision: 0.773 Total-Time: 2
Step:  10/200 Loss: 6.00534 ROC-AUC Score: 0.803 Average Precision: 0.815 Total-Time: 3
Step:  12/200 Loss: 5.41419 ROC-AUC Score: 0.841 Average Precision: 0.852 Total-Time: 4
Step:  14/200 Loss: 4.97864 ROC-AUC Score: 0.868 Average Precision: 0.878 Total-Time: 4
Step:  16/200 Loss: 4.65078 ROC-AUC Score: 0.881 Average Precision: 0.891 Total-Time: 5
Step:  18/200 Loss: 4.35771 ROC-AUC Score: 0.887 Average Precision: 0.898 Total-Time: 6
Step:  20/200 Loss: 4.10243 ROC-AUC Score: 0.894 Average Precision: 0.906 Total-Time: 6
Step:  22/200 Loss: 3.90202 ROC-AUC Score: 0.903 Average Precision: 0.915 Total-Time: 7
Step:  24/200 Loss: 3.72894 ROC-

In [10]:
link_prediction_performance(scores_matrix=model._scores_matrix, val_ones=test_ones, val_zeros=test_zeros)

(0.945433759838192, 0.945153138360737)