# This notebook uses jet data to classify constituent particles (hadrons / not hadrons) using GNN

## Imports

In [4]:
# torch imports
import torch
import torch.utils.data as data
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.nn as tgnn
from torch_geometric.loader import DataLoader

from torch_scatter import scatter

In [5]:
# other imports
import os
import uproot
import vector
import awkward as ak
vector.register_awkward()
import pandas as pd
import numpy as np
from tqdm import tqdm

In [6]:
# set up multithreading
torch.set_num_threads(16)

## Define dataset loader

In [7]:
class JetToPC(data.Dataset):
    """
     This class returns jet data as set of point clouds
    """

    def __init__(
        self, 
        dataset_path:str,       # path to dataset
        part_feats:list,        # list of particle features to use in training
        jet_feats:list,         # list of global jet features to se in trainng
        tree_name:str='tree',   # name of input data tree
        k:int = 5,              # parameter for knn
    ) -> None:

        super().__init__()
        
        self.dataset = uproot.open(dataset_path)
        self.tree = self.dataset[tree_name].arrays()
        self.num_entries = self.dataset[tree_name].num_entries
        
        self.part_feats = part_feats
        self.jet_feats = jet_feats
        
        self.k = k
        
        
    def transform_jet_to_point_cloud(self, idx:int) -> dict :
    
        npart = self.tree['jet_nparticles'].to_numpy()[idx]  # get number of particles
        
        # setup feature lists
        part_feat_list = np.stack([self.tree[part_feat][idx].to_numpy() for part_feat in self.part_feats]).T
        jet_feat_list = np.stack([self.tree[jet_feat].to_numpy()[idx:idx+1] for jet_feat in self.jet_feats]).T
        
        # nan check
        part_feat_list[np.isnan(part_feat_list)] = 0
        jet_feat_list[np.isnan(jet_feat_list)] = 0
        
        # define target class
        part_is_neutral_hadron = self.tree['part_isNeutralHadron'][idx].to_numpy()
        part_is_charged_hadron = self.tree['part_isChargedHadron'][idx].to_numpy()
        part_is_hadron = np.logical_or(part_is_neutral_hadron, part_is_charged_hadron)
        
        # set up knn graph
        part_eta = torch.tensor(self.tree['part_deta'][idx].to_numpy())
        part_phi = torch.tensor(self.tree['part_dphi'][idx].to_numpy())
        eta_phi_pos = torch.stack([part_eta, part_phi], dim=-1)
        
        edge_index = tgnn.pool.knn_graph(x=eta_phi_pos, k=self.k) # this is saved as src, dst
        
        # set up edge distances (or edge weights)
        src, dst = edge_index 
        part_del_eta = part_eta[dst] - part_eta[src]
        part_del_phi = part_phi[dst] - part_phi[src]
        
        part_del_R = torch.hypot(part_del_eta, part_del_phi).view(-1, 1)
        
        # set up the data
        data = Data (x=torch.tensor(part_feat_list), edge_index=edge_index, edge_deltaR=part_del_R)
        data.label = torch.tensor(part_is_hadron)
        data.global_data = torch.tensor(jet_feat_list)
        data.seq_length = torch.tensor(int(npart))
        
        return data

    def __len__(self) -> int:
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return self.num_entries
    
    def __getitem__(self, idx:int) -> dict :
        # Return the idx-th data point of the dataset
        return self.transform_jet_to_point_cloud(idx)


## MLP builder

In [8]:
def BuildMLP(
    in_features,
    outputsize,
    features,
    add_batch_norm=False,
    add_activation=None
):
    """
     This is a template function for a generalised MLP
    """
    layers = []
    
    # input layer
    layers.append(nn.Linear(in_features,features[0]))
    layers.append(nn.ReLU())
    
    # hidden layers
    for hidden_i in range(1,len(features)):
        if add_batch_norm:
            layers.append(nn.BatchNorm1d(features[hidden_i-1]))
            
        layers.append(nn.Linear(features[hidden_i-1],features[hidden_i]))
        layers.append(nn.ReLU())
    
    # output layer
    layers.append(nn.Linear(features[-1],outputsize))
    
    if add_activation!=None:
        layers.append(add_activation)

    return nn.Sequential(*layers)

## Set up Models

In [9]:
class EdgeModel(nn.Module):
    def __init__(
        self, 
        input_edge_dim:int, 
        output_edge_dim:int, 
        node_dim:int, 
        global_dim:int, 
        features:list 
    ):
        super(EdgeModel, self).__init__()
        
        self.edge_mlp = BuildMLP(
            in_features=2*node_dim+global_dim+input_edge_dim, 
            outputsize=output_edge_dim, 
            features=features
        )

    def forward(self, src, dst, edge_attr, u, edge_batch):
        out = torch.cat([src, dst, edge_attr, u[edge_batch]], dim=1)
        return self.edge_mlp(out)

In [10]:
class NodeModel(torch.nn.Module):
    def __init__(
        self, 
        input_edge_dim:int, 
        input_node_dim:int, 
        output_node_dim:int, 
        input_global_dim:int, 
        features:list
    ):
        super(NodeModel, self).__init__()
        
        self.node_mlp = BuildMLP(
            in_features=input_edge_dim+input_node_dim+input_global_dim, 
            outputsize=output_node_dim, 
            features=features
        )

    def forward(self, x, edge_index, edge_attr, u, batch):

        row, col = edge_index
        #out = torch.cat([x[row], edge_attr], dim=1)
        #out = self.node_mlp_1(out)
        out = scatter(edge_attr, col, dim=0, dim_size=x.size(0), reduce='mean')
#         print('Agrregated out shape : ', out.shape)
        out = torch.cat([x, out, u[batch]], dim=1)
#         print('Stacked out shape : ', out.shape)
        return self.node_mlp(out)

In [11]:
class GlobalModel(torch.nn.Module):
    def __init__(
        self, 
        input_edge_dim:int, 
        input_node_dim:int, 
        input_global_dim:int, 
        output_global_dim:int, 
        features:list
    ):
        super(GlobalModel, self).__init__()
        
        self.global_mlp = BuildMLP(
            in_features=input_edge_dim+input_node_dim+input_global_dim, 
            outputsize=output_global_dim, 
            features=features
        )

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        src_idx, dst_idx = edge_index
        
        out = torch.cat([
            u,
            scatter(x, batch, dim=0, reduce='mean'),
            scatter(edge_attr, batch[src_idx], dim=0, reduce='mean')
        ], dim=1)
        return self.global_mlp(out)

## Set up a single GNN layer

In [12]:
def BuildGNNLayer(
    n_edge_input:int, 
    n_edge_hidden:list, 
    n_edge_output:int, 
    n_node_input:int, 
    n_node_hidden:list, 
    n_node_output:int,
    n_global_input:int,
    n_global_hidden:list,
    n_global_output:int
):
    
    edge_network = EdgeModel(n_edge_input, n_edge_output, n_node_input, n_global_input, n_edge_hidden)
    node_network = NodeModel(n_edge_output, n_node_input, n_node_output, n_global_input, n_node_hidden)
    global_network = GlobalModel(n_edge_output, n_node_output, n_global_input, n_global_output, n_global_hidden)
    
    gnn_layer = tgnn.MetaLayer(edge_model=edge_network, node_model=node_network, global_model=global_network)
    
    return gnn_layer

## Set up the GNN model

In [13]:
class GNNModel(torch.nn.Module):
    
    def __init__(
        self,
        nLayers:int,
        edge_input_features:int,
        node_input_features:int,
        global_input_features:int,
        edge_hidden_features:list,
        node_hidden_features:list,
        global_hidden_features:list,
        edge_output_features:list,
        node_output_features:list,
        global_output_features:list,
        normalization:str = '', 
        pool:str = 'mean'
    ) -> None:
        
        
        super(GNNModel, self).__init__()

        layers = []
        
        edge_output_features.insert(0, edge_input_features)
        node_output_features.insert(0, node_input_features)
        global_output_features.insert(0, global_input_features)
        
        # layer loop
        for i in range(len(edge_hidden_features)):

            layers.append(BuildGNNLayer(
                edge_output_features[i], 
                edge_hidden_features[i],
                edge_output_features[i+1], 
                node_output_features[i], 
                node_hidden_features[i], 
                node_output_features[i+1],
                global_output_features[i],
                global_hidden_features[i],
                global_output_features[i+1]
            ))
        
        
        self.sequential = nn.ModuleList(layers)
    
       
    def forward(
        self, 
        x:torch.Tensor, 
        edge_attr:torch.Tensor, 
        u:torch.Tensor, 
        edge_index:torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        
        for i, layer in enumerate(self.sequential):
            x, edge_attr, u = layer(x, edge_index, edge_attr, u, batch)
            
        out = F.log_softmax(x, dim=-1)
        
        return out

## Define the loss function

In [14]:
# define the loss function
class CustomLoss(nn.Module):
    
    def __init__(self):
        super(CustomLoss, self).__init__()
        
    def forward(self, outputs, labels, batch):
        
        # Get the unique batch indices
        unique_batch_indices = torch.unique(batch, return_counts=True)

        # Get the weighted mean of BCE for each jet
        net_loss = 0
        start_idx = 0
        
        for batch_idx, batch_size in zip(unique_batch_indices[0], unique_batch_indices[1]):
            end_idx = start_idx + batch_size
            
            # Slice the output
            separated_output = outputs[start_idx:end_idx]  
            separated_label = labels[start_idx:end_idx]
            
            net_loss += F.nll_loss(separated_output, separated_label.type(torch.LongTensor))*batch_size
            start_idx = end_idx
        
        net_loss = net_loss/batch.shape[0]

        return net_loss

## Train and Test routines

In [15]:
# training routine
def Train(model, device, train_loader, optimizer, loss_fn):
    
    train_loss_ep = 0.
    
    # set the train mode for the model
    model.train()
    
    # input batchwise data
    with tqdm(train_loader, ascii=True) as tq:
        for dl in tq:
            x, edge_index, edge_attr, label, u, batch = (
                dl.x.to(device), 
                dl.edge_index.to(device), 
                dl.edge_deltaR.to(device), 
                dl.label.to(device), 
                dl.global_data.to(device), 
                dl.batch.to(device)
            )
            optimizer.zero_grad()
            output = model(x, edge_attr, u, edge_index, batch)
            loss = loss_fn(output, label, batch)
            loss.backward()
            optimizer.step()

            train_loss_ep += loss.item()*x.size(0)
        
    return train_loss_ep
    

In [16]:
# testing routine
def Test(model, device, test_loader, loss_fn):
    
    test_loss_ep = 0.
    
    model.eval()
    with tqdm(train_loader, ascii=True) as tq:
        for dl in tq:
            x, edge_index, edge_attr, label, u, batch = (
                dl.x.to(device), 
                dl.edge_index.to(device), 
                dl.edge_deltaR.to(device), 
                dl.label.to(device), 
                dl.global_data.to(device), 
                dl.batch.to(device)
            )
            
            output = model(x, edge_attr, u, edge_index, batch)
            loss = loss_fn(output, label, batch)
            
            test_loss_ep += loss.item()*x.size(0)
        
    return test_loss_ep

## Define model inputs

In [17]:
particle_features = [
    'part_px',
    'part_py',
    'part_pz',
    'part_energy',
    'part_deta',
    'part_dphi',
    'part_d0val',
    'part_d0err',
    'part_dzval',
    'part_dzerr',
    'part_charge',
]

In [18]:
jet_features = [
    'jet_pt',
    'jet_eta',
    'jet_phi',
    'jet_energy',
    'jet_tau1',
    'jet_tau2',
    'jet_tau3',
    'jet_tau4',
]

In [19]:
edge_hidden_features = [[4, 10, 5],[5, 10, 5],[5, 10, 5]]
node_hidden_features = [[20, 20, 20, 20],[20, 20, 20, 20],[20, 20, 20, 20]]
global_hidden_features = [[8, 8, 8], [8, 8, 8], [8, 8, 8]]

edge_output_features = [10, 10, 10]
node_output_features = [10, 10, 2]
global_output_features = [8, 8, 8]

## Runtime settings

In [26]:
# device to run
device = 'cpu' # cuda, mps

# dataset path
dataset_path = '../data/JetClass_example_100k.root'

## Instantiate data loader, model, optimiser

In [27]:
# number of subprocesses to use for data loading
num_workers = 5

# how many samples per batch to load
batch_size = 5

# percentage of training set to use as validation
train_size, valid_size = 0.6, 0.2

# create train, test and valid split
jet_dataset = JetToPC(dataset_path, particle_features, jet_features)
num_train = len(jet_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
train_split = int(np.floor(train_size * num_train))
valid_split = int(np.floor(valid_size * num_train))
train_index, valid_index, test_index = indices[0:train_split], indices[train_split:train_split + valid_split], indices[train_split + valid_split:]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(valid_index)
test_sampler = SubsetRandomSampler(test_index)

# prepare data loaders
train_loader = DataLoader(
    dataset=jet_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    sampler=train_sampler,
)
valid_loader = DataLoader(
    dataset=jet_dataset, 
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=valid_sampler
)
test_loader = DataLoader(
    dataset=jet_dataset, 
    batch_size=batch_size,
    num_workers=num_workers,  
    sampler=test_sampler
)

In [28]:
# instantiate model
model = GNNModel(
            nLayers=len(edge_output_features), 
            edge_input_features=1,
            node_input_features=len(particle_features),
            global_input_features=len(jet_features),
            edge_hidden_features=edge_hidden_features,
            node_hidden_features=node_hidden_features,
            global_hidden_features=global_hidden_features,
            edge_output_features=edge_output_features,
            node_output_features=node_output_features,
            global_output_features=global_output_features
        )

In [29]:
# instantiate optimiser
optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)

In [30]:
# instantiate loss function
loss_fn = CustomLoss()

## Start the training

In [None]:
# number of epochs to train the model
n_epochs = 20

# initialize tracker for minimum validation loss
valid_loss_min = np.Inf  # set initial "min" to infinity

for epoch in range(n_epochs):
    # monitor losses
    
    train_loss = Train(model, device, train_loader, optimizer, loss_fn)
    
    valid_loss = Test(model, device, valid_loader, loss_fn)
    
    # print training/validation statistics 
    # calculate average loss over an epoch
    train_loss = train_loss / len(train_loader.sampler)
    valid_loss = valid_loss / len(valid_loader.sampler)
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch+1, 
        train_loss,
        valid_loss
        ))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(model.state_dict(), '../models/hadron_id.pt')
        valid_loss_min = valid_loss

 23%|#####################7                                                                        | 2780/12000 [01:43<06:13, 24.70it/s]