# Notebook to construct and train a GraphSAGE model on a corpus of British Folk/Traditional melodies

Notebook to train GraphSAGE model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

print(torch.__version__)
print(torch.version.cuda)

from torch_geometric.data import Batch
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader, LinkNeighborLoader
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import pickle
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sn

2.1.0
None


### Load the data:

(Using the same train, val, test splits as the identical dataset for GNN comparison)

In [2]:
data_path = 'split_data/10_class_50yrN_split.pkl' # Change as needed
with open(data_path, 'rb') as f:
    # Load the serialized object from the file
    tune_data = pickle.load(f)

### Define Model:

In [54]:
class GraphSAGE(torch.nn.Module):
    '''GraphSAGE Model'''

    # Define layers:
    def __init__(self, dim_in, dim_h, num_classes):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, dim_h // 2)
        self.linear = torch.nn.Linear(dim_h // 2, num_classes)

    # Define forward function:
    def forward(self, x, edge_index):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        h = F.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.linear(h)

        return h # return single numeric value

### Train and test functions:

In [55]:
def train_model(model, train_loader, val_loader, optimiser, criterion, n_epochs, saved_model_path, print_metrics=True, save_best_model=False):
    ''' 
    Function to train a GraphSAGE model.

    Inputs:
    - model (torch.nn.Module): GraphSAGE model
    - train_loader (pytorch_geometric.loader.NeighborLoader)
    - optmiser (torch.optim)
    - criterion (torch.nn): Loss function
    - n_epochs (int): Number of epochs to train for

    Returns:
    - average_train_loss (float): Training loss averaged over batches
    - average_val_loss (float): Validation loss averaged over batches
    - peak_cal_acc (float): Peak validation accuracy reached during training
    
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Training loop:

    peak_val_acc = 0.0
    train_loss_lst = []
    val_loss_lst = []
    epoch_lst = []
    for epoch in range(n_epochs+1):
        model.train()
        train_loss = 0
        train_accuracy = 0
        train_correct = 0
        train_total = 0
        val_loss = 0
        num_batches = len(train_loader)

        # Train in mini-batches:
        for batch in train_loader:
            batch = batch.to(device)
            optimiser.zero_grad()

            # Feed data through model:
            out = model(batch.x, batch.edge_index)

            # Calc. loss:
            loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
            loss.backward()
            optimiser.step()
            train_loss += loss.item()

            # Calculate train accuracy
            pred = out.argmax(dim=1)
            correct = pred[batch.train_mask] == batch.y[batch.train_mask]
            train_correct += correct.sum().item()
            train_total += batch.train_mask.sum().item()
        
        average_train_loss = train_loss / num_batches
        train_accuracy = train_correct / train_total
        
        if epoch % 5 == 0:

            # Val loss:
            model.eval()
            val_accuracy = 0.0
            num_val_nodes = 0
            with torch.no_grad():
                model.eval()
                for batch in val_loader:
                    batch = batch.to(device)
                    out = model(batch.x, batch.edge_index)
                    loss = criterion(out[batch.val_mask], batch.y[batch.val_mask])
                    val_loss += loss.item()

                    # Calculate validation accuracy
                    pred = out.argmax(dim=1)
                    correct = pred[batch.val_mask] == batch.y[batch.val_mask]
                    num_val_nodes += batch.val_mask.sum().item()
                    val_accuracy += correct.sum().item()

            average_val_loss = val_loss / num_batches
            val_accuracy /= num_val_nodes

            # Add train and val loss to lst
            train_loss_lst.append(train_accuracy)
            val_loss_lst.append(val_accuracy)
            epoch_lst.append(epoch)

            if val_accuracy > peak_val_acc:
                peak_val_acc = val_accuracy

                # Save best model:
                if save_best_model:
                    torch.save(model.state_dict(), saved_model_path)
                    print("Model saved")

            if print_metrics:    
                print('Epoch: ', epoch)
                print('Train loss: ', average_train_loss)
                print('Train accuracy: ', train_accuracy)
                print('Val loss: ', average_val_loss)
                print('Val accuracy: ', val_accuracy)

    # Plot losses:
    plt.figure(figsize=(10, 5))
    plt.plot(epoch_lst, train_loss_lst, label='Training Accuracy')
    plt.plot(epoch_lst, val_loss_lst, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy of GraphSAGE Model')
    plt.grid(True)
    plt.legend()
    plt.show()


    return average_train_loss, average_val_loss, peak_val_acc

In [56]:
def test_model(model, data, data_loader, print_metrics=True):
    ''' 
    Function to train a GraphSAGE model.

    Inputs:
    - model (torch.nn.Module): GraphSAGE model
    - data_loader (pytorch_geometric.loader.NeighborLoader)
    - print_metrics (bool): Whether to print testing results (Default = True)

    Returns:
    - test_accuracy (float): Test accuracy
    
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Test init:
    model.eval()

    test_accuracy = 0
    test_correct = 0
    test_total = 0
    num_batches = len(data_loader)
    all_preds = []
    all_labels = []

    # Test in mini-batches:
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)

            # Feed data through model:
            out = model(batch.x, batch.edge_index)

            # Calculate test accuracy
            pred = out.argmax(dim=1)
            correct = pred[batch.test_mask] == batch.y[batch.test_mask]
            test_correct += correct.sum().item()
            test_total += batch.test_mask.sum().item()

            # Append for confusion matrix
            all_preds.append(pred[batch.test_mask].cpu().numpy())
            all_labels.append(batch.y[batch.test_mask].cpu().numpy())

    
    test_accuracy = test_correct / test_total

    if print_metrics:    
        print('Test accuracy: ', test_accuracy)

        # Convert prediction and label tensors to numpy arrays
        preds = np.concatenate(all_preds)
        labels = np.concatenate(all_labels)

        # Build confusion matrix:
        if num_classes == 4:
            classes = ('1700', '1750', '1800', '1850')
        if num_classes == 5:
            classes = ('1650', '1700', '1750', '1800', '1850') 
        if num_classes == 10:
            classes = ('1650', '1675', '1700', '1725', '1750', '1775', '1800',  '1825', '1850', '1875')
        cf_matrix = confusion_matrix(labels, preds) # build confusion matrix
        cf_matrix_df = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                        columns = [i for i in classes]) # normalise, convert to cols, store in dataframe 
        plt.figure(figsize = (12,7))
        plt.title('Folk-Song Dating Confusion Matrix')
        sn.heatmap(cf_matrix_df, annot=True)


    return test_accuracy

### Grid Search:

In [18]:
### Perform grid-search to find optimum number of out_channels:

dim_h_grid = [16, 32, 64, 128]
num_neighbours_1 = [15, 30]
num_neighbours_2 = [15, 30]
lr_grid = [0.001, 0.01]

results_dict = {}

for num_neighbours in num_neighbours_1:

    for num_neighbours_second in num_neighbours_2:

        # Params:
        train_loader = NeighborLoader(
        tune_data,
        num_neighbors=[num_neighbours, num_neighbours_second], 
        batch_size = 800,
        input_nodes=tune_data.train_mask,
        )

        # Params:
        val_loader = NeighborLoader(
        tune_data,
        num_neighbors=[num_neighbours, num_neighbours_second], 
        batch_size = 800,
        input_nodes=tune_data.val_mask,
        )

        for num_dims in dim_h_grid:
            for lr in lr_grid:
                peak_val_acc = 0.0
                
                dim_h = num_dims
                dim_in = tune_data.num_features
                epochs = 500

                
                # Train:
                saved_model_path = 'None'
                num_classes = 5
                model = GraphSAGE(dim_in, dim_h, num_classes)

                optimiser = torch.optim.Adam(model.parameters(), lr=lr)
                criterion = nn.CrossEntropyLoss()

                _, _, peak_val_acc = train_model(model, train_loader, val_loader, optimiser, criterion, epochs, saved_model_path, False)

                results_dict[(num_neighbours, num_neighbours_second, num_dims, lr)] = peak_val_acc 

                print('num_neighbours_1: {num_neighbours}, num_neighbours_2: {num_neighbours_second}, dim_h: {dim_h}, lr: {lr}, achieved peak val_acc: {peak_val_acc}'.format(num_neighbours=num_neighbours, num_neighbours_second=num_neighbours_second, dim_h=dim_h, lr=lr, peak_val_acc=peak_val_acc))
  
# Print best:
accuracy_lst = []
for params, accuracy in results_dict.items():
    accuracy_lst.append((params, accuracy))
    
accuracy_lst.sort(reverse=True, key=lambda x: x[1])
top_3 = accuracy_lst[:3]

print(top_3)


num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 16, lr: 0.001, achieved peak val_acc: 0.2733333333333333
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 16, lr: 0.01, achieved peak val_acc: 0.10135135135135136
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 32, lr: 0.001, achieved peak val_acc: 0.2080536912751678
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 32, lr: 0.01, achieved peak val_acc: 0.1
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 64, lr: 0.001, achieved peak val_acc: 0.22297297297297297
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 64, lr: 0.01, achieved peak val_acc: 0.24
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 128, lr: 0.001, achieved peak val_acc: 0.26
num_neighbours_1: 5, num_neighbours_2: 5, dim_h: 128, lr: 0.01, achieved peak val_acc: 0.26174496644295303
num_neighbours_1: 5, num_neighbours_2: 15, dim_h: 16, lr: 0.001, achieved peak val_acc: 0.21333333333333335
num_neighbours_1: 5, num_neighbours_2: 15, dim_h: 16, lr: 0.01, achieved peak val_acc:

### Final Training:

In [78]:
# Parameters
dim_in = tune_data.num_features
dim_h = 128
num_classes = 10
saved_model_path = 'run5_SAGE_best_model_10Classes_50yrN'

train_loader = NeighborLoader(
        tune_data,
        num_neighbors=[15, 15], # n neighbours of a node and n neighbours for each of these neighbours
        batch_size = 800,
        input_nodes=tune_data.train_mask)

val_loader = NeighborLoader(
        tune_data,
        num_neighbors=[15, 15], # n neighbours of a node and n neighbours for each of these neighbours
        batch_size = 800,
        input_nodes=tune_data.val_mask)

test_loader = NeighborLoader(
        tune_data,
        num_neighbors=[15, 15], # n neighbours of a node and n neighbours for each of these neighbours
        batch_size = 800,
        input_nodes=tune_data.test_mask)

# model
model = GraphSAGE(dim_in, dim_h, num_classes)

# Optimiser and loss criterion:
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [79]:
train_loss, val_loss, peak_val_acc = train_model(model, train_loader, val_loader, optimiser, criterion, 500, saved_model_path, True, True)

Model saved
Epoch:  0
Train loss:  8.422938346862793
Train accuracy:  0.10714285714285714
Val loss:  3.1857831478118896
Val accuracy:  0.21333333333333335
Epoch:  5
Train loss:  4.159163951873779
Train accuracy:  0.11714285714285715
Val loss:  2.4459190368652344
Val accuracy:  0.08
Epoch:  10
Train loss:  2.776630401611328
Train accuracy:  0.13428571428571429
Val loss:  2.263192653656006
Val accuracy:  0.12666666666666668
Epoch:  15
Train loss:  2.422349214553833
Train accuracy:  0.11714285714285715
Val loss:  2.288421154022217
Val accuracy:  0.1
Epoch:  20
Train loss:  2.339447498321533
Train accuracy:  0.09428571428571429
Val loss:  2.298563003540039
Val accuracy:  0.11333333333333333
Epoch:  25
Train loss:  2.304237127304077
Train accuracy:  0.11428571428571428
Val loss:  2.300551652908325
Val accuracy:  0.11333333333333333
Epoch:  30
Train loss:  2.3089160919189453
Train accuracy:  0.10285714285714286
Val loss:  2.300790309906006
Val accuracy:  0.11333333333333333
Epoch:  35
Train 

In [80]:
peak_val_acc

0.3

In [81]:
model_path = saved_model_path

dim_in = tune_data.num_features
dim_h = 128
num_classes = 10
model = GraphSAGE(dim_in, dim_h, num_classes)

model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [None]:
test_accuracies = []
for i in range(5):
    test_accuracy = test_model(model, tune_data, test_loader, num_classes)
    test_accuracies.append(test_accuracy)

avg_accuracy = np.mean(test_accuracies)
print(avg_accuracy)