In [1]:
import os
from glob import glob
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

## Sequential Dataset

In [2]:
class sequentialDataset(torch.utils.data.Dataset):
    def __init__(self, fp, file_type = ".pt") -> None:
        self.data_files = sorted(glob(os.path.join(fp, f"*{file_type}")))
        super().__init__()
    def __getitem__(self, idx):
        return torch.load(self.data_files[idx])[0]
    def __getindex__(self, idx):
        return torch.load(self.data_files[idx])[0]
    def __len__(self):
        return len(self.data_files)

## Model

In [3]:
from torch_geometric.nn import GCNConv
class GNN(torch.nn.Module):
    def __init__(self, input_size, feature_size, output_size):
        super().__init__()
        self.conv = GCNConv(in_channels=input_size, out_channels=feature_size)
        self.activation = nn.ReLU()
        self.fc = nn.Linear(in_features=feature_size, out_features=output_size)
    
    def forward(self, data):
        node_attr = F.normalize(data.x.float(), dim=0)
        num_pad = self.conv.in_channels - node_attr.shape[1]
        node_attr = torch.cat(
            (
                node_attr,
                torch.zeros((node_attr.shape[0], num_pad))
            ),
            -1
        )
        edge_index = data.edge_index.long()
        edge_weight = F.normalize(data.edge_attr.float().reshape(-1, 1), dim=0)
        x = self.conv(
                    x=node_attr, 
                    edge_index=edge_index, 
                    edge_weight=edge_weight
                    )
        x = self.activation(x)
        x = self.fc(x)
        return x

## Trainer

In [4]:
from sklearn import metrics
class Trainer:
    def __init__(self, model, dataset, args):
        self.device = args.device
        self.model = model.to(self.device)
        self.dataset = dataset
        self.epochs = args.num_epochs
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = args.learning_rate)
        self.criterion = nn.MSELoss().to(self.device)
        self.val_idx = int(len(self.dataset) * (1-args.val_size))
        self.best_model_weights = self.model.state_dict()
        self.best_epoch = 0
        self.best_val_loss = float('inf')
        self.threshold = args.threshold
    def train(self):
        for epoch in range(1, self.epochs + 1):
            train_loss = 0.0
            val_loss = 0.0
            for i, data in enumerate(self.dataset):
                if i < self.val_idx:
                    loss = self._train_step(self.model, data)
                    train_loss += loss / self.val_idx
                else:
                    loss = self._val_step(self.model, data)
                    val_loss += loss / (len(self.dataset) - self.val_idx)
            if epoch % 20 == 0 or epoch == self.epochs:
                print(f"""
                    epoch {epoch}:
                        train loss: {train_loss},
                        val loss: {val_loss}
                """)
            if self.best_val_loss > val_loss:
                self.best_val_loss = val_loss
                self.best_epoch = epoch
                self.best_model_weights = model.state_dict()
        self.model.load_state_dict(self.best_model_weights)
        print(
            f"""
            best model loss is:
                val loss: {self.best_val_loss} @ epoch: {self.best_epoch}
            """
        )
        self._benchmark()
        return self.model

    def _train_step(self, model, data):
        self.optimizer.zero_grad()
        logits, target = self._shared_step(model, data)
        loss = self.criterion(logits, target)
        loss.backward()
        self.optimizer.step()
        return loss.item()
    def _val_step(self, model, data):
        logits, target = self._shared_step(model, data)
        loss = self.criterion(logits, target)
        return loss.item()
    def _shared_step(self, model, data):
        data.x = data.x.to(self.device)
        data.edge_index = data.edge_index.to(self.device)
        data.edge_attr = data.edge_attr.to(self.device)
        target = data.y.float().squeeze().to(self.device)
        logits = model(data).squeeze()
        return logits, target
    def _benchmark(self):
        train_preds = []
        train_trues = []
        val_preds = []
        val_trues = []
        for i, data in enumerate(self.dataset):
            logits, target = self._shared_step(self.model, data)
            pred = logits.detach().cpu().numpy()
            target = target.cpu().numpy()
            if i < self.val_idx:
                train_preds.append(pred)
                train_trues.append(target)
            else:
                val_preds.append(pred)
                val_trues.append(target)
        train_preds = np.hstack(train_preds)
        train_trues = np.hstack(train_trues)
        val_preds = np.hstack(val_preds)
        val_trues = np.hstack(val_trues)
        
        train_preds = (train_preds > self.threshold).astype(int)
        train_trues = (train_trues > self.threshold).astype(int)
        val_preds = (val_preds > self.threshold).astype(int)
        val_trues = (val_trues > self.threshold).astype(int)
        
        print(
            f"""
                best model performance is:
                    train acc: {metrics.accuracy_score(train_trues, train_preds)}
                    val acc: {metrics.accuracy_score(val_trues, val_preds)}

                    train f1 score {metrics.f1_score(train_trues, train_preds)}
                    val f1 score {metrics.f1_score(val_trues, val_preds)}

                    train precision score {metrics.precision_score(train_trues, train_preds)}
                    val precision score {metrics.precision_score(val_trues, val_preds)}

                    train recall score {metrics.recall_score(train_trues, train_preds)}
                    val recall score {metrics.recall_score(val_trues, val_preds)}

                    num of pos prediction in training set {train_preds[train_preds == 1].shape[0]}
                    num of neg prediction in training set {train_preds[train_preds == 0].shape[0]}
                    num of pos prediction in val set {val_preds[val_preds == 1].shape[0]}
                    num of neg prediction in val set {val_preds[val_preds == 0].shape[0]}
            """
        )
        print(
            metrics.classification_report(val_trues, val_preds)
            )


In [5]:
import argparse

graph_path = "../data/processed/twitter"
dataset = sequentialDataset(graph_path)
model = GNN(70, 32, 1)
args = dict(
    num_epochs = 500,
    learning_rate = 2e-5,
    device = "cpu",
    val_size = .2,
    threshold = 0.035
)

args = argparse.Namespace(**args)

In [6]:
trainer = Trainer(model, dataset, args)

In [7]:
model = trainer.train()


                    epoch 20:
                        train loss: 0.01609255376388319,
                        val loss: 0.014974718680605292
                

                    epoch 40:
                        train loss: 0.010830949089722708,
                        val loss: 0.01091243140399456
                

                    epoch 60:
                        train loss: 0.009835411590756848,
                        val loss: 0.0103581000585109
                

                    epoch 80:
                        train loss: 0.009516938822343946,
                        val loss: 0.010200228542089462
                

                    epoch 100:
                        train loss: 0.00923016874003224,
                        val loss: 0.010020856163464487
                

                    epoch 120:
                        train loss: 0.008941071981098503,
                        val loss: 0.00982963724527508
                

                    epoch 140:
      