# Data preprocessing
---

## imports

In [None]:
import torch
from torch_geometric.datasets import Reddit, Amazon
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
import random
from torch.functional import F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import NeighborLoader
from models import *
from torch_geometric.data import Data

# DATA PREPARATION

In [2]:
# reading the Amazon dataset
amazon_computers_dataset = Amazon(root='data/Amazon', name='Computers')
data = amazon_computers_dataset[0]
data.num_classes = data.y.unique().shape[0]

In [None]:
# print info about the dataset
print("data", data)
print("num nodes", data.num_nodes)
print("Num edges", data.num_edges)
print("num features", data.num_features)
print("is undirected", data.is_undirected())
print("is directed", data.is_directed())

In [4]:
# plots for data visualization and exploration
# G = to_networkx(data, to_undirected=True)
# pos = nx.spring_layout(G)
# plt.figure(figsize=(8, 8))
# nx.draw(G, pos, node_size=10, width=0.5)
# plt.show()

In [5]:
def create_masks(data, train_ratio, val_ratio):
  num_nodes = data.num_nodes
  indices = list(range(num_nodes))
  np.random.shuffle(indices)
  
  # Create masks
  train_mask = torch.zeros(num_nodes, dtype=torch.bool)
  val_mask = torch.zeros(num_nodes, dtype=torch.bool)
  test_mask = torch.zeros(num_nodes, dtype=torch.bool)
  
  # Set proportions for train, val, and test
  train_end = int(train_ratio * num_nodes)
  val_end = train_end + int(val_ratio * num_nodes)
  train_mask[indices[:train_end]] = True
  val_mask[indices[train_end:val_end]] = True
  test_mask[indices[val_end:]] = True
  
  # Assign the custom masks to the dataset
  data.train_mask = train_mask
  data.val_mask = val_mask
  data.test_mask = test_mask


def train_test_split_graph(data: Data, train_ratio: float, val_ratio: float, batch_size: int):
    """
    Split the graph data into train, validation, and test sets
    :param data: The graph data
    :param train_ratio: The ratio of the training set
    :param val_ratio: The ratio of the validation set
    :param batch_size: The batch size
    
    :return: The train, validation, and test data loaders
    """

    # Call the function to create masks
    create_masks(data, train_ratio, val_ratio)

    # train, validation, and test node indices based on the masks
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)
    test_idx = data.test_mask.nonzero(as_tuple=False).view(-1)
    
    def create_data_loader(data, indices):
        return DataLoader(data[indices], batch_size=batch_size, shuffle=True)
      
    def create_neighbor_loader(data, indices, batch_size=batch_size):
        return NeighborLoader(data, num_neighbors=[30] * 2, batch_size=batch_size, input_nodes=indices)

    # create the data loaders
    train_loader = create_neighbor_loader(data, train_idx)
    val_loader = create_neighbor_loader(data, val_idx, batch_size=4096*2)
    test_loader = create_neighbor_loader(data, test_idx)

    return train_loader, val_loader, test_loader

# TRAINING

In [6]:
def train_epoch(model, optimizer, data_loader, device, data):
    """Train the model for one epoch using NeighborSampler mini-batches."""
    model.train()
    total_loss = 0

    for batch_size, n_id, adjs in data_loader:
        # `n_id` includes target nodes and all sampled neighbors in this mini-batch
        # `batch_size` is the number of target nodes for which loss should be calculated

        adjs = [adj.to(device) for adj in adjs]  # Move sampled adjacency matrices to GPU

        # Move input features of all nodes in `n_id` to the device
        x_input = data.x[n_id].to(device)
        
        # Compute the model's predictions for the mini-batch
        out = model(x_input, adjs[0].edge_index)
        
        # Calculate loss only for the first `batch_size` target nodes
        loss = F.nll_loss(out[:batch_size], data.y[n_id[:batch_size]].to(device))
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)


@torch.no_grad()
def evaluate(model, data_loader, device, data):
    """Evaluate the model on the validation or test set using NeighborSampler mini-batches."""
    model.eval()
    total_correct = 0

    for batch_size, n_id, adjs in data_loader:
        adjs = [adj.to(device) for adj in adjs]
        
        # Compute predictions for all nodes in the mini-batch
        x_input = data.x[n_id].to(device)
        out = model(x_input, adjs[0].edge_index)

        # Get predictions only for the first `batch_size` target nodes
        pred = out[:batch_size].max(dim=1)[1]
        
        # Compare with the actual labels of the target nodes
        total_correct += (pred == data.y[n_id[:batch_size]].to(device)).sum().item()

    # Calculate accuracy based on the number of target nodes in the entire set
    return total_correct / len(data_loader.dataset)


# METRICS

In [7]:
from sklearn.metrics import f1_score, balanced_accuracy_score


def accuracy(predictions, labels):
    correct = (predictions.argmax(dim=1) == labels).sum().item()
    total = labels.size(0)
    return correct / total


def f1(predictions, labels):
    preds = predictions.argmax(dim=1).cpu().numpy()
    labels = labels.cpu().numpy()
    return f1_score(labels, preds, average='weighted')
  
def balanced_accuracy(predictions, labels):
    preds = predictions.argmax(dim=1).cpu().numpy()
    labels = labels.cpu().numpy()
    return balanced_accuracy_score(labels, preds)

# TRAINING

In [8]:
def train_epoch(model, optimizer, loss_fn, train_loader, device, metrics):
    model.train()
    all_preds = []
    all_labels = []
    losses = []

    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        out, h = model(batch.x, batch.edge_index)  # Receive only one output
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        all_preds.append(out.detach().cpu())
        all_labels.append(batch.y.detach().cpu())
        
    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute metrics on the entire dataset
    avg_metrics = {metric_name: metric_fn(all_preds, all_labels) 
                  for metric_name, metric_fn in metrics.items()}
    avg_loss = np.mean(losses)

    return avg_loss, avg_metrics

@torch.no_grad()
def validate(model, loss_fn, val_loader, device, metrics):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    losses = []

    for batch in val_loader:
        batch = batch.to(device)
        out, h = model(batch.x, batch.edge_index)  # Receive only one output
        loss = loss_fn(out, batch.y)

        losses.append(loss.item())
        all_preds.append(out.detach().cpu())
        all_labels.append(batch.y.detach().cpu())
        
    # with torch.no_grad():
    #   _, h = model(val_loader.data.x, val_loader.data.edge_index)
    #   visualize(h, val_loader.data.y)

    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute metrics on the entire dataset
    avg_metrics = {metric_name: metric_fn(all_preds, all_labels) 
                  for metric_name, metric_fn in metrics.items()}
    avg_loss = np.mean(losses)

    return avg_loss, avg_metrics

def training_loop(model, optimizer, loss_fn, train_loader, val_loader, num_epochs, device, metrics):
    print("Starting training")
    train_losses, val_losses = [], []
    train_metrics_history = {metric_name: [] for metric_name in metrics}
    val_metrics_history = {metric_name: [] for metric_name in metrics}

    for epoch in range(1, num_epochs + 1):
        # Training
        train_loss, train_metrics = train_epoch(
            model, optimizer, loss_fn, train_loader, device, metrics)
        # Validation
        val_loss, val_metrics = validate(
            model, loss_fn, val_loader, device, metrics)
        
        # Logging results
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        for metric_name in metrics:
            train_metrics_history[metric_name].append(train_metrics[metric_name])
            val_metrics_history[metric_name].append(val_metrics[metric_name])

        # Print metrics
        metrics_str = ', '.join(
            [f'{metric_name}: {train_metrics[metric_name]:.3f} (train), {val_metrics[metric_name]:.3f} (val)'
             for metric_name in metrics])
        print(
            f"Epoch {epoch}/{num_epochs}: "
            f"Loss: {train_loss:.3f} (train), {val_loss:.3f} (val), "
            f"{metrics_str}"
        )

    return model, train_losses, val_losses, train_metrics_history, val_metrics_history

# EXPERIMENT

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
train_ratio = 0.8
val_ratio = 0.1
learning_rate = 0.01  
batch_size = 64
  
train_loader, val_loader, test_loader = train_test_split_graph(
      data, train_ratio, val_ratio, batch_size)

In [10]:
def gnn_experiment(model, train_loader, val_loader):
  print(model.__class__.__name__)
  
  optimizer = torch.optim.Adam(
      model.parameters(), lr=learning_rate)
  
  loss_fn = torch.nn.CrossEntropyLoss()
  
  metrics = {
    'accuracy': accuracy,
    'f1': f1,
    'balanced_accuracy': balanced_accuracy,
  }
  
  # Train the model
  model, train_losses, val_losses, train_metrics_history, val_metrics_history = training_loop(
    model, optimizer, loss_fn, train_loader, val_loader, num_epochs=10, device=device, metrics=metrics
  )
  
  return {
    'model': model,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_metrics_history': train_metrics_history,
    'val_metrics_history': val_metrics_history
  }

In [11]:
gcn = GCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes).to(device)
gat = GAT(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes, num_heads=8)
gin = GIN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes)
gsage = GraphSAGE(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes)

In [None]:

# warnings.filterwarnings("ignore")
gcn_results = gnn_experiment(gcn, train_loader, val_loader)
#gnn_experiment(gsage, train_loader, val_loader)

#save gcn_results_benchmark as pt
torch.save(gcn_results, 'gcn_results_benchmark.pt')

In [None]:
gat_results = gnn_experiment(gat, train_loader, val_loader)
#save gat_results_benchmark as pt
torch.save(gat_results, 'gat_results_benchmark.pt')