In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import git

import uproot as ut
import awkward as ak
import numpy as np
import math
import vector
import sympy as sp

import re
from tqdm import tqdm
import timeit

sys.path.append( git.Repo('.', search_parent_directories=True).working_tree_dir )
from utils import *

import utils.torchUtils as gnn

In [2]:
import torch, torch_geometric
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx
import networkx as nx

from torch_geometric.typing import Adj, PairTensor
from typing import Callable, Optional, Union, Tuple
from torch.nn import Linear, Module
from torch.nn.functional import softmax, relu, sigmoid

In [3]:
def get_pool_weight(pool_tru, nclusters):
    n_tot = pool_tru.shape[0]
    n_tru = (pool_tru > 0).sum()
    n_fak = (pool_tru == 0).sum()

    weight = torch.Tensor([n_fak]+[n_tru]*(nclusters-1))
    weight = n_tot - weight
    return weight


class PoolTruth(BaseTransform):
    def __call__(self, data : Data) -> Data:
        data.h_pool_tru = (data.node_id+1)//2
        data.h_pool_weight = get_pool_weight(data.h_pool_tru, 5)

        data.y_pool_tru = (data.h_pool_tru+1)//2
        data.y_pool_weight = get_pool_weight(data.y_pool_tru, 3)

        data.x_pool_tru = (data.y_pool_tru+1)//2
        data.x_pool_weight = get_pool_weight(data.x_pool_tru, 2)

        return data

In [4]:
from torch_geometric.loader import DataLoader, DenseDataLoader

def load_dataset(fn='data/MX_1200_MY_500-training', template=None, shuffle=False):
    dataset = gnn.Dataset(fn,transform=template.transform)
    training, testing = gnn.train_test_split(dataset[:3000], 0.33)
    training, validation = gnn.train_test_split(training, 0.5)

    batch_size = 1
    trainloader = DataLoader(training, batch_size=batch_size, shuffle=shuffle, num_workers=8)
    validloader = DataLoader(validation, batch_size=batch_size, shuffle=shuffle, num_workers=8)
    testloader = DataLoader(testing, batch_size=batch_size, shuffle=shuffle, num_workers=8)

    return trainloader, validloader, testloader

template = gnn.Dataset('data/template',make_template=True, transform=PoolTruth())
trainloader, validloader, testloader = load_dataset(template=template)

In [5]:
dataset = trainloader.dataset

In [6]:
data = dataset[0]

In [7]:
from torch.nn.functional import one_hot

h_pool_s = one_hot(data.h_pool_tru).float()

In [8]:
h_pool_s.T @ data.x

tensor([[-1.7700, -1.8723,  3.9155,  4.0423, -4.1032],
        [ 2.4032,  1.2530,  2.3653, -0.2414,  1.1859],
        [ 0.3197,  0.4655,  0.6710, -0.5141, -1.0215],
        [-0.0967,  0.7818, -0.3043, -0.2889,  2.0713],
        [-0.5433, -0.6029, -0.3427,  1.4341, -0.8838]])

In [9]:
h_pool_s

tensor([[0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]])

In [10]:
perm = torch.eye(data.num_nodes)[torch.randperm(data.num_nodes)]

In [14]:
perm_x = torch.matmul(perm, data.x)
edge_attr = data.edge_attr.reshape(data.num_nodes, data.num_nodes, -1).movedim(2,0)
edge_attr = torch.matmul(perm, torch.matmul(edge_attr, perm.T)).movedim(0,2)
perm_edge_attr = edge_attr.reshape(data.num_nodes**2, -1)

In [15]:
conv = gnn.layers.GCNConvMSG(14, 2)

In [16]:
x_0, edge_attr_0 = conv(data.x, data.edge_index, data.edge_attr)
perm_x_0, perm_edge_attr_0 = conv(perm_x, data.edge_index, perm_edge_attr)

In [17]:
x_0 - perm.T @ perm_x_0

tensor([[ 0.0000e+00,  0.0000e+00],
        [ 9.5367e-07,  1.1921e-07],
        [ 2.3842e-07,  0.0000e+00],
        [ 2.3842e-07, -2.3842e-07],
        [ 9.5367e-07, -2.3842e-07],
        [ 4.7684e-07,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-9.5367e-07, -7.1526e-07],
        [ 9.5367e-07,  0.0000e+00]], grad_fn=<SubBackward0>)

In [18]:
edge_attr_0.shape

torch.Size([100, 2])

In [19]:
edge_attr_0.reshape(x_0.shape[0],x_0.shape[0],-1)[:,:,0] - perm.T @ perm_edge_attr_0.reshape(x_0.shape[0],x_0.shape[0], -1)[:,:,0] @ perm

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  9.5367e-07,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -9.5367e-07,  9.5367e-07],
        [ 0.0000e+00,  1.9073e-06,  9.5367e-07,  9.5367e-07,  1.9073e-06,
          1.9073e-06,  9.5367e-07,  9.5367e-07,  0.0000e+00,  1.9073e-06],
        [ 0.0000e+00,  9.5367e-07,  4.7684e-07,  9.5367e-07,  9.5367e-07,
          9.5367e-07,  4.7684e-07,  0.0000e+00, -9.5367e-07,  9.5367e-07],
        [ 0.0000e+00,  9.5367e-07,  9.5367e-07,  4.7684e-07,  1.9073e-06,
          9.5367e-07,  4.7684e-07,  9.5367e-07, -9.5367e-07,  9.5367e-07],
        [ 9.5367e-07,  1.9073e-06,  9.5367e-07,  1.9073e-06,  1.9073e-06,
          1.9073e-06,  9.5367e-07,  0.0000e+00,  0.0000e+00,  3.8147e-06],
        [ 0.0000e+00,  1.9073e-06,  9.5367e-07,  9.5367e-07,  1.9073e-06,
          9.5367e-07,  9.5367e-07,  9.5367e-07,  0.0000e+00,  9.5367e-07],
        [ 0.0000e+00,  9.5367e-07,  4.7684e-07,  4.7684e-07,  0.0000e+00,
          9.5367e-07,  0.0000e+0