#### Dataset Creation

In [1]:
import random
import torch
import numpy as np

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [2]:
from data_loading.models_dataset import ModelDataset

config_params = dict(
    timeout = 120,
    min_enr = 1.2,
    min_edges = 10
)
model_dataset = ModelDataset('ecore_555', reload=False, **config_params)
# dataset = ModelDataset('modelset', reload=False, remove_duplicates=True, **config_params)
# dataset = ModelDataset('mar-ecore-github', reload=True, **config_params)

Loading ecore_555 from pickle
Loaded ecore_555 with 281 graphs
Loaded ecore_555 with 281 graphs
Graphs: 281


In [3]:
from data_loading.graph_dataset import GraphNodeDataset

graph_data_params = dict(
    distance=1,
    reload=False,
    use_embeddings=True,
    embed_model_name='bert-base-uncased',
    ckpt='results/ecore_555/edge_cls/checkpoint-7260'
)

graph_dataset = GraphNodeDataset(model_dataset, **graph_data_params)
# modelset_graph_dataset = GraphEdgeDataset(modelset, **graph_data_params)
# mar_graph_dataset = GraphEdgeDataset(mar, **graph_data_params)

Creating graphs:   0%|          | 0/281 [00:00<?, ?it/s]

Processing graphs:   0%|          | 0/281 [00:00<?, ?it/s]

Train Node classes: {0: 7089, 1: 2248}
Test Node classes: {0: 1921, 1: 560}


In [4]:
# from data_loading.graph_dataset import GraphEdgeDataset

# graph_data_params = dict(
#     distance=1,
#     reload=False,
#     add_negative_train_samples=True,
#     neg_sampling_ratio=1,
#     use_edge_types=False,
#     use_embeddings=True,
#     embed_model_name='bert-base-uncased',
#     ckpt='results/ecore_555/edge_cls/checkpoint-7260'
# )

# graph_dataset = GraphEdgeDataset(model_dataset, **graph_data_params)
# modelset_graph_dataset = GraphEdgeDataset(modelset, **graph_data_params)
# mar_graph_dataset = GraphEdgeDataset(mar, **graph_data_params)

In [5]:
num_classes = graph_dataset.num_node_classes
print(num_classes)

2


In [6]:
from utils import get_device
from models.gnn_layers import GNNConv


device = get_device()


gnn_conv_model = GNNConv(
    model_name='SAGEConv',
    input_dim=768,
    hidden_dim=128,
    out_dim=128,
    num_layers=3,
    num_heads=4,
    residual=True,
    l_norm=False,
    dropout=0.3,
    aggregation='sum'
)

gnn_conv_model.to(device)

GNNConv(
  (conv_layers): ModuleList(
    (0): SAGEConv(768, 128, aggr=SumAggregation())
    (1-2): 2 x SAGEConv(128, 128, aggr=SumAggregation())
  )
  (activation): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
)

In [7]:
from models.gnn_layers import NodeClassifer


mlp_predictor = NodeClassifer(
    input_dim=128,
    hidden_dim=128,
    num_layers=3, 
    num_classes=num_classes,
    bias=True,
)
mlp_predictor.to(device)

NodeClassifer(
  (layers): ModuleList(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=128, out_features=2, bias=True)
  )
)

In [42]:
from random import shuffle
from torch_geometric.loader import DataLoader
import torch
from collections import defaultdict
from typing import List
import pandas as pd
from sklearn.metrics import (
    balanced_accuracy_score,
    f1_score, 
    recall_score, 
    accuracy_score
)

from itertools import chain
from tqdm.auto import tqdm
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import Adam



class Trainer:
    """
    Trainer class for GNN Link Prediction
    This class is used to train the GNN model for the link prediction task
    The model is trained to predict the link between two nodes
    """
    def __init__(
            self, 
            model: GNNConv, 
            predictor: NodeClassifer, 
            dataset: List[GraphNodeDataset],
            lr=1e-3,
            num_epochs=100,
            batch_size=32
        ) -> None:
        self.model = model
        self.predictor = predictor
        self.model.to(device)
        self.predictor.to(device)
        

        dataset = [g.data for g in dataset]
        shuffle(dataset)

        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self.optimizer = Adam(chain(model.parameters(), predictor.parameters()), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        
        self.edge2index = lambda g: torch.stack(list(g.edges())).contiguous()
        self.results = list()
        self.criterion = nn.CrossEntropyLoss()

        self.num_epochs = num_epochs

        print("GNN Trainer initialized.")



    def train(self):
        self.model.train()
        self.predictor.train()

        all_preds, all_labels = list(), list()
        epoch_loss = 0
        epoch_metrics = defaultdict(float)
        # for i, data in tqdm(enumerate(self.dataloader), desc=f"Training batches", total=len(self.dataloader)):
        for data in self.dataloader:
            self.optimizer.zero_grad()
            self.model.zero_grad()
            self.predictor.zero_grad()
            
            h = self.get_logits(data.x, data.edge_index)
            scores = self.get_prediction_score(h)[data.train_node_idx]
            labels = data.node_classes[data.train_node_idx]
            loss = self.compute_loss(scores, labels)
            
            all_preds.append(scores.detach())
            all_labels.append(labels)

            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            epoch_loss += loss.item()
                        
        
        all_preds = torch.cat(all_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        epoch_metrics = self.compute_metrics(all_preds, all_labels)
        epoch_metrics['loss'] = epoch_loss        
        epoch_metrics['phase'] = 'train'


    def test(self):
        self.model.eval()
        self.predictor.eval()
        all_preds, all_labels = list(), list()
        with torch.no_grad():
            epoch_loss = 0
            epoch_metrics = defaultdict(float)
            # for _, data in tqdm(enumerate(self.dataloader), desc=f"Evaluating batches", total=len(self.dataloader)):
            for data in self.dataloader:
                h = self.get_logits(data.x, data.edge_index)
                scores = self.get_prediction_score(h)[data.test_node_idx]
                labels = data.node_classes[data.test_node_idx]

                loss = self.compute_loss(scores, labels)
                epoch_loss += loss.item()


                all_preds.append(scores.detach())
                all_labels.append(labels)
                
                
                

            all_preds = torch.cat(all_preds, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            epoch_metrics = self.compute_metrics(all_preds, all_labels)
            
            epoch_metrics['loss'] = epoch_loss
            epoch_metrics['phase'] = 'test'
            # print(f"Epoch Test Loss: {epoch_loss}\nTest Accuracy: {epoch_acc}\nTest F1: {epoch_f1}")
            self.results.append(epoch_metrics)

            print(f"Epoch: {len(self.results)}\n{epoch_metrics}")
            

    def get_logits(self, x, edge_index):
        edge_index = edge_index.to(device)
        x = x.to(device)
        h = self.model(x, edge_index)
        return h
    

    def get_prediction_score(self, h):
        h = h.to(device)
        prediction_score = self.predictor(h)
        return prediction_score

    def compute_loss(self, scores, labels):
        loss = self.criterion(scores, labels.to(device))
        return loss
    

    def compute_metrics(self, scores, labels):
        preds = torch.argmax(scores, dim=-1)
        f1 = f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='weighted')
        accuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
        recall = recall_score(labels.cpu().numpy(), preds.cpu().numpy(), average='weighted')

        balanced_accuracy = balanced_accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())

        return {
            'f1-score': f1,
            'balanced_accuracy': balanced_accuracy,
            'recall': recall,
            'accuracy': accuracy,
        }
    
    def plot_metrics(self):
        results = pd.DataFrame(self.results)
        df = pd.DataFrame(results, index=range(1, len(results)+1))
        df['epoch'] = df.index

        columns = [c for c in df.columns if c not in ['epoch', 'phase']]
        df.loc[df['phase'] == 'test'].plot(x='epoch', y=columns, kind='line')


    def run(self):
        for _ in tqdm(range(self.num_epochs), desc="Running Epochs"):
            self.train()
            self.test()
        


In [43]:
trainer = Trainer(
    gnn_conv_model,
    mlp_predictor,
    graph_dataset,
    lr=1e-3,
    num_epochs=100,
    batch_size=32
)

GNN Trainer initialized.


In [44]:
trainer.run()

Running Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1
{'f1-score': 0.6702892804990056, 'balanced_accuracy': 0.5, 'recall': 0.7702539298669892, 'accuracy': 0.7702539298669892, 'loss': 6.0398798286914825, 'phase': 'test'}
Epoch: 2
{'f1-score': 0.6585581355685857, 'balanced_accuracy': 0.5105171856722999, 'recall': 0.7230955259975816, 'accuracy': 0.7230955259975816, 'loss': 5.343740016222, 'phase': 'test'}
Epoch: 3
{'f1-score': 0.6817785163041368, 'balanced_accuracy': 0.5125048190327585, 'recall': 0.7335751713018944, 'accuracy': 0.7335751713018944, 'loss': 5.470100373029709, 'phase': 'test'}
Epoch: 4
{'f1-score': 0.670141756306479, 'balanced_accuracy': 0.5057504492326401, 'recall': 0.7678355501813785, 'accuracy': 0.7678355501813785, 'loss': 4.972208917140961, 'phase': 'test'}
Epoch: 5
{'f1-score': 0.671578438532021, 'balanced_accuracy': 0.5057912415463722, 'recall': 0.7420395002015316, 'accuracy': 0.7420395002015316, 'loss': 5.223889917135239, 'phase': 'test'}
Epoch: 6
{'f1-score': 0.6681040958133532, 'balanced_accuracy': 0.500344906