In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import git

import uproot as ut
import awkward as ak
import numpy as np
import math
import vector
import sympy as sp

import re
from tqdm import tqdm
import timeit

sys.path.append( git.Repo('.', search_parent_directories=True).working_tree_dir )
from utils import *

import utils.torchUtils as gnn

In [2]:
import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx
import networkx as nx

from torch_geometric.typing import Adj, PairTensor
from typing import Callable, Optional, Union, Tuple
from torch.nn import Linear, Module
from torch.nn.functional import softmax, relu, sigmoid

In [3]:
def fully_connected(n_nodes, device=None):
    arange = torch.arange(n_nodes, device=device)
    rows = torch.repeat_interleave(arange, n_nodes).to(device=device)
    cols = torch.repeat_interleave(arange[None], n_nodes, dim=0).reshape(-1).to(device=device)
    return torch.stack([rows,cols]).to(device=device)

def batch_fully_connected(n_nodes, batch_size, device=None):
    edge_index = fully_connected(n_nodes, device=device)
    edge_index = torch.repeat_interleave(edge_index[None], batch_size, dim=0).to(device=device)
    edge_index = edge_index + n_nodes*torch.arange(batch_size).reshape(-1,1,1).to(device=device)
    rows, cols = edge_index[:,0], edge_index[:,1]
    return torch.stack([rows.flatten(), cols.flatten()]).to(device=device)

In [4]:
@torch.jit.script
def batched_pool(s : Tensor, batch_ptr : Tensor) -> Tuple[Tensor, Tensor]:
    n_in, n_out = s.shape
    batch_size = batch_ptr.shape[0] - 1
    batch_pool = torch.zeros(n_in, n_out*batch_size).to(s.device)
    batch_pool_mask = torch.zeros(n_in, n_out*batch_size).to(s.device)
    nodes = torch.stack([torch.arange(batch_size),batch_ptr[:-1], batch_ptr[1:]]).to(s.device)
    for node in nodes.T:
        i, lo, hi = node[0], node[1], node[2]
        batch_pool[lo:hi,n_out*i:n_out*(i+1)] = s[lo:hi]
        batch_pool_mask[lo:hi,n_out*i:n_out*(i+1)] = 1
    return batch_pool, batch_pool_mask

In [5]:
@torch.jit.script
def flatten_pool(pool_matrix : Tensor, nclusters : int, batch_ptr : Tensor) -> Tensor:
    batch_size = batch_ptr.shape[0] - 1
    flat_pool = torch.zeros(batch_ptr[-1], nclusters)
    nodes = torch.stack([torch.arange(batch_size),batch_ptr[:-1], batch_ptr[1:]])
    for node in nodes.T:
        i, lo, hi = node[0], node[1], node[2]
        flat_pool[lo:hi] = pool_matrix[lo:hi,nclusters*i:nclusters*(i+1)]
    return flat_pool

In [63]:
class DiffPool(torch.nn.Module):
    def __init__(self, n_in_node, n_in_edge, n_out,  n_out_clusters):
        super().__init__()
        self.n_in_node = n_in_node
        self.n_in_edge = n_in_edge
        self.n_out = n_out
        self.n_out_clusters = n_out_clusters
        
        self.norm = gnn.layers.GCNNormalize()
        self.node_embeding = torch.nn.Linear(n_in_node, n_out)
        self.edge_embeding = torch.nn.Linear(n_in_edge, n_out)
        self.adj_embeding = torch.nn.Linear(n_out, 2)
        self.pooling = torch.nn.Linear(n_out, n_out_clusters)

    def _graph_diff_pool(self, x, adj):
        x = self.s.T @ x

        adj = adj.movedim(2, 0)
        adj = adj @ self.s
        adj = self.s.T @ adj
        adj = adj.movedim(0, 2)
        
        edge_index = fully_connected(self.n_out_clusters, x.device)
        edge_attr = adj[edge_index[0],edge_index[1]]
        return x, edge_index, edge_attr
        
    def _batch_diff_pool(self, x, adj, batch_ptr):
        self.s, self.mask = batched_pool(self.s , batch_ptr)

        x = self.s.T @ x

        adj = adj.movedim(2, 0)
        adj = adj @ self.s
        adj = self.s.T @ adj
        adj = adj.movedim(0, 2)


        edge_index = batch_fully_connected(self.n_out_clusters, batch_ptr.shape[0]-1, x.device)
        edge_attr = adj[edge_index[0],edge_index[1]]
        return x, edge_index, edge_attr

    def _loss(self, adj):
        self.link_loss = adj - torch.matmul(self.s, self.s.T)
        self.link_loss = torch.norm(self.link_loss, p=2)
        self.link_loss = self.link_loss/adj.numel()

        self.ent_loss = (-self.s * torch.log(self.s + 1e-15)).sum(dim=-1).mean()

    def flatten_pool(self, s):
        if self.batch_ptr is None: return s
        return flatten_pool(s, self.n_out_clusters, self.batch_ptr)


    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: Tensor, batch_ptr : Optional[Tensor] = None) -> Tensor:
        x, edge_attr = self.norm(x, edge_index, edge_attr)

        edge_attr = self.edge_embeding(edge_attr)

        adj_attr = torch.sparse_coo_tensor(edge_index, edge_attr).to_dense()
        adj = softmax(self.adj_embeding(edge_attr), dim=-1)[:,1]
        adj = torch.sparse_coo_tensor(edge_index, adj).to_dense()
        
        x = self.node_embeding( torch.matmul(adj, x) )
        self.s = softmax( relu( self.pooling(x) ), dim=-1 )
        self.mask = torch.ones_like(self.s)

        if batch_ptr is None or True:
            x, edge_index, edge_attr = self._graph_diff_pool(x, adj_attr)
        else:
            x, edge_index, edge_attr = self._batch_diff_pool(x, adj_attr, batch_ptr)
        self.batch_ptr = batch_ptr

        self._loss(adj)

        return x, edge_index, edge_attr

In [64]:
class GoldenModule(Module):
    def __init__(self, n_in_node=5, n_in_edge=4, layers=32):
        super().__init__()
        
        self.conv1 = gnn.layers.GCNConvMSG(n_in_node=n_in_node, n_in_edge=n_in_edge, n_out=layers)
        self.relu1 = gnn.layers.GCNRelu()

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: Tensor) -> Tensor:
        x, edge_attr = self.conv1(x, edge_index, edge_attr)
        x, edge_attr = self.relu1(x, edge_index, edge_attr)
        return x, edge_attr

In [87]:
from pytorch_lightning import LightningModule
from torch.nn.functional import nll_loss, log_softmax
from torchmetrics.functional import accuracy, auroc

class Model(LightningModule):
    def __init__(self,n_in_node=5, n_in_edge=4):
        super().__init__()
        self.norm = gnn.layers.GCNNormalize()

        self.golden1 = GoldenModule(n_in_node=n_in_node, n_in_edge=n_in_edge, layers=32)
        self.h_pool = DiffPool(32, 3*32, 16, 5)

        self.golden2 = GoldenModule(n_in_node=16, n_in_edge=16, layers=32)
        self.y_pool = DiffPool(32, 3*32, 16, 3)

        self.golden3 = GoldenModule(n_in_node=16, n_in_edge=16, layers=32)
        self.x_pool = DiffPool(32, 3*32, 16, 2)

    def forward(self, data : Data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        batch_ptr = data.ptr if hasattr(data, 'ptr') else None

        x, edge_attr = self.golden1(x, edge_index, edge_attr)
        # x, edge_index, edge_attr = self.h_pool(x, edge_index, edge_attr, batch_ptr)
        # batch_ptr = self.h_pool.n_out_clusters*torch.arange(data.num_graphs) if hasattr(data, 'ptr') else None
        
        # x, edge_attr = self.golden2(x, edge_index, edge_attr)
        # x, edge_index, edge_attr = self.y_pool(x, edge_index, edge_attr, batch_ptr)
        # batch_ptr = self.y_pool.n_out_clusters*torch.arange(data.num_graphs) if hasattr(data, 'ptr') else None
        
        x, edge_attr = self.golden3(x, edge_index, edge_attr)
        x, edge_index, edge_attr = self.x_pool(x, edge_index, edge_attr, batch_ptr)

        # node_h_pool = self.h_pool.s
        # node_h_mask = self.h_pool.mask
        # node_y_pool = node_h_pool @ self.y_pool.s 
        # node_y_mask = node_h_mask @ self.y_pool.mask
        node_x_pool = self.x_pool.s 
        node_x_mask = self.x_pool.mask
        
        return (
            # node_h_pool[node_h_mask>0].reshape(-1, self.h_pool.n_out_clusters), 
            # node_y_pool[node_y_mask>0].reshape(-1, self.y_pool.n_out_clusters), 
            node_x_pool[node_x_mask>0].reshape(-1, self.x_pool.n_out_clusters), 
        )

    def step(self, batch, batch_idx, tag):
        # h_pool, y_pool, x_pool = self(batch)
        x_pool, = self(batch)

        # h_pool_loss = nll_loss( log_softmax(h_pool), batch.h_pool_tru)/5. # + self.h_pool.link_loss + self.h_pool.ent_loss
        # y_pool_loss = nll_loss( log_softmax(y_pool), batch.y_pool_tru)/3. # + self.y_pool.link_loss + self.y_pool.ent_loss
        x_pool_loss = nll_loss( log_softmax(x_pool), batch.x_pool_tru)/2. # + self.x_pool.link_loss + self.x_pool.ent_loss

        loss = x_pool_loss

        # h_auroc = auroc(h_pool[:,1], batch.x_pool_tru)
        # self.log(f'{tag}/auroc', h_auroc)


        self.log(f'{tag}/loss',loss)
        return loss
    def training_step(self, batch, batch_idx): return self.step(batch ,batch_idx, 'train')
    def validation_step(self, batch, batch_idx): return self.step(batch ,batch_idx, 'valid')
    def test_step(self, batch, batch_idx):  return self.step(batch ,batch_idx, 'test')

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [88]:
from torch_geometric.loader import DataLoader, DenseDataLoader

def load_dataset(fn='data/MX_1200_MY_500-training', template=None, shuffle=False):
    dataset = gnn.Dataset(fn,transform=template.transform)
    training, testing = gnn.train_test_split(dataset[:3000], 0.33)
    training, validation = gnn.train_test_split(training, 0.5)

    batch_size = 1
    trainloader = DataLoader(training, batch_size=batch_size, shuffle=shuffle, num_workers=8)
    validloader = DataLoader(validation, batch_size=batch_size, shuffle=shuffle, num_workers=8)
    testloader = DataLoader(testing, batch_size=batch_size, shuffle=shuffle, num_workers=8)

    return trainloader, validloader, testloader

In [89]:
def get_pool_weight(pool_tru, nclusters):
    n_tot = pool_tru.shape[0]
    n_tru = (pool_tru > 0).sum()
    n_fak = (pool_tru == 0).sum()

    weight = torch.Tensor([n_fak]+[n_tru]*(nclusters-1))
    weight = n_tot - weight
    return weight


class PoolTruth(BaseTransform):
    def __call__(self, data : Data) -> Data:
        data.h_pool_tru = (data.node_id+1)//2
        data.h_pool_weight = get_pool_weight(data.h_pool_tru, 5)

        data.y_pool_tru = (data.h_pool_tru+1)//2
        data.y_pool_weight = get_pool_weight(data.y_pool_tru, 3)

        data.x_pool_tru = (data.y_pool_tru+1)//2
        data.x_pool_weight = get_pool_weight(data.x_pool_tru, 2)

        return data

In [90]:
template = gnn.Dataset('data/template',make_template=True, transform=gnn.Transform(PoolTruth()))
trainloader, validloader, testloader = load_dataset(template=template)

In [91]:
model = Model()

In [92]:
model(trainloader.dataset[0])

(tensor([[0.4844, 0.1289, 0.1289, 0.1289, 0.1289],
         [0.6348, 0.0913, 0.0913, 0.0913, 0.0913],
         [0.6236, 0.0941, 0.0941, 0.0941, 0.0941],
         [0.5667, 0.1083, 0.1083, 0.1083, 0.1083],
         [0.7602, 0.0599, 0.0599, 0.0599, 0.0599],
         [0.5948, 0.1013, 0.1013, 0.1013, 0.1013],
         [0.8289, 0.0428, 0.0428, 0.0428, 0.0428],
         [0.3959, 0.1510, 0.1510, 0.1510, 0.1510],
         [0.4888, 0.1278, 0.1278, 0.1278, 0.1278]], grad_fn=<ViewBackward>),
 tensor([[0.6841, 0.1579, 0.1579],
         [0.7763, 0.1119, 0.1119],
         [0.7694, 0.1153, 0.1153],
         [0.7346, 0.1327, 0.1327],
         [0.8531, 0.0735, 0.0735],
         [0.7518, 0.1241, 0.1241],
         [0.8952, 0.0524, 0.0524],
         [0.6299, 0.1850, 0.1850],
         [0.6868, 0.1566, 0.1566]], grad_fn=<ViewBackward>),
 tensor([[0.0743, 0.9257],
         [0.0526, 0.9474],
         [0.0542, 0.9458],
         [0.0624, 0.9376],
         [0.0345, 0.9655],
         [0.0584, 0.9416],
         [0.

In [93]:
from pytorch_lightning import Trainer

trainer = Trainer(gpus=0, max_epochs=1)
trainer.fit(model, trainloader, validloader)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name    | Type         | Params
-----------------------------------------
0 | norm    | GCNNormalize | 0     
1 | golden1 | GoldenModule | 480   
2 | h_pool  | DiffPool     | 2.2 K 
3 | golden2 | GoldenModule | 1.6 K 
4 | y_pool  | DiffPool     | 2.2 K 
5 | golden3 | GoldenModule | 1.6 K 
6 | x_pool  | DiffPool     | 2.1 K 
-----------------------------------------
10.1 K    Trainable params
0         Non-trainable params
10.1 K    Total params
0.041     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 2009/2009 [01:19<00:00, 25.26it/s, loss=0.294, v_num=73]


In [77]:
trainer.test(model, testloader)

Testing: 100%|█████████▉| 988/991 [00:28<00:00, 37.23it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/loss': 3.3573358058929443}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 991/991 [00:29<00:00, 34.00it/s]


[{'test/loss': 3.3573358058929443}]

In [86]:
model(testloader.dataset[0])

(tensor([[1.9570e-01, 1.9570e-01, 2.1720e-01, 1.9570e-01, 1.9570e-01],
         [1.9570e-01, 1.9570e-01, 2.1720e-01, 1.9570e-01, 1.9570e-01],
         [1.9570e-01, 1.9570e-01, 2.1719e-01, 1.9570e-01, 1.9570e-01],
         [1.9572e-01, 1.9572e-01, 2.1711e-01, 1.9572e-01, 1.9572e-01],
         [9.9998e-01, 3.9986e-06, 3.9986e-06, 3.9986e-06, 3.9986e-06],
         [9.9992e-01, 1.9643e-05, 1.9643e-05, 1.9643e-05, 1.9643e-05],
         [9.9999e-01, 2.2264e-06, 2.2264e-06, 2.2264e-06, 2.2264e-06],
         [9.9999e-01, 2.2861e-06, 2.2861e-06, 2.2861e-06, 2.2861e-06],
         [1.9570e-01, 1.9570e-01, 2.1719e-01, 1.9570e-01, 1.9570e-01]],
        grad_fn=<ViewBackward>),
 tensor([[0.0652, 0.0652, 0.8695],
         [0.0652, 0.0652, 0.8695],
         [0.0652, 0.0652, 0.8695],
         [0.0652, 0.0652, 0.8695],
         [0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3334],
         [0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333],
         [0.0652, 0.0652, 0.8695]], grad_fn=<Vi