In [1]:
import warnings
warnings.filterwarnings('ignore')

import pickle
import numpy as np
import scipy.sparse as sp
from scipy.sparse import load_npz

import torch

from cell.utils import link_prediction_performance
from cell.cell import Cell, EdgeOverlapCriterion, LinkPredictionCriterion
from cell.graph_statistics import compute_graph_statistics

# Load graph and validation-/ test edges (same split as in paper)

In [2]:
train_graph = load_npz('./data/CORA-ML_train.npz')
with open('./data/link_prediction.p', 'rb') as handle:
    val_ones, val_zeros, test_ones, test_zeros = pickle.load(handle)

In [11]:
train_graph

<2810x2810 sparse matrix of type '<class 'numpy.float64'>'
	with 13566 stored elements in Compressed Sparse Row format>

In [16]:
len(val_ones)
print(val_ones[798], val_ones[0])

[1914 2693] [2693 1914]


## Edge overlap criterion

In [3]:
# initialize model with EO-criterion
model = Cell(A=train_graph,
             H=9,
             callbacks=[EdgeOverlapCriterion(invoke_every=10, edge_overlap_limit=.5)])

In [4]:
# train model 
model.train(steps=200,
            optimizer_fn=torch.optim.Adam,
            optimizer_args={'lr': 0.1,
                            'weight_decay': 1e-7})

Step:  10/200 Loss: 6.00380 Edge-Overlap: 0.035 Total-Time: 3
Step:  20/200 Loss: 4.07511 Edge-Overlap: 0.281 Total-Time: 6
Step:  30/200 Loss: 3.32519 Edge-Overlap: 0.442 Total-Time: 10
Step:  40/200 Loss: 2.96738 Edge-Overlap: 0.551 Total-Time: 13


In [5]:
generated_graph = model.sample_graph()

In [6]:
compute_graph_statistics(generated_graph)

{'d_max': 199.0,
 'd_min': 1.0,
 'd': 4.8277580071174375,
 'LCC': 2798,
 'wedge_count': 80653.0,
 'claw_count': 1748949.0,
 'triangle_count': 1312,
 'square_count': 6372.0,
 'power_law_exp': 1.8271905683271248,
 'gini': 0.4563266025646071,
 'rel_edge_distr_entropy': 0.9497564252141653,
 'assortativity': -0.0732222705079979,
 'clustering_coefficient': 0.04880165647899024,
 'cpl': 5.270692451767936}

In [7]:
compute_graph_statistics(train_graph)

{'d_max': 238.0,
 'd_min': 1.0,
 'd': 4.8277580071174375,
 'LCC': 2810,
 'wedge_count': 101747.0,
 'claw_count': 3033514.0,
 'triangle_count': 2802,
 'square_count': 14268.0,
 'power_law_exp': 1.8550648593086239,
 'gini': 0.4825742921255409,
 'rel_edge_distr_entropy': 0.9406652031225717,
 'assortativity': -0.07626405450439543,
 'clustering_coefficient': 0.08261668648707088,
 'cpl': 5.630006245811316}

## Validation criterion

In [8]:
# initialize model with LP-criterion
model = Cell(A=train_graph,
             H=9,
             callbacks=[LinkPredictionCriterion(invoke_every=2,
                                                val_ones=val_ones,
                                                val_zeros=val_zeros,
                                                max_patience=3)])

In [17]:
print(len(val_ones), len(val_zeros), len(test_ones), len(test_zeros))

1596 1596 800 800


In [9]:
# train model 
model.train(steps=200,
            optimizer_fn=torch.optim.Adam,
            optimizer_args={'lr': 0.1,
                            'weight_decay': 1e-6})

Step:   2/200 Loss: 7.92516 ROC-AUC Score: 0.579 Average Precision: 0.572 Total-Time: 0
Step:   4/200 Loss: 7.71899 ROC-AUC Score: 0.640 Average Precision: 0.642 Total-Time: 1
Step:   6/200 Loss: 7.29670 ROC-AUC Score: 0.681 Average Precision: 0.694 Total-Time: 1
Step:   8/200 Loss: 6.69369 ROC-AUC Score: 0.722 Average Precision: 0.741 Total-Time: 2
Step:  10/200 Loss: 6.03576 ROC-AUC Score: 0.778 Average Precision: 0.798 Total-Time: 2
Step:  12/200 Loss: 5.45817 ROC-AUC Score: 0.830 Average Precision: 0.848 Total-Time: 3
Step:  14/200 Loss: 5.03277 ROC-AUC Score: 0.866 Average Precision: 0.881 Total-Time: 3
Step:  16/200 Loss: 4.69655 ROC-AUC Score: 0.882 Average Precision: 0.898 Total-Time: 4
Step:  18/200 Loss: 4.38538 ROC-AUC Score: 0.891 Average Precision: 0.906 Total-Time: 4
Step:  20/200 Loss: 4.12517 ROC-AUC Score: 0.899 Average Precision: 0.911 Total-Time: 4
Step:  22/200 Loss: 3.91924 ROC-AUC Score: 0.908 Average Precision: 0.918 Total-Time: 5
Step:  24/200 Loss: 3.73268 ROC-

#### Link prediction performance for ROC-AUC score and average precision

In [10]:
link_prediction_performance(scores_matrix=model._scores_matrix, val_ones=test_ones, val_zeros=test_zeros)

(0.9347875, 0.9434454020897184)