In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import git

import uproot as ut
import awkward as ak
import numpy as np
import math
import vector
import sympy as sp

import re
from tqdm import tqdm
import timeit

sys.path.append( git.Repo('.', search_parent_directories=True).working_tree_dir )
from utils import *

import utils.torchUtils as gnn

In [2]:
from torch_geometric import transforms

In [52]:
import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform

In [99]:
class MotherNode(BaseTransform):
    def __init__(self, level, y=-1, pairs=[], connects=-1):
        self.level = level
        self.y = y
        self.y1, self.y2 = pairs
        self.connects = connects
        
    def __call__(self, data: Data) -> Data:
        num_nodes, (row, col) = data.num_nodes, data.edge_index
        edge_type = data.get('edge_type', torch.zeros_like(row))
        node_type = data.get('node_type', torch.zeros(num_nodes))

        connected_nodes = node_type == self.connects
        arange = torch.arange(num_nodes, device=row.device)[connected_nodes]
        num_connected = len(arange)
        
        full = row.new_full((num_connected, ), num_nodes)
        row = torch.cat([row, arange, full], dim=0)
        col = torch.cat([col, full, arange], dim=0)
        edge_index = torch.stack([row, col], dim=0)

        new_type = edge_type.new_full((num_connected, ), self.level)
        edge_type = torch.cat([edge_type, new_type, new_type], dim=0)
        
        new_type = node_type.new_full((1, ), self.level)
        node_type = torch.cat([node_type, new_type], dim=0)

        for key, value in data.items():
            if key == 'edge_index' or key == 'edge_type':
                continue

            if isinstance(value, Tensor):
                dim = data.__cat_dim__(key, value)
                size = list(value.size())

                fill_value = None
                if data.is_edge_attr(key):
                    size[dim] = 2 * num_connected
                    fill_value = 0.
                elif data.is_node_attr(key):
                    size[dim] = 1
                    fill_value = 0.
                elif key == 'y':
                    size[dim] = 1
                    fill_value = self.y

                if fill_value is not None:
                    new_value = value.new_full(size, fill_value)
                    data[key] = torch.cat([value, new_value], dim=dim)
        

        decays = (((data.y[row] == self.y) & (data.y[col] == self.y1)) | ((data.y[row] == self.y1) & (data.y[col] == self.y)) |
                  ((data.y[row] == self.y) & (data.y[col] == self.y2)) | ((data.y[row] == self.y2) & (data.y[col] == self.y)))
        data.edge_y = torch.where(decays, 1, data.edge_y)
        data.edge_index = edge_index
        data.edge_type = edge_type
        data.node_type = node_type
        if 'num_nodes' in data:
            data.num_nodes = data.num_nodes + 1

        return data


In [233]:
class XYY_YToHH_8b(BaseTransform):
    def __init__(self):
        self.feynman = gnn.Transform(
            MotherNode(1, y=9, pairs=[1,2], connects=0),
            MotherNode(1, y=10, pairs=[3,4], connects=0),
            MotherNode(1, y=11, pairs=[5,6], connects=0),
            MotherNode(1, y=12, pairs=[7,8], connects=0),

            # MotherNode(2, y=13, pairs=[9,10], connects=1),
            # MotherNode(2, y=14, pairs=[11,12], connects=1),

            # MotherNode(3, y=15, pairs=[13,14], connects=2),
        )
    
    def __call__(self, data : Data) -> Data:
        return self.feynman(data)

In [234]:
class SimplifyTruth(BaseTransform):
    def __call__(self, data : Data) -> Data:
        data.y = 1*(data.y > 0)
        data.edge_y = 1*(data.edge_y > 0)
        return data

In [235]:
template = gnn.Dataset('data/template',make_template=True,scale='raw', transform=gnn.Transform(XYY_YToHH_8b(),SimplifyTruth()))

In [236]:
dataset = gnn.Dataset('data/MX_700_MY_300-training',transform=template.transform)

In [237]:

from torch_geometric.loader import DataLoader

training,validation = gnn.train_test_split(dataset,0.99)
print(len(training))

validation, _ = gnn.train_test_split(validation, 0.99)
print(len(validation))
trainloader = DataLoader(training,batch_size=100,shuffle=True,num_workers=16)
validloader = DataLoader(validation,batch_size=100,shuffle=True,num_workers=16)


449
444


In [238]:
model = gnn.GoldenGCN(template)

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

trainer = pl.Trainer(
    max_epochs=10,
    gpus=1 if gnn.config.useGPU else 0,
    # callbacks=[EarlyStopping(monitor="val_loss")],
)

In [None]:
fit = trainer.fit(model, trainloader, validloader)


  | Name         | Type          | Params
-----------------------------------------------
0 | conv1        | GCNConvMSG    | 480   
1 | relu1        | GCNRelu       | 0     
2 | conv2        | GCNConvMSG    | 20.6 K
3 | relu2        | GCNRelu       | 0     
4 | node_linear1 | NodeLinear    | 8.3 K 
5 | edge_linear1 | EdgeLinear    | 24.6 K
6 | relu3        | GCNRelu       | 0     
7 | node_linear2 | NodeLinear    | 130   
8 | edge_linear2 | EdgeLinear    | 130   
9 | log_softmax  | GCNLogSoftmax | 0     
-----------------------------------------------
54.2 K    Trainable params
0         Non-trainable params
54.2 K    Total params
0.217     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s, loss=32.1, v_num=1]   


In [207]:
graph = dataset[3]

In [208]:
node_o, edge_o = model.predict(graph)

In [224]:
h_nodes = torch.where(graph.node_type == 1)[0]
y_nodes = torch.where(graph.node_type == 2)[0]

In [231]:
h1_edge_mask = (graph.edge_index == h_nodes[0]).sum(axis=0) == 1

In [232]:
edge_o[h1_edge_mask]

tensor([0.4153, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 0.0000, 0.9884, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 0.5932, 1.0000, 0.5932, 1.0000])

In [194]:
graph

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False, False, False, False, False, False, False, False, 

In [175]:
dataset[0].y

tensor([1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])