In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
import git

import uproot as ut
import awkward as ak
import numpy as np
import math
import vector
import sympy as sp

import re
from tqdm import tqdm
import timeit

sys.path.append( git.Repo('.', search_parent_directories=True).working_tree_dir )
from utils import *

import utils.torchUtils as gnn

In [3]:
import torch, torch_geometric
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx
import networkx as nx

from torch_geometric.typing import Adj, PairTensor
from typing import Callable, Optional, Union, Tuple
from torch.nn import Linear, Module
from torch.nn.functional import softmax, relu, sigmoid

In [4]:
class PairJets:
    def __call__(self, data : Data) -> Data:
        uptri_mask = data.edge_index[0] < data.edge_index[1]
        dijet_features = torch.nn.functional.pad(data.edge_attr[uptri_mask], (1,0))
        dijet_i, dijet_j = data.edge_index[:,uptri_mask]
        dijet_y = data.edge_y[uptri_mask]
        dijet_id = data.edge_id[uptri_mask]

        node_type = data.get('node_type', torch.zeros(data.num_nodes))
        data.node_type = torch.cat([node_type, torch.full_like(dijet_i, 1)]).long()
       
        data.x = torch.cat([data.x,dijet_features])
        data.y = torch.cat([data.y, dijet_y])
        data.node_id = torch.cat([data.node_id, dijet_id])

        dijet_idx = torch.where(data.node_type == 1)[0]
        data.edge_index = torch.stack([torch.cat([dijet_i, dijet_idx, dijet_j, dijet_idx]), torch.cat([dijet_idx, dijet_j, dijet_idx, dijet_i])])
        data.edge_attr = torch.ones_like(data.edge_index[0]).reshape(-1, 1)
        return data

In [5]:
class PairHiggs:
    def __call__(self, data: Data) -> Data:
        n_higgs = (data.node_type == 1).sum()
        n_dihiggs = n_higgs*(n_higgs-1)//2

        dihiggs_features = torch.zeros(n_dihiggs, data.num_node_features)

        dijets = torch.where(data.node_type == 1)[0]
        dihiggs_i = torch.repeat_interleave(dijets, n_higgs)
        dihiggs_j = torch.repeat_interleave(dijets[None], n_higgs, dim=0).reshape(-1)
        uptri_mask = dihiggs_i < dihiggs_j

        dihiggs_i = dihiggs_i[uptri_mask]
        dihiggs_j = dihiggs_j[uptri_mask]

        dihiggs_1 = (data.node_id[dihiggs_i] == 1) & (data.node_id[dihiggs_j] == 2)
        dihiggs_2 = (data.node_id[dihiggs_i] == 3) & (data.node_id[dihiggs_j] == 4)

        data.node_type = torch.cat([data.node_type, torch.full_like(dihiggs_i, 2)]).long()
        data.node_id = torch.cat([data.node_id, 1*dihiggs_1 + 2*dihiggs_2]).long()

        data.x = torch.cat([data.x, dihiggs_features])
        data.y = torch.cat([data.y, 1*(dihiggs_1 | dihiggs_2)]).long()

        dihiggs_idx = torch.where(data.node_type == 2)[0]
        data.edge_index = torch.stack([torch.cat([data.edge_index[0], dihiggs_i,   dihiggs_idx, dihiggs_j,   dihiggs_idx]), 
                                    torch.cat([data.edge_index[1], dihiggs_idx, dihiggs_j,   dihiggs_idx, dihiggs_i])])

        data.edge_attr = torch.ones_like(data.edge_index[0]).reshape(-1,1)
        return data

In [6]:
class PairYs:
    def __call__(self, data: Data) -> Data:
        n_ys = (data.node_type == 2).sum()
        n_diys = n_ys*(n_ys-1)//2

        diys_features = torch.zeros(n_diys, data.num_node_features)

        dihiggs = torch.where(data.node_type == 2)[0]
        diys_i = torch.repeat_interleave(dihiggs, n_ys)
        diys_j = torch.repeat_interleave(dihiggs[None], n_ys, dim=0).reshape(-1)
        uptri_mask = diys_i < diys_j

        diys_i = diys_i[uptri_mask]
        diys_j = diys_j[uptri_mask]

        diys_1 = (data.node_id[diys_i] == 1) & (data.node_id[diys_j] == 2)

        data.node_type = torch.cat([data.node_type, torch.full_like(diys_i, 3)]).long()
        data.node_id = torch.cat([data.node_id, 1*diys_1]).long()


        data.x = torch.cat([data.x, diys_features])
        data.y = torch.cat([data.y, 1*(diys_1)]).long()

        diys_idx = torch.where(data.node_type == 3)[0]
        data.edge_index = torch.stack([torch.cat([data.edge_index[0], diys_i,   diys_idx, diys_j,   diys_idx]), 
                                       torch.cat([data.edge_index[1], diys_idx, diys_j,   diys_idx, diys_i])])

        data.edge_attr = torch.ones_like(data.edge_index[0]).reshape(-1,1)
        return data

In [7]:
class NodeClassWeight:
    def __call__(self, data: Data) -> Data:
        true_mask = data.y == 1
        true_weight = (data.num_nodes - true_mask.sum())/data.num_nodes
        fake_weight = 1 - true_weight
        data.weight = torch.where(true_mask, true_weight, fake_weight)
        return data

In [18]:
from torch_geometric.loader import DataLoader, DenseDataLoader

def load_dataset(fn='data/MX_1200_MY_500-training', template=None, shuffle=False):
    dataset = gnn.Dataset(fn,transform=template.transform)
    training, testing = gnn.train_test_split(dataset[:3000], 0.33)
    training, validation = gnn.train_test_split(training, 0.5)

    batch_size = 50
    trainloader = DataLoader(training, batch_size=batch_size, shuffle=shuffle, num_workers=8)
    validloader = DataLoader(validation, batch_size=batch_size, shuffle=shuffle, num_workers=8)
    testloader = DataLoader(testing, batch_size=batch_size, shuffle=shuffle, num_workers=8)

    return trainloader, validloader, testloader

template = gnn.Dataset('data/template',make_template=True, transform=gnn.Transform(PairJets(), NodeClassWeight()))
trainloader, validloader, testloader = load_dataset(template=template)

In [19]:
class GoldenModule(Module):
    def __init__(self, n_in_node=5, n_in_edge=4, layers=32):
        super().__init__()
        
        self.conv1 = gnn.layers.GCNConvMSG(n_in_node=n_in_node, n_in_edge=n_in_edge, n_out=layers)
        self.relu1 = gnn.layers.GCNRelu()
        self.norm1 = gnn.layers.GCNNormalize()

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: Tensor) -> Tensor:
        x, edge_attr = self.conv1(x, edge_index, edge_attr)
        x, edge_attr = self.relu1(x, edge_index, edge_attr)
        x, edge_attr = self.norm1(x, edge_index, edge_attr)
        return x, edge_attr

In [37]:
class NodeClassLinear(Module):
    def __init__(self, n_in, n_out, n_classes):
        super().__init__()

        self.in_features, self.out_features = n_in, n_out
        self.n_classes = n_classes
        self.linear = [ torch.nn.Linear(n_in, n_out) for _ in range(n_classes) ]            
    def __call__(self, x : Tensor, classes : Tensor) -> Tensor:
        x = [ linear(x[classes == i]) for i, linear in enumerate(self.linear) ]
        return torch.cat(x) 

In [38]:
from pytorch_lightning import LightningModule
from torch.nn.functional import nll_loss, log_softmax, binary_cross_entropy
from torchmetrics.functional import accuracy, auroc

class Model(LightningModule):
    def __init__(self,n_in_node=5, n_in_edge=1):
        super().__init__()
        self.norm = gnn.layers.GCNNormalize()

        self.golden1 = GoldenModule(n_in_node=n_in_node, n_in_edge=n_in_edge, layers=32)
        self.linear1 = NodeClassLinear(32, 64, 2)
        self.golden2 = GoldenModule(n_in_node=64, n_in_edge=32, layers=128)
        self.linear2 = NodeClassLinear( 64, 32, 2)

        self.linear3 = NodeClassLinear( 32,  1, 2)

    def forward(self, data : Data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x, edge_attr = self.golden1(x, edge_index, edge_attr)
        x = self.linear1(x, data.node_type)

        x, edge_attr = self.golden2(x, edge_index, edge_attr)
        x = self.linear2(x, data.node_type)

        x = self.linear3(x, data.node_type)

        return softmax(x, dim=-1)[:,1]

    def step(self, batch, batch_idx, tag):
        o = self(batch)
        y = batch.y
        loss = binary_cross_entropy(o, y.float(), batch.weight)
        node_auroc = auroc(o, y)
        self.log(f'{tag}/node_auroc',node_auroc)
        self.log(f'{tag}/loss',loss)
        return loss
    def training_step(self, batch, batch_idx): return self.step(batch ,batch_idx, 'train')
    def validation_step(self, batch, batch_idx): return self.step(batch ,batch_idx, 'valid')
    def test_step(self, batch, batch_idx):  return self.step(batch ,batch_idx, 'test')

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [39]:
data = trainloader.dataset[0]

In [40]:
model = Model()

In [41]:
from pytorch_lightning import Trainer

trainer = Trainer(gpus=1, max_epochs=10)
trainer.fit(model, trainloader, validloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type            | Params
--------------------------------------------
0 | norm    | GCNNormalize    | 0     
1 | golden1 | GoldenModule    | 384   
2 | linear1 | NodeClassLinear | 0     
3 | golden2 | GoldenModule    | 20.6 K
4 | linear2 | NodeClassLinear | 0     
5 | linear3 | NodeClassLinear | 0     
--------------------------------------------
21.0 K    Trainable params
0         Non-trainable params
21.0 K    Total params
0.084     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument mat1 in method wrapper_addmm)

In [36]:
trainer.test(model, testloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 100%|██████████| 20/20 [00:01<00:00, 18.97it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/loss': 0.21773627400398254, 'test/node_auroc': 0.5057910680770874}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 20/20 [00:01<00:00, 13.93it/s]


[{'test/node_auroc': 0.5057910680770874, 'test/loss': 0.21773627400398254}]

In [31]:
model(testloader.dataset[0])

tensor([0.2043, 0.2238, 0.2403, 0.2031, 0.1956, 0.2434, 0.2203, 0.2313, 0.1749,
        0.1914, 0.1813, 0.1733, 0.1978, 0.1908, 0.1956, 0.1811, 0.1759, 0.1822,
        0.1890, 0.1953, 0.1979, 0.1713, 0.1875, 0.1917, 0.1998, 0.2033, 0.1700,
        0.1926, 0.1887, 0.1942, 0.1873, 0.1849, 0.1889, 0.2039, 0.2081, 0.1938],
       grad_fn=<SelectBackward>)