In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

while 'notebooks' in os.getcwd():
    os.chdir('..')

import torch
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from sklearn.metrics import roc_auc_score
import logging

from src.models import GraphSAGE, LinkPredictor
from src.data.gamma.arxiv import load_data, get_val_test_edges, prepare_adjencency, get_edge_index_from_adjencency
from src.train.gamma import GammaGraphSage

In [3]:
logging.basicConfig(
    format='%(asctime)s - %(levelname)s : %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [4]:
torch.cuda.is_available()

True

In [5]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
device

device(type='cuda', index=0)

## Data Loading

In [6]:
data = load_data()

data, edges_val, edges_test, neg_edges_val, neg_edges_test =\
    get_val_test_edges(data, remove_from_data=True, device=device)

data = prepare_adjencency(data, to_symmetric=True)

edge_index = get_edge_index_from_adjencency(data, device)

## Training

In [None]:
for run in range(30):
    gamma = GammaGraphSage(device, data.num_nodes, run=run)
    torch.cuda.empty_cache()
    gamma.train(edge_index,
                edges_val,
                edges_test,
                neg_edges_val,
                neg_edges_test,
                data.adj_t,
                data.y)

2022-04-24 20:02:27 - INFO : Run: 0000, Epoch: 0001, Train Loss: 1.3983, Valid loss: 1.1462, Test loss: 1.1484, Train AUC: 0.5829, Valid AUC: 0.5827, Test AUC: 0.5777
2022-04-24 20:02:34 - INFO : Run: 0000, Epoch: 0002, Train Loss: 1.0874, Valid loss: 1.0919, Test loss: 1.0924, Train AUC: 0.6073, Valid AUC: 0.6037, Test AUC: 0.5977
2022-04-24 20:02:40 - INFO : Run: 0000, Epoch: 0003, Train Loss: 1.0461, Valid loss: 1.0605, Test loss: 1.0612, Train AUC: 0.6244, Valid AUC: 0.6181, Test AUC: 0.6125
2022-04-24 20:02:47 - INFO : Run: 0000, Epoch: 0004, Train Loss: 1.0143, Valid loss: 1.0323, Test loss: 1.0333, Train AUC: 0.6375, Valid AUC: 0.6283, Test AUC: 0.6240
2022-04-24 20:02:53 - INFO : Run: 0000, Epoch: 0005, Train Loss: 0.9843, Valid loss: 1.0050, Test loss: 1.0059, Train AUC: 0.6438, Valid AUC: 0.6347, Test AUC: 0.6291
2022-04-24 20:02:59 - INFO : Run: 0000, Epoch: 0006, Train Loss: 0.9552, Valid loss: 0.9782, Test loss: 0.9796, Train AUC: 0.6515, Valid AUC: 0.6406, Test AUC: 0.635

In [10]:
for epoch in range(1, 1 + epochs):
    loss_train = gamma.train_epoch(
        edge_index,
        batch_size,
        data.adj_t)
    
    if epoch % eval_steps == 0:
        loss_val, loss_test, auc_train, auc_val, auc_test = gamma.eval(
            edge_index,
            edges_val,
            edges_test,
            neg_edges_val,
            neg_edges_test,
            data.adj_t,
            data.y)
        
        print(f'Epoch: {epoch:04d}, '
              f'Train Loss: {loss_train:.4f}, '
              f'Valid loss: {loss_val:.4f}, '
              f'Test loss: {loss_test:.4f}, '
              f'Train AUC: {auc_train:.4f}, '
              f'Valid AUC: {auc_val:.4f}, '
              f'Test AUC: {auc_test:.4f}')
        # losses = losses.append({
        #     'epoch': epoch,
        #     'train_loss': loss,
        #     'valid_loss': valid_loss.item(), 
        #     'test_loss': test_loss.item()
        # }, ignore_index=True)

Epoch: 0005, Train Loss: 0.9839, Valid loss: 1.0055, Test loss: 1.0055, Train AUC: 0.6452, Valid AUC: 0.6355, Test AUC: 0.6301
Epoch: 0010, Train Loss: 0.8508, Valid loss: 0.8845, Test loss: 0.8822, Train AUC: 0.6698, Valid AUC: 0.6610, Test AUC: 0.6532
Epoch: 0015, Train Loss: 0.7463, Valid loss: 0.7905, Test loss: 0.7880, Train AUC: 0.6799, Valid AUC: 0.6703, Test AUC: 0.6639
Epoch: 0020, Train Loss: 0.6621, Valid loss: 0.7172, Test loss: 0.7143, Train AUC: 0.6793, Valid AUC: 0.6713, Test AUC: 0.6645
Epoch: 0025, Train Loss: 0.5935, Valid loss: 0.6580, Test loss: 0.6554, Train AUC: 0.6762, Valid AUC: 0.6685, Test AUC: 0.6616
Epoch: 0030, Train Loss: 0.5353, Valid loss: 0.6104, Test loss: 0.6080, Train AUC: 0.6701, Valid AUC: 0.6631, Test AUC: 0.6566
Epoch: 0035, Train Loss: 0.4864, Valid loss: 0.5699, Test loss: 0.5677, Train AUC: 0.6664, Valid AUC: 0.6598, Test AUC: 0.6542
Epoch: 0040, Train Loss: 0.4462, Valid loss: 0.5385, Test loss: 0.5357, Train AUC: 0.6641, Valid AUC: 0.6575, T