## Setup

### packages

In [None]:
# basic packages
import os, time, random
from itertools import product
import matplotlib.pyplot as plt

# data
import awkward as ak
import d_hep_data

# qml
import pennylane as qml
from pennylane import numpy as np

# pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# pytorch_lightning
import lightning as L
import lightning.pytorch as pl

# pytorch_geometric
import networkx as nx
import torch_geometric.nn as geom_nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing

# scipy
from sklearn import metrics

# wandb
import wandb
from lightning.pytorch.loggers import WandbLogger
wandb.login()

# reproducibility
L.seed_everything(3020616)

# faster calculation on GPU but less precision
torch.set_float32_matmul_precision("medium")

In [None]:
# configuration dictionary
cf = {}
cf["time"]     = time.strftime("%Y%m%d_%H%M%S", time.localtime())
cf["wandb"]    = True
cf["project"]  = "g_3vec_geognn"
cf["rnd_seed"] = None # to be determined by for loop

# data infotmation
cf["num_events"]    = "50000"
cf["sig_channel"]   = "ZprimeToZhToZinvhbb"
cf["bkg_channel"]   = "QCD_HT2000toInf"
cf["jet_type"]      = "fatjet"
cf["subjet_radius"] = None # to be determined from [0.25, 0.5, 0.75]
cf["cut_limit"]     = (800, 1600)
cf["bin"]           = 8
cf["num_bin_data"]  = None # to be determined from [100, 200, 300]

# traning configuration
cf["num_train_ratio"]   = 0.8
cf["num_test_ratio"]    = 0.2
cf["batch_size"]        = 64
cf["num_workers"]       = 0
cf["max_epochs"]        = 100
cf["accelerator"]       = "cpu"
cf["fast_dev_run"]      = False
cf["log_every_n_steps"] = cf["batch_size"] // 2

# model hyperparameters
cf["loss_function"]  = nn.BCEWithLogitsLoss()
cf["optimizer"]      = optim.Adam
cf["learning_rate"]  = 1E-3

# 2PCNN hyperparameters
cf["gnn_layers"] = None # to be determined by grid search
cf["mlp_layers"] = None # to be determined by grid search

## Data Module

In [None]:
class JetGraphDataModule(pl.LightningDataModule):
    def __init__(self, sig_events, bkg_events, link_mode, **kwargs):
        '''Add a "_" prefix if it is a fastjet feature'''
        super().__init__()
        # jet events
        sig_events = self._pre_transformation(sig_events)
        bkg_events = self._pre_transformation(bkg_events)
        sig_graph_list = self._create_graph_list(sig_events, 1, link_mode, **kwargs)
        bkg_graph_list = self._create_graph_list(bkg_events, 0, link_mode, **kwargs)

        # count the number of training, and testing
        num_data = cf["bin"] * cf["num_bin_data"]
        assert len(sig_graph_list) >= num_data, f"sig data not enough: {len(sig_graph_list)} < {num_data}"
        assert len(bkg_graph_list) >= num_data, f"bkg data not enough: {len(bkg_graph_list)} < {num_data}"
        num_train = int(num_data * cf["num_train_ratio"])
        num_test  = int(num_data * cf["num_test_ratio"])
        print(f"DataLog: {cf['sig_channel']} has {len(sig_graph_list)} events | {cf['bkg_channel']} has {len(bkg_graph_list)} events.")
        print(f"Choose num_data for each channel to be {num_data} | Each channel  has num_train = {num_train}, num_test = {num_test}")

        # prepare dataset for dataloader
        train_idx = num_train
        test_idx  = num_train + num_test
        self.train_dataset = self.sig_graph_list[:train_idx] + self.bkg_graph_list[:train_idx]
        self.test_dataset  = self.sig_graph_list[train_idx:test_idx] + self.bkg_graph_list[train_idx:test_idx]
        self.valid_dataset = self.test_dataset
    
    def _pre_transformation(events):
        subjet_pt  = events["_pt"] / events["pt"]
        subjet_eta = events["_delta_eta"]
        subjet_phi = events["_delta_phi"]
        events     = ak.zip([subjet_pt, subjet_eta, subjet_phi])
        events     = events.to_list()
        events     = [torch.tensor(events[i], dtype=torch.float32) for i in range(len(events))]
        return events

    def _create_graph_list(self, events, y, link_mode, **kwargs):
        # create pytorch_geometric "Data" object
        graph_list = []
        for x in events:
            x.requires_grad = False
            if link_mode == "fully_connected":
                edge_index = list(product(range(len(x)), range(len(x))))
            elif link_mode == "top_n_closest":
                edge_index = []
                for i in range(len(x)):
                    delta_eta = x[1] - x[1][i]
                    delta_phi = x[2] - x[2][i]
                    top_n_idx = torch.argsort(delta_eta**2 + delta_phi**2)
                    edge_index += [[i, idx] for idx in top_n_idx[1:1+kwargs["n"]]]
            edge_index = torch.tensor(edge_index).transpose(0, 1)
            graph_list.append(Data(x=x, edge_index=edge_index, y=y))
        random.shuffle(graph_list)
        return graph_list

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=cf["batch_size"], num_workers=cf["num_workers"],  shuffle=True)

    def valid_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=cf["batch_size"], num_workers=cf["num_workers"],  shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=cf["batch_size"], num_workers=cf["num_workers"],  shuffle=False)

## Models

In [None]:
class ClassicalMLP(nn.Module):
    def __init__(self, in_channel, out_channel, hidden_channel, num_layers):
        super().__init__()
        if num_layers == 0:
            self.net = nn.Linear(in_channel, out_channel)
        else:
            net = [nn.Linear(in_channel, hidden_channel), nn.ReLU()]
            for _ in range(num_layers-2):
                net += [nn.Linear(hidden_channel, hidden_channel), nn.ReLU()]
            net += [nn.Linear(hidden_channel, out_channel)]
            self.net = nn.Sequential(*net)
    def forward(self, x):
        return self.net(x)

class QuantumMLP(nn.Module):
    def __init__(self, num_qubits, num_layers, num_reupload, measurements):
        super().__init__()
        # create a quantum MLP
        @qml.qnode(qml.device('default.qubit', wires=num_qubits))
        def circuit(inputs, weights):
            for i in range(num_reupload):
                qml.AngleEmbedding(features=inputs, wires=range(num_qubits), rotation='Y')
                qml.StronglyEntanglingLayers(weights=weights[i], wires=range(num_qubits))
            measurements_dict = {"X":qml.PauliX, "Y":qml.PauliY, "Z":qml.PauliZ}
            return [qml.expval(measurements_dict[m[1]](wires=m[0])) for m in measurements]
        # turn the quantum circuit into a torch layer
        weight_shapes = {"weights":(num_reupload, num_layers, num_qubits, 3)}
        net = [qml.qnn.TorchLayer(circuit, weight_shapes=weight_shapes)]
        self.net = nn.Sequential(*net)
    def forward(self, x):
        return self.net(x)

In [None]:
class Classical2PCNNForwardMP(MessagePassing):
    def __init__(self, num_features, num_layers, aggr):
        super().__init__(aggr=aggr)
        self.mp_phi = ClassicalMLP(
            in_channel     = 2*num_features,
            out_channel    = num_features,
            hidden_channel = 2*num_features,
            num_layers     = num_layers,
            )
        self.mp_gamma = ClassicalMLP(
            in_channel     = 2*num_features, 
            out_channel    = num_features,
            hidden_channel = 2*num_features,
            num_layers     = num_layers,
        )
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    def message(self, x_i, x_j):
        return self.mp_phi(torch.cat((x_i, x_j), dim=-1))
    def update(self, aggr_out, x):
        return self.mp_gamma(torch.cat((x, aggr_out), dim=-1))
    
class Quantum2PCQNNForwardMP(MessagePassing):
    def __init__(self, num_features, num_layers, num_reupload, aggr):
        super().__init__(aggr=aggr)
        measurements = [[i, "Z"] for i in range(num_features)]
        self.mp_phi = QuantumMLP(
            num_qubits   = 2*num_features, 
            num_layers   = num_layers,
            num_reupload = num_reupload,
            measurements = measurements,
            )
        self.mp_gamma = ClassicalMLP(
            in_channel     = 2*len(measurements), 
            out_channel    = num_features,
            hidden_channel = 2*num_features,
            num_layers     = num_layers,
        )
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    def message(self, x_i, x_j):
        return self.mp_phi(torch.cat((x_i, x_j), dim=-1))
    def update(self, aggr_out, x):
        return self.mp_gamma(torch.cat((x, aggr_out), dim=-1))