# GNN
---

## Import

In [23]:
import torch
import matplotlib.pyplot as plt
#import networkx as nx
import numpy as np
import pandas as pd
from torch_geometric.loader import DataLoader
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
from models import *
from umap import UMAP
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

np.random.seed(0)

## Read, understand and prepare the data

In [24]:
# node_data = pd.read_parquet('data/amazon_product_data_word2vec.parquet')
data = torch.load('data/amazon_product_data_sum.pt')
data.num_classes = data.y.unique().shape[0]

  data = torch.load('data/amazon_product_data_sum.pt')


### Main info

In [25]:
print("data", data)
print("num nodes", data.num_nodes)
print("Num edges", data.num_edges)
print("num node features", data.num_node_features)
print("is undirected", data.is_undirected())
print("is directed", data.is_directed())
print("num edge features", data.num_edge_features)
print('num classes', data.num_classes)

data Data(x=[729819, 300], edge_index=[2, 680548], y=[729819], num_classes=10)
num nodes 729819
Num edges 680548
num node features 300
is undirected False
is directed True
num edge features 0
num classes 10


In [26]:
# value_counts = node_data['main_category'].value_counts()

# # plot a bar chart of the main categories
# plt.figure(figsize=(15, 6))
# plt.bar(value_counts.index, value_counts.values)
# plt.xticks(rotation=90)
# plt.title('Main Category Distribution')
# plt.show()

### Visualization

In [27]:
def visualize(h, color):
    z = UMAP(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10, 10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

In [28]:
# def visualize(h, color):
#     z = UMAP(n_components=2).fit_transform(h.detach().cpu().numpy())

#     plt.figure(figsize=(10, 10))
#     plt.xticks([])
#     plt.yticks([])

#     plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
#     plt.show()

### Preparation for model training

In [29]:
def create_masks(data, train_ratio, val_ratio):
  num_nodes = data.num_nodes
  indices = list(range(num_nodes))
  np.random.shuffle(indices)
  
  # Create masks
  train_mask = torch.zeros(num_nodes, dtype=torch.bool)
  val_mask = torch.zeros(num_nodes, dtype=torch.bool)
  test_mask = torch.zeros(num_nodes, dtype=torch.bool)
  
  # Set proportions for train, val, and test
  train_end = int(train_ratio * num_nodes)
  val_end = train_end + int(val_ratio * num_nodes)
  train_mask[indices[:train_end]] = True
  val_mask[indices[train_end:val_end]] = True
  test_mask[indices[val_end:]] = True
  
  # Assign the custom masks to the dataset
  data.train_mask = train_mask
  data.val_mask = val_mask
  data.test_mask = test_mask


def train_test_split_graph(data: Data, train_ratio: float, val_ratio: float, batch_size: int):
    """
    Split the graph data into train, validation, and test sets
    :param data: The graph data
    :param train_ratio: The ratio of the training set
    :param val_ratio: The ratio of the validation set
    :param batch_size: The batch size
    
    :return: The train, validation, and test data loaders
    """

    # Call the function to create masks
    create_masks(data, train_ratio, val_ratio)

    # train, validation, and test node indices based on the masks
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)
    test_idx = data.test_mask.nonzero(as_tuple=False).view(-1)
    
    def create_data_loader(data, indices):
        return DataLoader(data[indices], batch_size=batch_size, shuffle=True)
      
    def create_neighbor_loader(data, indices, batch_size=batch_size):
        return NeighborLoader(data, num_neighbors=[30] * 2, batch_size=batch_size, input_nodes=indices)

    # create the data loaders
    train_loader = create_neighbor_loader(data, train_idx)
    val_loader = create_neighbor_loader(data, val_idx, batch_size=4096*2)
    test_loader = create_neighbor_loader(data, test_idx)

    return train_loader, val_loader, test_loader

## Model training

### Training functions

In [30]:
from sklearn.metrics import f1_score, balanced_accuracy_score


def accuracy(predictions, labels):
    correct = (predictions.argmax(dim=1) == labels).sum().item()
    total = labels.size(0)
    return correct / total


def f1(predictions, labels):
    preds = predictions.argmax(dim=1).cpu().numpy()
    labels = labels.cpu().numpy()
    return f1_score(labels, preds, average='weighted')
  
def balanced_accuracy(predictions, labels):
    preds = predictions.argmax(dim=1).cpu().numpy()
    labels = labels.cpu().numpy()
    return balanced_accuracy_score(labels, preds)

In [31]:
# print the size of the validation set
train_loader, val_loader, test_loader = train_test_split_graph(data, 0.8, 0.1, 4096)
print("Validation set size", len(val_loader.dataset))

Validation set size 72981




In [32]:
def train_epoch(model, optimizer, loss_fn, train_loader, device, metrics):
    model.train()
    all_preds = []
    all_labels = []
    losses = []

    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        out, h = model(batch.x, batch.edge_index)  # Receive only one output
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        all_preds.append(out.detach().cpu())
        all_labels.append(batch.y.detach().cpu())
        
    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute metrics on the entire dataset
    avg_metrics = {metric_name: metric_fn(all_preds, all_labels) 
                  for metric_name, metric_fn in metrics.items()}
    avg_loss = np.mean(losses)

    return avg_loss, avg_metrics

@torch.no_grad()
def validate(model, loss_fn, val_loader, device, metrics):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    losses = []

    for batch in val_loader:
        batch = batch.to(device)
        out, h = model(batch.x, batch.edge_index)  # Receive only one output
        loss = loss_fn(out, batch.y)

        losses.append(loss.item())
        all_preds.append(out.detach().cpu())
        all_labels.append(batch.y.detach().cpu())
        
    # with torch.no_grad():
    #   _, h = model(val_loader.data.x, val_loader.data.edge_index)
    #   visualize(h, val_loader.data.y)

    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute metrics on the entire dataset
    avg_metrics = {metric_name: metric_fn(all_preds, all_labels) 
                  for metric_name, metric_fn in metrics.items()}
    avg_loss = np.mean(losses)

    return avg_loss, avg_metrics

def training_loop(model, optimizer, loss_fn, train_loader, val_loader, num_epochs, device, metrics):
    print("Starting training")
    train_losses, val_losses = [], []
    train_metrics_history = {metric_name: [] for metric_name in metrics}
    val_metrics_history = {metric_name: [] for metric_name in metrics}

    for epoch in range(1, num_epochs + 1):
        # Training
        train_loss, train_metrics = train_epoch(
            model, optimizer, loss_fn, train_loader, device, metrics)
        # Validation
        val_loss, val_metrics = validate(
            model, loss_fn, val_loader, device, metrics)
        
        # Logging results
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        for metric_name in metrics:
            train_metrics_history[metric_name].append(train_metrics[metric_name])
            val_metrics_history[metric_name].append(val_metrics[metric_name])

        # Print metrics
        metrics_str = ', '.join(
            [f'{metric_name}: {train_metrics[metric_name]:.3f} (train), {val_metrics[metric_name]:.3f} (val)'
             for metric_name in metrics])
        print(
            f"Epoch {epoch}/{num_epochs}: "
            f"Loss: {train_loss:.3f} (train), {val_loss:.3f} (val), "
            f"{metrics_str}"
        )

    return model, train_losses, val_losses, train_metrics_history, val_metrics_history

### Actual training

## Experiment

In [33]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
train_ratio = 0.8
val_ratio = 0.1
learning_rate = 0.01  
batch_size = 64
  
train_loader, val_loader, test_loader = train_test_split_graph(
      data, train_ratio, val_ratio, batch_size)

In [34]:
def gnn_experiment(model, train_loader, val_loader):
  print(model.__class__.__name__)
  
  optimizer = torch.optim.Adam(
      model.parameters(), lr=learning_rate)
  
  loss_fn = torch.nn.CrossEntropyLoss()
  
  metrics = {
    'accuracy': accuracy,
    'f1': f1,
    'balanced_accuracy': balanced_accuracy,
  }
  
  # Train the model
  model, train_losses, val_losses, train_metrics_history, val_metrics_history = training_loop(
    model, optimizer, loss_fn, train_loader, val_loader, num_epochs=10, device=device, metrics=metrics
  )
  
  return {
    'model': model,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_metrics_history': train_metrics_history,
    'val_metrics_history': val_metrics_history
  }

In [35]:
gcn = GCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes).to(device)
gat = GAT(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes, num_heads=8)
gin = GIN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes)
gsage = GraphSAGE(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes)

In [43]:

# warnings.filterwarnings("ignore")
gcn_results = gnn_experiment(gat, train_loader, val_loader)
# gat_results = gnn_experiment(gat, train_loader, val_loader)
#gnn_experiment(gin, train_loader, val_loader)
#gnn_experiment(gsage, train_loader, val_loader)

GAT
Starting training




Epoch 1/10: Loss: 0.511 (train), 0.236 (val), accuracy: 0.832 (train), 0.930 (val), f1: 0.832 (train), 0.930 (val), balanced_accuracy: 0.824 (train), 0.927 (val)
Epoch 2/10: Loss: 0.157 (train), 0.150 (val), accuracy: 0.956 (train), 0.957 (val), f1: 0.956 (train), 0.957 (val), balanced_accuracy: 0.954 (train), 0.956 (val)
Epoch 3/10: Loss: 0.109 (train), 0.125 (val), accuracy: 0.970 (train), 0.965 (val), f1: 0.970 (train), 0.965 (val), balanced_accuracy: 0.969 (train), 0.964 (val)
Epoch 4/10: Loss: 0.092 (train), 0.131 (val), accuracy: 0.975 (train), 0.963 (val), f1: 0.975 (train), 0.963 (val), balanced_accuracy: 0.974 (train), 0.961 (val)
Epoch 5/10: Loss: 0.082 (train), 0.108 (val), accuracy: 0.978 (train), 0.970 (val), f1: 0.978 (train), 0.970 (val), balanced_accuracy: 0.977 (train), 0.969 (val)
Epoch 6/10: Loss: 0.075 (train), 0.123 (val), accuracy: 0.979 (train), 0.965 (val), f1: 0.979 (train), 0.965 (val), balanced_accuracy: 0.979 (train), 0.965 (val)
Epoch 7/10: Loss: 0.071 (tra

In [None]:
# plot all the metrics in a single plot
def plot_metrics(results):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    for ax, (metric_name, metric_values) in zip(axs, results['train_metrics_history'].items()):
        ax.plot(metric_values, label='train')
        ax.plot(results['val_metrics_history'][metric_name], label='val')
        ax.set_title(metric_name)
        ax.legend()

    plt.show()

In [37]:
# def gnn_experiment(model, train_loader, val_loader):
  
#   optimizer = torch.optim.Adam(
#       model.parameters(), lr=learning_rate)
  
#   loss_fn = torch.nn.CrossEntropyLoss()
  
#   metrics = {
#     'accuracy': accuracy,
#     'f1': f1,
#     'balanced_accuracy': balanced_accuracy,
#   }
  
#   # Train the model
#   model, train_losses, val_losses, train_metrics_history, val_metrics_history = training_loop(
#     model, optimizer, loss_fn, train_loader, val_loader, num_epochs=1, device=device, metrics=metrics
#   )
  
#   return {
#     'model': model,
#     'train_losses': train_losses,
#     'val_losses': val_losses,
#     'train_metrics_history': train_metrics_history,
#     'val_metrics_history': val_metrics_history
#   }

In [38]:
# gcn = GCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes).to(device)
# #gat = GAT(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes, num_heads=8)
# #gin = GIN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes)
# #gsage = GraphSAGE(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes)

In [39]:
# # train the models and get the results
# #for model in [gcn, gat, gin, gsage]:
# results = gnn_experiment(gcn, train_loader, val_loader) 
#  # torch.save(results, f'{model.__class__.__name__}_results.pt')

In [40]:
# # put all the results in a dataframe and display it as a table where it is easy to compare the models
# results = []

# for model in [gcn, gat, gin, gsage]:
#   model_name = model.__class__.__name__
#   results.append(torch.load(f'output/{model_name}_results.pt'))
  
# results_df = pd.DataFrame(results)
# results_df['model'] = ['GCN', 'GAT', 'GIN', 'GraphSAGE']
# results_df = results_df.set_index('model')
# results_df

## Evaluation

In [41]:
def confusion_matrix(predictions, labels):
    # visualize the confusion matrix with dusplay matrix from sklearn
    
    preds = predictions.argmax(dim=1).cpu().numpy()
    labels = labels.cpu().numpy()
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

In [42]:
# test the confusion matrix with the GCN results
confusion_matrix(gcn_results['model'](data.x, data.edge_index)[0], data.y)

TypeError: argmax() got an unexpected keyword argument 'dim'