## Required package

In [1]:
import numpy as np
import os
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops
from torch_geometric.transforms import ToUndirected
from torch.nn import BatchNorm1d
from torch.optim.lr_scheduler import ExponentialLR

In [2]:
# Get the number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")

Number of available GPUs: 4


## Required data generated by GNNonCalo_Scaling_DataPreparation.ipynb

In [3]:
# Open the cell features and the training hdf5 files
hf_multi_cellFeaturesScaled_neighbor = h5py.File("/storage/mxg1065/MultiClassGNN/multi_cellFeaturesScaled_train_70evs.hdf5", 'r')
hf_multi_train_edge_source_BD = h5py.File("/storage/mxg1065/MultiClassGNN/train_edge_source_BD_70evs.hdf5", 'r')
hf_multi_train_edge_dest_BD = h5py.File("/storage/mxg1065/MultiClassGNN/train_edge_dest_BD_70evs.hdf5", 'r')
hf_multi_train_edge_source_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/train_edge_source_noBD_70evs.hdf5", 'r')
hf_multi_train_edge_dest_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/train_edge_dest_noBD_70evs.hdf5", 'r')
hf_multi_truth_label_train_neighbor= h5py.File("/storage/mxg1065/neighborLabels100Events.hdf5", 'r')
# hf_multi_truth_label_train_neighbor= h5py.File('/storage/mxg1065/MultiClassGNN/totalTrainingLabelsRandom.hdf5', 'r')


# Pull the data as arrays
multi_cellFeaturesScaled = hf_multi_cellFeaturesScaled_neighbor.get("multi_cellFeatures_trainS")[:]
multi_train_edge_source_BD = hf_multi_train_edge_source_BD.get("train_edge_source_BD")[:]
multi_train_edge_dest_BD = hf_multi_train_edge_dest_BD.get("train_edge_dest_BD")[:]
multi_train_edge_source_noBD = hf_multi_train_edge_source_noBD.get("train_edge_source_noBD")[:]
multi_train_edge_dest_noBD = hf_multi_train_edge_dest_noBD.get("train_edge_dest_noBD")[:]
multi_truth_label_train = hf_multi_truth_label_train_neighbor.get("neighborLabels100Events")[:]
# multi_truth_label_train = hf_multi_truth_label_train_neighbor.get("totalTrainingLabelsRandom")[:]

# Close the files
hf_multi_cellFeaturesScaled_neighbor.close()
hf_multi_train_edge_source_BD.close()
hf_multi_train_edge_dest_BD.close()
hf_multi_train_edge_source_noBD.close()
hf_multi_train_edge_dest_noBD.close()
hf_multi_truth_label_train_neighbor.close()

# Opening and closing the True hdf5 files
hf_multi_test_edge_source_true_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_true_BD.hdf5", "r")
hf_multi_test_edge_dest_true_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_true_BD.hdf5", "r")
hf_multi_test_edge_source_true_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_true_noBD.hdf5", "r")
hf_multi_test_edge_dest_true_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_true_noBD.hdf5", "r")

multi_test_edge_source_true_BD = hf_multi_test_edge_source_true_BD.get("multi_test_edge_source_true_BD")[:]
multi_test_edge_dest_true_BD = hf_multi_test_edge_dest_true_BD.get("multi_test_edge_dest_true_BD")[:]
multi_test_edge_source_true_noBD = hf_multi_test_edge_source_true_noBD.get("multi_test_edge_source_true_noBD")[:]
multi_test_edge_dest_true_noBD = hf_multi_test_edge_dest_true_noBD.get("multi_test_edge_dest_true_noBD")[:]

hf_multi_test_edge_source_true_BD.close()
hf_multi_test_edge_dest_true_BD.close()
hf_multi_test_edge_source_true_noBD.close()
hf_multi_test_edge_dest_true_noBD.close()

# Opening and closing the Lone-Lone hdf5 files
hf_multi_test_edge_source_bkg_lone_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_lone_BD.hdf5", "r")
hf_multi_test_edge_dest_bkg_lone_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_lone_BD.hdf5", "r")
hf_multi_test_edge_source_bkg_lone_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_lone_noBD.hdf5", "r")
hf_multi_test_edge_dest_bkg_lone_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_lone_noBD.hdf5", "r")

multi_test_edge_source_bkg_lone_BD = hf_multi_test_edge_source_bkg_lone_BD.get("multi_test_edge_source_bkg_lone_BD")[:]
multi_test_edge_dest_bkg_lone_BD = hf_multi_test_edge_dest_bkg_lone_BD.get("multi_test_edge_dest_bkg_lone_BD")[:]
multi_test_edge_source_bkg_lone_noBD = hf_multi_test_edge_source_bkg_lone_noBD.get("multi_test_edge_source_bkg_lone_noBD")[:]
multi_test_edge_dest_bkg_lone_noBD = hf_multi_test_edge_dest_bkg_lone_noBD.get("multi_test_edge_dest_bkg_lone_noBD")[:]

hf_multi_test_edge_source_bkg_lone_BD.close()
hf_multi_test_edge_dest_bkg_lone_BD.close()
hf_multi_test_edge_source_bkg_lone_noBD.close()
hf_multi_test_edge_dest_bkg_lone_noBD.close()

# Opening and closing the Lone-Cluster hdf5 files
hf_multi_test_edge_source_bkg_lone_cluster_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_lone_cluster_BD.hdf5", "r")
hf_multi_test_edge_dest_bkg_lone_cluster_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_lone_cluster_BD.hdf5", "r")
hf_multi_test_edge_source_bkg_lone_cluster_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_lone_cluster_noBD.hdf5", "r")
hf_multi_test_edge_dest_bkg_lone_cluster_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_lone_cluster_noBD.hdf5", "r")

multi_test_edge_source_bkg_lone_cluster_BD = hf_multi_test_edge_source_bkg_lone_cluster_BD.get("multi_test_edge_source_bkg_lone_cluster_BD")[:]
multi_test_edge_dest_bkg_lone_cluster_BD = hf_multi_test_edge_dest_bkg_lone_cluster_BD.get("multi_test_edge_dest_bkg_lone_cluster_BD")[:]
multi_test_edge_source_bkg_lone_cluster_noBD = hf_multi_test_edge_source_bkg_lone_cluster_noBD.get("multi_test_edge_source_bkg_lone_cluster_noBD")[:]
multi_test_edge_dest_bkg_lone_cluster_noBD = hf_multi_test_edge_dest_bkg_lone_cluster_noBD.get("multi_test_edge_dest_bkg_lone_cluster_noBD")[:]

hf_multi_test_edge_source_bkg_lone_cluster_BD.close()
hf_multi_test_edge_dest_bkg_lone_cluster_BD.close()
hf_multi_test_edge_source_bkg_lone_cluster_noBD.close()
hf_multi_test_edge_dest_bkg_lone_cluster_noBD.close()

# Opening and closing the Cluster-Lone hdf5 files
hf_multi_test_edge_source_bkg_cluster_lone_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_cluster_lone_BD.hdf5", "r")
hf_multi_test_edge_dest_bkg_cluster_lone_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_cluster_lone_BD.hdf5", "r")
hf_multi_test_edge_source_bkg_cluster_lone_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_cluster_lone_noBD.hdf5", "r")
hf_multi_test_edge_dest_bkg_cluster_lone_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_cluster_lone_noBD.hdf5", "r")

multi_test_edge_source_bkg_cluster_lone_BD = hf_multi_test_edge_source_bkg_cluster_lone_BD.get("multi_test_edge_source_bkg_cluster_lone_BD")[:]
multi_test_edge_dest_bkg_cluster_lone_BD = hf_multi_test_edge_dest_bkg_cluster_lone_BD.get("multi_test_edge_dest_bkg_cluster_lone_BD")[:]
multi_test_edge_source_bkg_cluster_lone_noBD = hf_multi_test_edge_source_bkg_cluster_lone_noBD.get("multi_test_edge_source_bkg_cluster_lone_noBD")[:]
multi_test_edge_dest_bkg_cluster_lone_noBD = hf_multi_test_edge_dest_bkg_cluster_lone_noBD.get("multi_test_edge_dest_bkg_cluster_lone_noBD")[:]

hf_multi_test_edge_source_bkg_cluster_lone_BD.close()
hf_multi_test_edge_dest_bkg_cluster_lone_BD.close()
hf_multi_test_edge_source_bkg_cluster_lone_noBD.close()
hf_multi_test_edge_dest_bkg_cluster_lone_noBD.close()

# Opening and closing the Cluster-CLuster hdf5 files
hf_multi_test_edge_source_bkg_cluster_cluster_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_cluster_cluster_BD.hdf5", "r")
hf_multi_test_edge_dest_bkg_cluster_cluster_BD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_cluster_cluster_BD.hdf5", "r")
hf_multi_test_edge_source_bkg_cluster_cluster_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_source_bkg_cluster_cluster_noBD.hdf5", "r")
hf_multi_test_edge_dest_bkg_cluster_cluster_noBD = h5py.File("/storage/mxg1065/MultiClassGNN/multi_test_edge_dest_bkg_cluster_cluster_noBD.hdf5", "r")

multi_test_edge_source_bkg_cluster_cluster_BD = hf_multi_test_edge_source_bkg_cluster_cluster_BD.get("multi_test_edge_source_bkg_cluster_cluster_BD")[:]
multi_test_edge_dest_bkg_cluster_cluster_BD = hf_multi_test_edge_dest_bkg_cluster_cluster_BD.get("multi_test_edge_dest_bkg_cluster_cluster_BD")[:]
multi_test_edge_source_bkg_cluster_cluster_noBD = hf_multi_test_edge_source_bkg_cluster_cluster_noBD.get("multi_test_edge_source_bkg_cluster_cluster_noBD")[:]
multi_test_edge_dest_bkg_cluster_cluster_noBD = hf_multi_test_edge_dest_bkg_cluster_cluster_noBD.get("multi_test_edge_dest_bkg_cluster_cluster_noBD")[:]

hf_multi_test_edge_source_bkg_cluster_cluster_BD.close()
hf_multi_test_edge_dest_bkg_cluster_cluster_BD.close()
hf_multi_test_edge_source_bkg_cluster_cluster_noBD.close()
hf_multi_test_edge_dest_bkg_cluster_cluster_noBD.close()

In [4]:
multi_cellFeaturesScaled.shape

(70, 187652, 8)

In [5]:
# Make the scaled cell features into a torch tensor
x = torch.tensor(multi_cellFeaturesScaled, dtype=torch.float)
x.shape

torch.Size([70, 187652, 8])

In [6]:
multi_train_edge_source_BD.shape

(70, 2500484)

In [7]:
multi_test_edge_source_true_BD.shape

(30, 66000)

## Preparing the Training Set and Test Set of Edges for GNN

In [8]:
def createEdgeIndexTensor(source, dest):
    edgeIndex = torch.tensor([source, dest], dtype=torch.long)
    return edgeIndex.permute(1, 0, 2)

In [9]:
# Training set (Bi-directional and Uni-directional)
trainingEdgeIndexBD = createEdgeIndexTensor(multi_train_edge_source_BD, multi_train_edge_dest_BD)
trainingEdgeIndexNoBD = createEdgeIndexTensor(multi_train_edge_source_noBD, multi_train_edge_dest_noBD)

# Dictionary for storing edge indices for different edge types
edgeIndexData = {
    "ttBD": (multi_test_edge_source_true_BD, multi_test_edge_dest_true_BD),
    "ttNoBD": (multi_test_edge_source_true_noBD, multi_test_edge_dest_true_noBD),
    "llBD": (multi_test_edge_source_bkg_lone_BD, multi_test_edge_dest_bkg_lone_BD),
    "llNoBD": (multi_test_edge_source_bkg_lone_noBD, multi_test_edge_dest_bkg_lone_noBD),
    "lcBD": (multi_test_edge_source_bkg_lone_cluster_BD, multi_test_edge_dest_bkg_lone_cluster_BD),
    "lcNoBD": (multi_test_edge_source_bkg_lone_cluster_noBD, multi_test_edge_dest_bkg_lone_cluster_noBD),
    "clBD": (multi_test_edge_source_bkg_cluster_lone_BD, multi_test_edge_dest_bkg_cluster_lone_BD),
    "clNoBD": (multi_test_edge_source_bkg_cluster_lone_noBD, multi_test_edge_dest_bkg_cluster_lone_noBD),
    "ccBD": (multi_test_edge_source_bkg_cluster_cluster_BD, multi_test_edge_dest_bkg_cluster_cluster_BD),
    "ccNoBD": (multi_test_edge_source_bkg_cluster_cluster_noBD, multi_test_edge_dest_bkg_cluster_cluster_noBD),
}

# Use list comprehension to create and permute tensors for all edge types
edgeIndexTensors = {key: createEdgeIndexTensor(sources, dests) for key, (sources, dests) in edgeIndexData.items()}

  edgeIndex = torch.tensor([source, dest], dtype=torch.long)


In [10]:
print(trainingEdgeIndexBD.shape)
print(trainingEdgeIndexNoBD.shape)

torch.Size([70, 2, 2500484])
torch.Size([70, 2, 1250242])


## Preparing label (true/Fake) tensor 

In [11]:
# Take the array that represents the target labels and add
# a dimension in the "1" index position to make the array
# three-dimensional, with the first dimension representing
# the length of the training set
trainingTruthLabels = np.expand_dims(multi_truth_label_train, axis=1)
# Expands the dimensions of multi_truth_label_train.

In [12]:
trainingTruthLabels.shape

(100, 1, 1250242)

In [13]:
# Convert the target labels into a torch tensor
y_train = torch.tensor(trainingTruthLabels)
y_train.shape

torch.Size([100, 1, 1250242])

## Data customization specific to pytorch 

In [14]:
# Create a class that inherents from the torch.utils.data.Dataset class
# The pytorch class is abstract, meaning we need to define certain methods
# like __len__() and __getitem__()
class CustomDataset(torch.utils.data.Dataset):
    # Class constructor that takes in data list and
    # stores it as an instance, making it avaliable
    # to other methods in the class
    def __init__(self, dataList):
        self.dataList = dataList
    
    # Method return length of data set
    def __len__(self):
        return len(self.dataList)

    # Method returns data point at index idx
    def __getitem__(self, idx):
        return self.dataList[idx]

# Used to handle batch loading, shuffling, and parallel loading during 
# training and testing in the ML pipeline

In [15]:
# Create a list with information regarding a homogenous graph (a graph
# where all nodes represent instances of the same type [cells in the 
# detector] and all edges represent relations of the same type [connections
# between cells])
def createDataList(edgeIndexBD, edgeIndexNoBD, x):
    dataList = []
    for i in range(len(edgeIndexBD)):
        # Create a node feature matrix out of the scaled cell features
        # torch tensor
        x_mat = x[i]
        # Create a graph connectivity matrix out of the torch tensor that
        # contained information of the bi-directional training edge sources
        # and destinations
        edge_index = edgeIndexBD[i]
        edge_index, _ = add_self_loops(edge_index)
        # Create the data object describing a homogeneous graph. x_mat is 
        # the node feature matrix, edge_index is the graph connectivity 
        # matrix, y_train are the target labels 
        data = Data(x=x_mat, edge_index=edge_index, edge_index_out=edgeIndexNoBD[i], y=y_train[i])
        # Converts a homogeneous or heterogeneous graph to an undirected
        # graph (a graph whose edges does not have direction)
        data = ToUndirected()(data)
        dataList.append(data)
    return dataList

# Create collate function which extracts the features (x),
# graph connectivity (edge_index BD, edge_index_out noBD),
# and truth labels (y) to be used in combining samples into
# batches
def collateData(dataList, is_training=False):
    if is_training:
        return (
            [data.x for data in dataList],
            [data.edge_index for data in dataList],
            [data.edge_index_out for data in dataList],
            [data.y for data in dataList]
        )
    else:
        return (
            [data.x for data in dataList],
            [data.edge_index for data in dataList],
            [data.edge_index_out for data in dataList]
        )

In [16]:
# Create the data lists for all edge types
# Create data lists for all categories

dataListTraining = createDataList(trainingEdgeIndexBD, trainingEdgeIndexNoBD, x) # Training Edges
dataListTT = createDataList(edgeIndexTensors['ttBD'], edgeIndexTensors['ttNoBD'], x) # True-True Edges
dataListLL = createDataList(edgeIndexTensors['llBD'], edgeIndexTensors['llNoBD'], x) # Lone-lone Edges
dataListLC = createDataList(edgeIndexTensors['lcBD'], edgeIndexTensors['lcNoBD'], x) # Lone-Cluster Edges
dataListCL = createDataList(edgeIndexTensors['clBD'], edgeIndexTensors['clNoBD'], x) # Cluster-Lone Edges
dataListCC = createDataList(edgeIndexTensors['ccBD'], edgeIndexTensors['ccNoBD'], x) # Cluster-Cluster Edges

In [17]:
batchSize = 5
dataSets = {}
dataLoaders = {}
dataListMapping = {
    "train": dataListTraining,  # Training Edges
    "tt": dataListTT,           # True-True Edges
    "ll": dataListLL,           # Lone-Lone Edges
    "lc": dataListLC,           # Lone-Cluster Edges
    "cl": dataListCL,           # Cluster-Lone Edges
    "cc": dataListCC            # Cluster-Cluster Edges
}

for key, data_list in dataListMapping.items():
    dataSets[key] = CustomDataset(data_list)
    # For 'train', pass is_training=True, otherwise False
    if key == "train":
        dataLoaders[key] = torch.utils.data.DataLoader(
            dataSets[key], 
            batch_size=batchSize, 
            collate_fn=lambda batch: collateData(batch, is_training=True)  # Force is_training=True for train
        )
    else:
        dataLoaders[key] = torch.utils.data.DataLoader(
            dataSets[key], 
            batch_size=batchSize, 
            collate_fn=lambda batch: collateData(batch, is_training=False)
        )

In [18]:
# For the total background dataset
dataListTotalBkg = dataListLL + dataListLC + dataListCL + dataListCC

# Collate function for total background
def collateTotalBkg(dataListTotalBkg):
    batch_x = [data.x for data in dataListTotalBkg]
    batch_edge_index = [data.edge_index for data in dataListTotalBkg]
    batch_edge_index_out = [data.edge_index_out for data in dataListTotalBkg]
    return batch_x, batch_edge_index, batch_edge_index_out

# Create the total background DataLoader
customDatasetTotalBkg = CustomDataset(dataListTotalBkg)
dataLoaderTotalBkg = torch.utils.data.DataLoader(customDatasetTotalBkg, batch_size=batchSize, collate_fn=collateTotalBkg)

## Multi-Edge Classifier Model

In [38]:
class MultiEdgeClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(MultiEdgeClassifier, self).__init__()

        # Node embedding layer
        self.node_embedding = nn.Linear(input_dim, hidden_dim)
        
        # Initialize lists to hold convolutional layers and batch normalization layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        # Add the first graph convolutional layer
        self.convs.append(GCNConv(hidden_dim, 128))
        self.bns.append(BatchNorm1d(128))

        # Add additional layers based on the parameter 'num_layers'
        for i in range(1, num_layers):
            in_channels = 128 if i == 1 else 64  # First layer has 128 channels, rest have 64
            out_channels = 64
            self.convs.append(GCNConv(in_channels, out_channels))
            self.bns.append(BatchNorm1d(out_channels))
        
        # Edge classification layer
        self.fc = nn.Linear(128, output_dim)  # Output logits
    
    def forward(self, x, edge_index, edge_index_out):
        edge_index = edge_index
        x = self.node_embedding(x)

        # Loop through the convolutional layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = torch.relu(x)

        # Edge representations
        edge_index_to_compare = edge_index_out
        edge_rep = torch.cat([x[edge_index_to_compare[0]], x[edge_index_to_compare[1]]], dim=1)  # Check the dim=1 part

        # Return logits (no softmax, CrossEntropyLoss will handle it)
        edge_scores = self.fc(edge_rep)
        return edge_scores

# Usage example:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate the model with dynamic layers
input_dim = 8
hidden_dim = 256
output_dim = 5  # Multiclass classification
num_layers = 5  # Example: you can change this to any number of layers
model = MultiEdgeClassifier(input_dim, hidden_dim, output_dim, num_layers)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()  # Handles softmax internally
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Define the learning rate scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [45]:
def trainModel(model, device, data_loader, optimizer, criterion):
    # Sets the model into training mode
    model.train()
    # Sends model to GPU if available, otherwise uses the CPU
    model.to(device)
    # Initializes the total loss per epoch list
    totalLossPerEpoch = []

    # Loops iterates over batches of data from the data loader
    for batch_x, batch_edge_index, batch_edge_index_out, batch_y in data_loader:
        # Sends the input features, the edge indices, and target
        # labels to the GPU if available, otherwise the CPU
        batch_x = torch.stack(batch_x).to(device)
        batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
        batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]

        # Convert target labels to LongTensor (torch.int64)
        batch_y = [y.long().to(device) for y in batch_y]

        # Clears the gradients of the model parameters to ensure
        # they are not accumulated across batches
        optimizer.zero_grad()

        # Initializes the loss per batch list
        loss_per_batch = []

        # Model processes each graph in the batch one by one
        for i in range(len(batch_edge_index)):
            # Pass the features and the edge indices into the model and store
            # the output (logits)
            _output = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])

            # Ensure that model outputs (logits) are of type float32
            _output = _output.float()

            # Calculate the difference between the model output and the targets
            # via the provided criterion (loss function)
            # Note: Criterion expects logits and target class labels
            loss = criterion(_output.squeeze(), batch_y[i].squeeze())

            # This difference is stored in the loss_per_batch list
            loss_per_batch.append(loss)

        # The average loss across all subgraphs within the batch is calculated and stored
        total_loss_per_batch = sum(loss_per_batch) / len(loss_per_batch)
        totalLossPerEpoch.append(total_loss_per_batch)

        # Computes the loss gradients with respect to the model parameters
        total_loss_per_batch.backward()

        # Updates the model parameters using the gradients
        optimizer.step()

    # After processing all the batches, the average loss across all
    # batches is calculated and returned
    total_loss_per_epoch = sum(totalLossPerEpoch) / len(totalLossPerEpoch)
    return total_loss_per_epoch

In [46]:
data_loader_bkg_dict = {
    0: dataLoaders['ll'],  # For label 0
    2: dataLoaders['cl'],  # For label 2
    3: dataLoaders['lc'],  # For label 3
    4: dataLoaders['cc']   # For label 4
}

def testModel(model, device, data_loader_true, data_loader_bkg_dict):
    all_scores = []
    true_labels = []
    
    with torch.no_grad():
        model.eval()
        model.to(device)
        
        # Loop over true edges (positive class)
        for batch_x, batch_edge_index, batch_edge_index_out in data_loader_true:
            batch_x = torch.stack(batch_x).to(device)
            batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
            batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]

            for i in range(len(batch_edge_index)):
                test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                
                # Assign the true label for true edges (positive class is 1 for true edges)
                true_label = torch.ones(test_edge_scores.size(0), dtype=torch.long, device=device)
                
                # Append scores and true labels for this batch
                all_scores.append(test_edge_scores)
                true_labels.append(true_label)
        
        # Loop over background edges (negative class)
        for background_type, data_loader_bkg in data_loader_bkg_dict.items():
            background_label = background_type  # This is your label for the current background type

            for batch_x, batch_edge_index, batch_edge_index_out in data_loader_bkg:
                batch_x = torch.stack(batch_x).to(device)
                batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
                batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]

                for i in range(len(batch_edge_index)):
                    test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                    
                    # Assign the true label for the current background type
                    true_label = torch.full((test_edge_scores.size(0),), background_label, dtype=torch.long, device=device)
                    
                    # Append scores and true labels for this batch
                    all_scores.append(test_edge_scores)
                    true_labels.append(true_label)

    # Concatenate all scores and labels from different batches
    all_scores = torch.cat(all_scores, dim=0).cpu().numpy()
    true_labels = torch.cat(true_labels, dim=0).cpu().numpy()

    return all_scores, true_labels

In [47]:
def lossForTrainingAndTesting(model, loader, loss_fn, optimizer, training, device):
    if training:
        model.train()  # Set model to training mode
    else:
        model.eval()  # Set model to evaluation mode
    
    total_loss = 0.0
    num_batches = 0
    all_scores = []
    all_labels = []
    
    if training:
        # Training loop: loader provides 4 elements
        for batch_x, batch_edge_index, batch_edge_index_out, batch_y in loader:
            batch_x = torch.stack(batch_x).to(device)
            batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
            batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]
            batch_y = torch.cat(batch_y).to(device)  # True labels for the current batch
            
            for i in range(len(batch_edge_index)):
                # Forward pass
                test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                
                # Compute loss using the provided true labels (`batch_y`)
                loss = loss_fn(test_edge_scores, batch_y[i])
                total_loss += loss.item()
                num_batches += 1
                
                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    else:
        # Evaluation loop: loader provides only 3 elements
        with torch.no_grad():
            for batch_x, batch_edge_index, batch_edge_index_out in loader:
                batch_x = torch.stack(batch_x).to(device)
                batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
                batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]
                
                for i in range(len(batch_edge_index)):
                    # Forward pass
                    test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                    all_scores.append(test_edge_scores)
                    num_batches += 1
    
    # Compute the average loss during training (loss is not calculated during evaluation)
    average_loss = total_loss / num_batches if training else None
    
    return average_loss

In [None]:
# Define the data-saving function
def saveDataH5(fileName, dataDict, compressionType='gzip'):
    with h5py.File(fileName, 'w') as file:
        for key, data in dataDict.items():
            file.create_dataset(key, data=data, compression=compressionType)

num_epochs = 500
lossPerEpoch = []
scores = []
truth_labels = []
avgLoss_TrueTrain = []
avgLoss_TrueTest = []
avgLoss_BkgTrain = []
avgLoss_BkgTest = []

for epoch in range(num_epochs):
    # Train the model
    total_loss_per_epoch = trainModel(model, device, dataLoaders['train'], optimizer, criterion)
    lossPerEpoch.append(total_loss_per_epoch.cpu().detach().numpy())  # Ensure tensor is detached for saving

    # Update learning rate
    scheduler.step()

    # Test the model
    epoch_scores, epoch_true_labels = testModel(model, device, dataLoaders['tt'], data_loader_bkg_dict)
    
    # Compute the average loss for true and background edges
    avgLossTrueTrain = lossForTrainingAndTesting(model, dataLoaders['train'], criterion, optimizer, True, device)
    avgLossTrueTest = lossForTrainingAndTesting(model, dataLoaders['tt'], criterion, optimizer, False, device)
    avgLossBkgTrain = lossForTrainingAndTesting(model, dataLoaders['train'], criterion, optimizer, True, device)
    avgLossBkgTest = lossForTrainingAndTesting(model, dataLoaderTotalBkg, criterion, optimizer, False, device)

    # Store results
    scores.append(epoch_scores)
    truth_labels.append(epoch_true_labels)
    avgLoss_TrueTrain.append(avgLossTrueTrain)
    avgLoss_TrueTest.append(avgLossTrueTest)
    avgLoss_BkgTrain.append(avgLossBkgTrain)
    avgLoss_BkgTest.append(avgLossBkgTest)

    # Print the loss for the current epoch
    print(f"Epoch: {epoch+1} | Total Loss Per Epoch: {total_loss_per_epoch.item():.4f}")
    
    # Save the model and data every 100 epochs
    if (epoch + 1) % 100 == 0:
        # Save model checkpoint
        checkpoint_path = f"/storage/mxg1065/MultiClassGNN/modelCheckpoints/model_{num_layers}_layers_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved for epoch {epoch+1}")
        
        # Save current training metrics
        data_path = f"/storage/mxg1065/MultiClassGNN/data_with_{num_layers}_layers_epoch_{epoch+1}.h5"
        saveDataH5(
            data_path,
            {
                "lossData": np.array(lossPerEpoch, dtype=np.float32),
                "scores": np.array(scores, dtype=np.float32),
                "truth_labels": np.array(truth_labels, dtype=np.int32),
                "avgLoss_TrueTrain": np.array(avgLoss_TrueTrain, dtype=np.float32),
                "avgLoss_TrueTest": np.array(avgLoss_TrueTest, dtype=np.float32),
                "avgLoss_BkgTrain": np.array(avgLoss_BkgTrain, dtype=np.float32),
                "avgLoss_BkgTest": np.array(avgLoss_BkgTest, dtype=np.float32)
            },
            compressionType='gzip'
        )
        print(f"Training data saved for epoch {epoch+1}")

### Running the code from a starting model epoch

In [None]:
start_epoch = 300
num_epochs = 500


# Load the model checkpoint from starting epoch
checkpoint_path = "SOME PATH"
model.load_state_dict(torch.load(checkpoint_path))
model.to(device)

# Resume training
for epoch in range(start_epoch, num_epochs):
    # Train the model
    total_loss_per_epoch = trainModel(model, device, dataLoaders['train'], optimizer, criterion)
    lossPerEpoch.append(total_loss_per_epoch)  # Append new losses to existing list

    # Update learning rate
    scheduler.step()

    # Test the model
    epoch_scores, epoch_true_labels = testModel(model, device, dataLoaders['tt'], data_loader_bkg_dict)
    
    # Compute the average loss for true and background edges
    avgLossTrueTrain = lossForTrainingAndTesting(model, dataLoaders['train'], criterion, optimizer, True, device)
    avgLossTrueTest = lossForTrainingAndTesting(model, dataLoaders['tt'], criterion, optimizer, False, device)
    avgLossBkgTrain = lossForTrainingAndTesting(model, dataLoaders['train'], criterion, optimizer, True, device)
    avgLossBkgTest = lossForTrainingAndTesting(model, dataLoaderTotalBkg, criterion, optimizer, False, device)

    # Store results
    scores.append(epoch_scores)
    truth_labels.append(epoch_true_labels)
    avgLoss_TrueTrain.append(avgLossTrueTrain)
    avgLoss_TrueTest.append(avgLossTrueTest)
    avgLoss_BkgTrain.append(avgLossBkgTrain)
    avgLoss_BkgTest.append(avgLossBkgTest)

    # Print the loss for the current epoch
    print(f"Epoch: {epoch+1} | Total Loss Per Epoch: {total_loss_per_epoch.item():.4f}")

    # As an example, if you start at 300 epochs, this code below shows how to
    # save the model at 400 and 500 epochs
    if (epoch + 1) in [400, 500]:
        checkpoint_path = f"/storage/mxg1065/MultiClassGNN/modelCheckpoints/model_{num_layers}_layers_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved for epoch {epoch+1}")

# Moves the loss back to the CPU and converts it to a numpy array
lossPerEpoch = [tensor.cpu() for tensor in lossPerEpoch]
lossPerEpoch = [tensor.detach().numpy() for tensor in lossPerEpoch]

# Convert lists to numpy arrays for saving
lossPerEpoch = np.array(lossPerEpoch, dtype=np.float32)
scores = np.array(scores, dtype=np.float32)
truth_labels = np.array(truth_labels, dtype=np.int32)
avgLoss_TrueTrain = np.array(avgLoss_TrueTrain, dtype=np.float32)
avgLoss_TrueTest = np.array(avgLoss_TrueTest, dtype=np.float32)
avgLoss_BkgTrain = np.array(avgLoss_BkgTrain, dtype=np.float32)
avgLoss_BkgTest = np.array(avgLoss_BkgTest, dtype=np.float32)

# Saving the data
saveDataH5("/storage/mxg1065/MultiClassGNN/data_five_layers.h5", 
           {"lossData": lossPerEpoch,
            "scores": scores,
            'truth_labels': truth_labels,
            "avgLoss_TrueTrain": avgLoss_TrueTrain,
            "avgLoss_TrueTest": avgLoss_TrueTest,
            "avgLoss_BkgTrain": avgLoss_BkgTrain,
            "avgLoss_BkgTest": avgLoss_BkgTest},
           'gzip')
