# Notebook for the experiments
In this notebook are contained the following features:
* GRAFF + Link prediction,

The main tools that have been exploited are [PyTorch](https://pytorch.org/) (1.13.0), [PyTorch-Lightning](https://www.pytorchlightning.ai/index.html) (1.5.10), [Pytorch-geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) (2.3.0) and [Weights & Biases](https://wandb.ai/)

### Requirements to run the notebook

In [1]:
# !pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# !pip install pytorch-lightning==1.5.10
# !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
# !pip install torch_geometric
# !pip install wandb
# !pip install ogb

## Importing the libraries

In [2]:
######## IMPORT EXTERNAL FILES ###########
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torch.nn as nn

import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger

import wandb
######### IMPORT INTERNAL FILES ###########
import sys
sys.path.append("../../src")
from GRAFF import *
from config import *

  from .autonotebook import tqdm as notebook_tqdm


Link prediction features initialized.....


### System configuration

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_gpus = 1 if device == 'cuda' else 0

if wb:
    wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdifra00[0m ([33mdeepl_wizards[0m). Use [1m`wandb login --relogin`[0m to force relogin


## PyTorch Lightning DataModule (Link Prediction)

In [4]:
class DataModuleLP(pl.LightningDataModule):

    def __init__(self,  train_set, val_set, test_set, neg_edges, mode, batch_size):

        self.mode = mode  # "hp" or "test"
        self.batch_size = batch_size
        self.train_set, self.val_set, self.test_set = train_set, val_set, test_set
        self.neg_edges = neg_edges

    def setup(self, stage=None):
        if stage == 'fit':
            if self.mode == 'test':
                # For the test phase, after the hp tuning we unify train and val.
                self.train_set.edge_index = torch.concat((self.train_set.edge_index, self.val_set.edge_label_index), dim = -1)

            train_mask_edg = 0.8 * self.train_set.edge_index.shape[1] 

            self.train_set.pos_forward_pass = self.train_set.edge_index[:, :int(train_mask_edg)]

            # The remaining (30%) is used for the prediction
            self.train_set.pos_masked_edges = self.train_set.edge_index[:, int(train_mask_edg):]
            # The same amount used as positive in the prediction is taken from the negatives
            self.train_set.neg_edges = self.neg_edges[:, :self.train_set.pos_masked_edges.shape[1]]


        elif stage == 'test':
            # During the inference we attempt to predict the whole set as true.
            if self.mode == 'hp':
                self.val_set.neg_edges = self.neg_edges[:, self.train_set.pos_masked_edges.shape[1]: self.train_set.pos_masked_edges.shape[1] + 
                                                                    self.val_set.edge_label_index.shape[1]]
            elif self.mode == 'test':
                self.test_set.neg_edges = self.neg_edges[:, self.train_set.pos_masked_edges.shape[1]:self.train_set.pos_masked_edges.shape[1]+
                                                            self.test_set.edge_label_index.shape[1]]

    def train_dataloader(self, *args, **kwargs):
        return DataLoader([self.train_set], batch_size = batch_size, shuffle = False)
    def val_dataloader(self, *args, **kwargs):
        if self.mode == 'hp':
            return DataLoader([self.val_set], batch_size = batch_size, shuffle = False)
        elif self.mode == 'test':
            return DataLoader([self.test_set], batch_size = batch_size, shuffle = False)


In [5]:
train_data = torch.load(dataset_name + "/train_data.pt")
val_data = torch.load(dataset_name + "/val_data.pt")
test_data = torch.load(dataset_name + "/test_data.pt")
negative_edges = torch.load(dataset_name + "/negatives.pt")

In [6]:
mode = 'test'  # hp: Hyperparameter selection mode
dataM = DataModuleLP(train_data.clone(), val_data.clone(), test_data.clone(), negative_edges, mode = mode, batch_size = batch_size)
dataM.setup(stage='fit')
dataM.setup(stage='test') 

### PyTorch Lightning Callbacks

In [7]:

class Get_Metrics(Callback):

    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):

        # Compute the metrics
        train_loss = sum(
            pl_module.train_prop['loss']) / len(pl_module.train_prop['loss'])
        train_acc = sum(
            pl_module.train_prop['HR@100']) / len(pl_module.train_prop['HR@100'])
        test_loss = sum(
            pl_module.test_prop['loss']) / len(pl_module.test_prop['loss'])
        test_acc = sum(pl_module.test_prop['HR@100']) / \
            len(pl_module.test_prop['HR@100'])

        # Log the metrics
        pl_module.log(name='Loss on train', value=train_loss,
                      on_epoch=True, prog_bar=True, logger=True)
        pl_module.log(name='HR@100 on train', value=train_acc,
                      on_epoch=True, prog_bar=True, logger=True)
        pl_module.log(name='Loss on test', value=test_loss,
                      on_epoch=True, prog_bar=True, logger=True)
        pl_module.log(name='HR@100 on test', value=test_acc,
                      on_epoch=True, prog_bar=True, logger=True)

        # Re-initialize the metrics
        pl_module.train_prop['loss'] = []
        pl_module.train_prop['HR@100'] = []
        pl_module.test_prop['loss'] = []
        pl_module.test_prop['HR@100'] = []

## PyTorch Lightning Training Module (Node Classification)

In [8]:
class TrainingModule(pl.LightningModule):

    def __init__(self, model, lr, wd):
        super().__init__()
        self.model = model.to(device)
        self.lr = lr
        self.wd = wd

        self.train_prop = {'loss': [], 'HR@100': []}
        self.test_prop = {'loss': [], 'HR@100': []}
       

    def training_step(self, batch, batch_idx):
       
        pos_pred, neg_pred, negatives = self.model(batch, train = True)

        # print("negatives: ", negatives.shape)
        # print("pos_pred: ", pos_pred.shape)

        acc = evaluate(pos_pred, negatives)

        loss = -torch.log(pos_pred + 1e-15).mean() - torch.log(1 - neg_pred + 1e-15).mean()
        self.train_prop['loss'].append(loss)
        self.train_prop['HR@100'].append(acc)

        return loss

    def validation_step(self, batch, batch_idx):

        pos_pred, neg_pred, negatives = self.model(batch, train = False)

        
        acc = evaluate(pos_pred, negatives)


        loss = -torch.log(pos_pred + 1e-15).mean() - torch.log(1 - neg_pred + 1e-15).mean()
        self.test_prop['loss'].append(loss)
        self.test_prop['HR@100'].append(acc)

        return loss

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.lr, weight_decay=self.wd)
        return self.optimizer


def evaluate(pos_pred, negatives):

  
    indices = list(range(pos_pred.shape[0]))
    # print("indices: ", indices)

    scoring_tensor = torch.cat((pos_pred, negatives), dim = 0)
    # print("score_tensor: ", scoring_tensor.shape)
    top_indices = torch.topk(scoring_tensor.squeeze(1), 100).indices
    # print("top_indices: ", top_indices)
    # print("type top_indices: ", type(top_indices))


    hr = 0

    for i in range(len(indices)):
        if i in top_indices:
            # print("{} is in {}".format(i, top_indices))
            hr+=1

    hr /= len(indices)
    


    return hr

In [9]:

#### hp enables a grid search on a wide set of hyperparameters.
if mode != 'hp':
   model = PhysicsGNN_LP(dataset, hidden_dim, output_dim, num_layers, mlp_layer, link_bias, dropout)
   pl_training_module = TrainingModule(model, lr, wd)
 


### Hyperparameters Tuning

In [10]:
def sweep_train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        model = PhysicsGNN_LP(dataset, config.hidden_dim, config.output_dim,
                              config.num_layers, config.mlp_layer, config.link_bias, config.dropout,
                              step=config.step)
        pl_training_module = TrainingModule(
            model, config.lr, config.wd)
        exp_name = "Sweep_LinkPred"
        wandb_logger = WandbLogger(
            project=project_name, name=exp_name, config=hyperparameters)
        trainer = trainer = pl.Trainer(
            max_epochs=epochs,  # maximum number of epochs.
            gpus=num_gpus,  # the number of gpus we have at our disposal.
            default_root_dir="", callbacks=[Get_Metrics(), EarlyStopping('Loss on test', mode='min', patience=15)],
            logger=wandb_logger
        )
        trainer.fit(model=pl_training_module, datamodule=dataM)


if mode == 'hp':

    if wb == False:
        if wb == False:
            model = PhysicsGNN_LP(dataset, hidden_dim, output_dim, num_layers, mlp_layer, link_bias, dropout)
            pl_training_module = TrainingModule(model, lr, wd)
    else:

    
        import pprint

        pprint.pprint(sweep_config)

        sweep_id = wandb.sweep(sweep_config, project=project_name)

        wandb.agent(sweep_id, sweep_train, count=500)

        wandb.finish()

In [11]:


if wb:
    exp_name = "Node_class_lr: " + \
        str(hyperparameters['learning rate']) + \
        '_wd: ' + str(hyperparameters['weight decay'])
    description = ' initial tests'
    exp_name += description
    wandb_logger = WandbLogger(
        project=project_name, name=exp_name, config=hyperparameters)


trainer = trainer = pl.Trainer(
    max_epochs=epochs,  # maximum number of epochs.
    gpus=num_gpus,  # the number of gpus we have at our disposal.
    default_root_dir="", callbacks=[Get_Metrics(), EarlyStopping('Loss on test', mode='min', patience=15)],
    logger=wandb_logger if wb else None

)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model = pl_training_module, datamodule = dataM)
if wb:
    wandb.finish()