# Link Prediction for Heterogeneous graph

## 0. Enviroment setup

In [1]:
import os
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm

from sklearn.metrics import roc_auc_score, f1_score
from torch_geometric.utils import negative_sampling
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
from torch_geometric.loader import LinkNeighborLoader

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

In [2]:
from models import *
from dataloaders import *

In [3]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [4]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## 1. Data Pre-processing

In [5]:
data = getHiveDataset(
    'dataset/hive/v3/nodes.csv',
    'dataset/hive/v3/edges.csv'
)
data = T.ToUndirected()(data)
data

HeteroData(
  user={ x=[12585, 6] },
  link={ x=[23612, 1] },
  (user, upvote, link)={
    edge_index=[2, 95390],
    edge_attr=[95390, 2],
    y=[95390],
  },
  (user, post, link)={
    edge_index=[2, 19181],
    edge_attr=[19181, 1],
    y=[19181],
  },
  (user, downvote, link)={
    edge_index=[2, 2243],
    edge_attr=[2243, 2],
    y=[2243],
  },
  (link, rev_upvote, user)={
    edge_index=[2, 95390],
    edge_attr=[95390, 2],
    y=[95390],
  },
  (link, rev_post, user)={
    edge_index=[2, 19181],
    edge_attr=[19181, 1],
    y=[19181],
  },
  (link, rev_downvote, user)={
    edge_index=[2, 2243],
    edge_attr=[2243, 2],
    y=[2243],
  }
)

In [6]:
edge_types, rev_edge_types = [data.edge_types[:3], data.edge_types[3:]]

In [7]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3, 
    add_negative_train_samples=False,
    # neg_sampling_ratio=2.0,
    edge_types=edge_types,
    rev_edge_types=rev_edge_types, 
)

train_data, val_data, test_data = transform(data)

## 2. Training

In [8]:
model = GraphSAGE(
    in_channels=-1, 
    out_channels=128, 
    hidden_channels=[64, 128, 256, 256, 512], 
    metadata=data.metadata(), 
    aggr_scheme='mean',
)


In [9]:
lr = 1e-3
optim = torch.optim.Adam(model.parameters(), lr=lr)
model.set_optimizer(optim)

wandb_logger = WandbLogger(
    project="LinkPrediction",
    log_model='all',
    save_dir='results/log',
    name='graphsage',
    config={
        "learning_rate": lr,
        "architecture": GraphSAGE,
        "epoch": 100
    },
    entity='ssc_project'
)

In [12]:
# for edge_types, rev_edge_types in edges:
for edge_type in edge_types:
    loss_checkpoint_callback = ModelCheckpoint(
        monitor=f'{edge_type[1]}_val_loss',
        dirpath='results/checkpoints/graphsage/loss',
        filename=f'LinkPred-{{epoch:02d}}-{{{edge_type[1]}_val_loss:.2f}}',
        save_top_k=3,
        save_last=True,
        mode='min',
        every_n_epochs=1
    )
    roc_auc_checkpoint_callback = ModelCheckpoint(
        monitor=f'{edge_type[1]}_val_roc_auc',
        dirpath='results/checkpoints/graphsage/roc_auc',
        filename=f'LinkPred-{{epoch:02d}}-{{{edge_type[1]}_val_roc_auc:.2f}}',
        save_top_k=3,
        save_last=True,
        mode='max',
        every_n_epochs=1
    )
    trainer = L.Trainer(
        max_epochs=100,
        check_val_every_n_epoch=1,
        callbacks=[
            loss_checkpoint_callback, 
            roc_auc_checkpoint_callback,
        ],
        logger=wandb_logger
    )
    
    train_batch_size=128
    val_batch_size=val_data[edge_type]['edge_label'].size(0)
    model.set_trainval_info(edge_type, train_batch_size, val_batch_size)
    train_edge_label_index = train_data[edge_type].edge_label_index
    train_edge_label = train_data[edge_type].edge_label
    train_loader = LinkNeighborLoader(
        data=train_data,
        num_neighbors=[20, 20],
        neg_sampling_ratio=2.0,
        edge_label_index=(edge_type, train_edge_label_index),
        edge_label=train_edge_label,
        batch_size=train_batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=8,
    )

    val_edge_label_index = val_data[edge_type].edge_label_index
    val_edge_label = val_data[edge_type].edge_label
    val_loader = LinkNeighborLoader(
        data=val_data,
        num_neighbors=[20, 20],
        neg_sampling_ratio=None,
        edge_label_index=(edge_type, val_edge_label_index),
        edge_label=val_edge_label,
        batch_size=val_batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8,
    )
    
    trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/hoangtran/miniconda3/envs/graph/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /media/hoangtran/T7/graph-anomaly-detection/link_prediction/results/checkpoints/graphsage/loss exists and is not empty.
/home/hoangtran/miniconda3/envs/graph/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /media/hoangtran/T7/graph-anomaly-detection/link_prediction/results/checkpoints/graphsage/roc_auc exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type              | Params
---------------------------------------------
0 | layers | GraphModule       | 3.6 M 
1 | crit   | BCEWithLogitsLoss | 0     
---------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     To

Sanity Checking: |                                     | 0/? [00:00<?, ?it/s]

Training: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/hoangtran/miniconda3/envs/graph/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /media/hoangtran/T7/graph-anomaly-detection/link_prediction/results/checkpoints/graphsage/loss exists and is not empty.
/home/hoangtran/miniconda3/envs/graph/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /media/hoangtran/T7/graph-anomaly-detection/link_prediction/results/checkpoints/graphsage/roc_auc exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type              | Params
---------------------------------------------
0 | layers | GraphModule       | 3.6 M 
1 | crit   | BCEWithLogitsLoss | 0     
---------------------------------------------
3.6 M     Trainable p

Sanity Checking: |                                     | 0/? [00:00<?, ?it/s]

/home/hoangtran/miniconda3/envs/graph/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (36) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

Validation: |                                          | 0/? [00:00<?, ?it/s]

OSError: [Errno 28] No space left on device: '/media/hoangtran/T7/graph-anomaly-detection/link_prediction/results/checkpoints/graphsage/loss/last-v8.ckpt' -> '/home/hoangtran/.local/share/wandb/artifacts/staging/tmpa_1iuoah'

wandb: ERROR Error uploading "/home/hoangtran/.local/share/wandb/artifacts/staging/tmpwbbsqqqf": FileNotFoundError, [Errno 2] No such file or directory: '/home/hoangtran/.local/share/wandb/artifacts/staging/tmpwbbsqqqf'
wandb: ERROR Uploading artifact file failed. Artifact won't be committed.
wandb: ERROR Error uploading "/home/hoangtran/.local/share/wandb/artifacts/staging/tmpf7_uvc16": FileNotFoundError, [Errno 2] No such file or directory: '/home/hoangtran/.local/share/wandb/artifacts/staging/tmpf7_uvc16'
wandb: ERROR Uploading artifact file failed. Artifact won't be committed.
wandb: ERROR Error uploading "/home/hoangtran/.local/share/wandb/artifacts/staging/tmp02mkzmlq": FileNotFoundError, [Errno 2] No such file or directory: '/home/hoangtran/.local/share/wandb/artifacts/staging/tmp02mkzmlq'
wandb: ERROR Uploading artifact file failed. Artifact won't be committed.
wandb: ERROR Error uploading "/home/hoangtran/.local/share/wandb/artifacts/staging/tmpr5z46itg": FileNotFoundError, [E