In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

while 'notebooks' in os.getcwd():
    os.chdir('..')

import torch
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from sklearn.metrics import roc_auc_score
import logging
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.offline as pyo
import numpy as np

from src.torch_geo_models import GraphSAGE, LinkPredictor
from src.data.gamma.arxiv import load_data, get_val_test_edges, prepare_adjencency, get_edge_index_from_adjencency
from src.train.gamma import GammaGraphSage

Using backend: pytorch


In [3]:
logging.basicConfig(
    format='%(asctime)s - %(levelname)s : %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [4]:
torch.cuda.is_available()

True

In [5]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
device

device(type='cuda', index=0)

## Data Loading

In [6]:
data = load_data()

data, edges_val, edges_test, neg_edges_val, neg_edges_test =\
    get_val_test_edges(data, remove_from_data=True, device=device)

data = prepare_adjencency(data, to_symmetric=True)

edge_index = get_edge_index_from_adjencency(data, device)

## Training

In [None]:
for run in range(30):
    gamma = GammaGraphSage(device, data.num_nodes, run=run)
    torch.cuda.empty_cache()
    gamma.train(edge_index,
                edges_val,
                edges_test,
                neg_edges_val,
                neg_edges_test,
                data.adj_t,
                data.y)

2022-05-24 22:41:53 - INFO : Run: 0000, Epoch: 0001, Train Loss: 1.3988, Valid loss: 1.1481, Test loss: 1.1483, Train AUC: 0.5862, Valid AUC: 0.5845, Test AUC: 0.5808


## Results

In [None]:
metrics = GammaGraphSage.read_metrics()
print(metrics.shape)
metrics.head()

### Sumarize metrics per epoch

In [None]:
metrics_cols = metrics.columns[2:]
metrics_cols

In [None]:
metrics['loss_train'] = metrics\
    .replace('None', np.nan)\
    ['loss_train']\
    .astype(float)

In [None]:
epoch_metrics = metrics\
    .fillna(-1)\
    .groupby('epoch')\
    [metrics_cols]\
    .agg(['mean', 'std'])
epoch_metrics.columns = [f'{x[0]}_{x[1]}' for x in epoch_metrics.columns]
epoch_metrics.head()

In [None]:
epoch_metrics.sort_values('auc_val_mean', ascending=False).head()

In [None]:
plt.rcParams['font.size'] = '14'
fig, ax1 = plt.subplots(1, 1, figsize=(16, 10))

ax1.set_title('Gamma metrics')

ax1.errorbar(
    epoch_metrics.index[1:],
    epoch_metrics['loss_train_mean'][1:],
    yerr=epoch_metrics['loss_train_std'][1:],
    ls='-',
    label='Train')

ax1.errorbar(
    epoch_metrics.index,
    epoch_metrics['loss_val_mean'],
    yerr=epoch_metrics['loss_val_std'],
    ls='-',
    label='Validation')

ax1.errorbar(
    epoch_metrics.index,
    epoch_metrics['loss_test_mean'],
    yerr=epoch_metrics['loss_test_std'],
    ls='-',
    label='Test')

ax1.legend(loc='upper right', title='Loss', bbox_to_anchor=[0.85, 1])
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')

ax2 = ax1.twinx()

ax2.errorbar(
    epoch_metrics.index[1:],
    epoch_metrics['auc_train_mean'][1:],
    yerr=epoch_metrics['auc_train_std'][1:],
    ls=':',
    label='Train')

ax2.errorbar(
    epoch_metrics.index,
    epoch_metrics['auc_val_mean'],
    yerr=epoch_metrics['auc_val_std'],
    ls=':',
    label='Validation')

ax2.errorbar(
    epoch_metrics.index,
    epoch_metrics['auc_test_mean'],
    yerr=epoch_metrics['auc_test_std'],
    ls=':',
    label='Test')

ax2.legend(loc='upper right', title='ROC AUC', bbox_to_anchor=[1, 1])
ax2.set_ylabel('ROC AUC')

plt.savefig('reports/images/gamma_training_metrics.pdf')

### Best model

In [None]:
best_model_metrics = metrics.sort_values('auc_val', ascending=False).iloc[0]
best_model_metrics

In [None]:
gamma = GammaGraphSage.load_model(
    best_model_metrics['run'],
    best_model_metrics['epoch'],
    device,
    data.num_nodes)

In [None]:
gamma.forward(edge_index, data.adj_t)

In [None]:
gamma.eval(
            edge_index,
            edges_val,
            edges_test,
            neg_edges_val,
            neg_edges_test,
            data.adj_t,
            data.y)

In [None]:
type(gamma.predictor)

In [None]:
data = load_data()
data = prepare_adjencency(data, to_symmetric=True)

In [None]:
gamma.forward(data.edge_index, data.adj_t)

In [None]:
data.edge_index

In [None]:
data.num_nodes