In [1]:
# System imports
import os
import sys
from pprint import pprint as pp
from time import time as tt

# External imports
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import torch.nn as nn
from torch_scatter import scatter_add
import torch.nn.functional as F

import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import clear_output
from IPython.display import HTML, display

%matplotlib inline

sys.path.append("..")

# Get rid of RuntimeWarnings, gross
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

from lightning_modules.GNN.utils import make_mlp

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load Data

In [2]:
input_dir= "/global/cscratch1/sd/danieltm/ExaTrkX/trackml-codalab/embedding_processed/1_pt_cut_endcaps_unweighted_augmented/train"

In [3]:
num_events = 10
all_events = os.listdir(input_dir)
loaded_events = [torch.load(os.path.join(input_dir,event)) for event in all_events[:num_events]]

In [4]:
train_loader = DataLoader(loaded_events, batch_size=1, shuffle=True)

## Model Definitions

In [64]:
class MPNN_Network(nn.Module):
    """
    A message-passing graph network which takes a graph with:
    - bi-directional edges
    - node features, no edge features

    and applies the following modules:
    - a graph encoder (no message passing)
    - recurrent edge and node networks
    - an edge classifier
    """

    def __init__(self, input_dim, hidden_node_dim, in_layers, node_layers, edge_layers,
                 n_graph_iters=1, layer_norm=True):
        super(MPNN_Network, self).__init__()
        self.n_graph_iters = n_graph_iters

        # The node encoder transforms input node features to the hidden space
        self.node_encoder = make_mlp(input_dim, [hidden_node_dim]*in_layers)

        # The edge network computes new edge features from connected nodes
        # self.edge_network = make_mlp(2*hidden_node_dim,
        #                              [hidden_edge_dim]*edge_layers,
        #                              layer_norm=layer_norm)

        # The node network computes new node features
        self.node_network = make_mlp(2*hidden_node_dim,
                                     [hidden_node_dim]*node_layers,
                                     layer_norm=layer_norm)

        # The edge classifier computes final edge scores
        self.edge_classifier = make_mlp(2*hidden_node_dim,
                                        [hidden_node_dim, hidden_node_dim, 1],
                                        output_activation=None)
        
#         self.conv1 = GCNConv(input_dim, hidden_node_dim).jittable()
#         self.conv2 = GCNConv(hidden_node_dim, hidden_node_dim).jittable()

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        start, end = edge_index[0], edge_index[1]

        # Encode the graph features into the hidden space
        x = self.node_encoder(x)

        src = x[end]
        index = start.unsqueeze(-1)
        in_messages = torch.zeros(x.shape, dtype=src.dtype).scatter_add(0, index.repeat((1,src.shape[1])), src) 
        
        src = x[start]
        index = end.unsqueeze(-1)
        out_messages = torch.zeros(x.shape, dtype=src.dtype).scatter_add(0, index.repeat((1,src.shape[1])), src) 
        
        aggr_messages = in_messages + out_messages
        
        #     # Compute new node features
        node_inputs = torch.cat([x, aggr_messages], dim=1)
        x = self.node_network(node_inputs)


        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_classifier(clf_inputs).squeeze(-1)

class Simple_Network(nn.Module):
    """
    A message-passing graph network which takes a graph with:
    - bi-directional edges
    - node features, no edge features

    and applies the following modules:
    - a graph encoder (no message passing)
    - recurrent edge and node networks
    - an edge classifier
    """

    def __init__(self, input_dim, hidden_node_dim, in_layers, node_layers, edge_layers,
                 n_graph_iters=1, layer_norm=True):
        super(Simple_Network, self).__init__()
        self.n_graph_iters = n_graph_iters

        # The node encoder transforms input node features to the hidden space
        self.node_encoder = make_mlp(input_dim, [hidden_node_dim]*in_layers)

        # The edge network computes new edge features from connected nodes
        # self.edge_network = make_mlp(2*hidden_node_dim,
        #                              [hidden_edge_dim]*edge_layers,
        #                              layer_norm=layer_norm)

        # The node network computes new node features
        self.node_network = make_mlp(hidden_node_dim,
                                     [hidden_node_dim]*node_layers,
                                     layer_norm=layer_norm)

        # The edge classifier computes final edge scores
        self.edge_classifier = make_mlp(2*hidden_node_dim,
                                        [hidden_node_dim, hidden_node_dim, 1],
                                        output_activation=None)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        start, end = edge_index[0], edge_index[1]

        # Encode the graph features into the hidden space
        x = self.node_encoder(x)

        # Loop over graph iterations
        for i in range(self.n_graph_iters):

            # Previous hidden state
            x0 = x

            # Sum edge features coming into each node
            # aggr_messages = scatter_add(x[end], start, dim=0, dim_size=x.shape[0]) + scatter_add(x[start], end, dim=0, dim_size=x.shape[0])

            # Compute new node features
            # node_inputs = torch.cat([x, aggr_messages], dim=1)
            x = self.node_network(x)

            # Residual connection
            x = x + x0

        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_classifier(clf_inputs).squeeze(-1)

## Training

In [58]:
def train(model, train_loader, optimizer):
    correct = 0
    total = 0
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        data = batch.to(device)
        pred = model(data.x, data.edge_index)
        loss = F.binary_cross_entropy_with_logits(pred.float(), data.y.float(), pos_weight=torch.tensor(weight))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += ((pred > 0.5) == (data.y > 0.5)).sum().item()
        total += len(pred)
    acc = correct/total
    return acc, total_loss

In [59]:
t_loss_v = []
t_acc_v = []
v_loss_v = []
v_acc_v = []
ep = 0

In [60]:
weight = 2
m_configs = {"input_dim": 3, "hidden_node_dim": 64, "in_layers": 3, "node_layers": 3, "edge_layers": 3, "n_graph_iters": 8, "layer_norm": True}
mpnn_model = MPNN_Network(**m_configs).to(device)
optimizer = torch.optim.Adam(mpnn_model.parameters(), lr=0.001, weight_decay=1e-3, amsgrad=True)
for epoch in range(10):
    ep += 1  
    mpnn_model.train()
    acc, total_loss = train(mpnn_model, train_loader, optimizer)
    t_loss_v.append(total_loss)
    t_acc_v.append(acc)


    print('Epoch: {}, Accuracy: {:.4f}'.format(ep, acc))

Epoch: 1, Accuracy: 0.8287
Epoch: 2, Accuracy: 0.8287
Epoch: 3, Accuracy: 0.8287
Epoch: 4, Accuracy: 0.8287
Epoch: 5, Accuracy: 0.8295
Epoch: 6, Accuracy: 0.8310
Epoch: 7, Accuracy: 0.8309
Epoch: 8, Accuracy: 0.8309
Epoch: 9, Accuracy: 0.8309
Epoch: 10, Accuracy: 0.8307


In [24]:
weight = 2
m_configs = {"input_dim": 3, "hidden_node_dim": 64, "in_layers": 3, "node_layers": 3, "edge_layers": 3, "n_graph_iters": 8, "layer_norm": True}
simple_model = Simple_Network(**m_configs).to(device)
optimizer = torch.optim.Adam(simple_model.parameters(), lr=0.001, weight_decay=1e-3, amsgrad=True)
for epoch in range(10):
    ep += 1  
    simple_model.train()
    acc, total_loss = train(simple_model, train_loader, optimizer)
    t_loss_v.append(total_loss)
    t_acc_v.append(acc)

    print('Epoch: {}, Accuracy: {:.4f}'.format(ep, acc))

Epoch: 2, Accuracy: 0.8221
Epoch: 3, Accuracy: 0.8280
Epoch: 4, Accuracy: 0.8291
Epoch: 5, Accuracy: 0.8257
Epoch: 6, Accuracy: 0.8403


KeyboardInterrupt: 

## Scatter_add testing

In [21]:
nodes = torch.rand((10,2))
edges = torch.randint(10, (2, 6))

In [22]:
nodes

tensor([[0.3532, 0.5300],
        [0.3755, 0.8037],
        [0.3916, 0.2947],
        [0.1680, 0.2666],
        [0.1674, 0.5123],
        [0.0464, 0.4204],
        [0.2450, 0.2764],
        [0.2899, 0.7279],
        [0.2734, 0.2199],
        [0.5983, 0.8264]])

In [23]:
edges

tensor([[7, 6, 4, 1, 8, 3],
        [2, 2, 9, 6, 9, 7]])

In [55]:
src = nodes[edges[0]]
index = edges[1].unsqueeze(-1)
aggr_messages = torch.zeros(nodes.shape, dtype=src.dtype).to(device).scatter_add(0, index, src) 

In [30]:
src

tensor([[0.2899, 0.7279],
        [0.2450, 0.2764],
        [0.1674, 0.5123],
        [0.3755, 0.8037],
        [0.2734, 0.2199],
        [0.1680, 0.2666]])

In [42]:
index

tensor([[7],
        [6],
        [4],
        [1],
        [8],
        [3]])

In [56]:
torch.zeros(nodes.shape, dtype=src.dtype).to(device).scatter_add(0, index.repeat((1,src.shape[1])), src) 

tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.5350, 1.0043],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.3755, 0.8037],
        [0.1680, 0.2666],
        [0.0000, 0.0000],
        [0.4408, 0.7322]])

In [43]:
index.repeat((1,2))

tensor([[7, 7],
        [6, 6],
        [4, 4],
        [1, 1],
        [8, 8],
        [3, 3]])

In [32]:
aggr_messages

tensor([[0.0000, 0.0000],
        [0.3755, 0.0000],
        [0.0000, 0.0000],
        [0.1680, 0.0000],
        [0.1674, 0.0000],
        [0.0000, 0.0000],
        [0.2450, 0.0000],
        [0.2899, 0.0000],
        [0.2734, 0.0000],
        [0.0000, 0.0000]])

In [52]:
scatter_add(nodes[edges[0]], edges[1], dim=0)

tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.5350, 1.0043],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.3755, 0.8037],
        [0.1680, 0.2666],
        [0.0000, 0.0000],
        [0.4408, 0.7322]])

## Onnx Testing

In [65]:
script_module = torch.jit.script(MPNN_Network(**m_configs))

In [66]:
script_module

RecursiveScriptModule(
  original_name=MPNN_Network
  (node_encoder): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Linear)
    (1): RecursiveScriptModule(original_name=ReLU)
    (2): RecursiveScriptModule(original_name=Linear)
    (3): RecursiveScriptModule(original_name=ReLU)
    (4): RecursiveScriptModule(original_name=Linear)
    (5): RecursiveScriptModule(original_name=ReLU)
  )
  (node_network): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Linear)
    (1): RecursiveScriptModule(original_name=LayerNorm)
    (2): RecursiveScriptModule(original_name=ReLU)
    (3): RecursiveScriptModule(original_name=Linear)
    (4): RecursiveScriptModule(original_name=LayerNorm)
    (5): RecursiveScriptModule(original_name=ReLU)
    (6): RecursiveScriptModule(original_name=Linear)
    (7): RecursiveScriptModule(original_name=LayerNorm)
    (8): RecursiveScriptModule(original_name=ReLU)
  )
  (edge_

In [67]:
traced_script_module = torch.jit.trace(mpnn_model, input_data)

In [68]:
traced_script_module

MPNN_Network(
  original_name=MPNN_Network
  (node_encoder): Sequential(
    original_name=Sequential
    (0): Linear(original_name=Linear)
    (1): ReLU(original_name=ReLU)
    (2): Linear(original_name=Linear)
    (3): ReLU(original_name=ReLU)
    (4): Linear(original_name=Linear)
    (5): ReLU(original_name=ReLU)
  )
  (node_network): Sequential(
    original_name=Sequential
    (0): Linear(original_name=Linear)
    (1): LayerNorm(original_name=LayerNorm)
    (2): ReLU(original_name=ReLU)
    (3): Linear(original_name=Linear)
    (4): LayerNorm(original_name=LayerNorm)
    (5): ReLU(original_name=ReLU)
    (6): Linear(original_name=Linear)
    (7): LayerNorm(original_name=LayerNorm)
    (8): ReLU(original_name=ReLU)
  )
  (edge_classifier): Sequential(
    original_name=Sequential
    (0): Linear(original_name=Linear)
    (1): ReLU(original_name=ReLU)
    (2): Linear(original_name=Linear)
    (3): ReLU(original_name=ReLU)
    (4): Linear(original_name=Linear)
  )
)

In [70]:
traced_script_module(input_data[0], input_data[1])

tensor([-1.7495, -1.7519, -1.4925,  ...,  0.0104,  0.2420, -1.3107],
       grad_fn=<SqueezeBackward1>)

## TensorRT Testing

In [71]:
example_data = loaded_events[0]
input_data = (example_data.x, example_data.edge_index)

In [72]:
ONNX_FILE_PATH = "simple_model.onnx"
torch.onnx.export(mpnn_model, input_data, ONNX_FILE_PATH, input_names=["input"],
                  output_names=["output"], export_params=True)

  "Passing an tensor of different rank in execution will be incorrect.")
  "Passing an tensor of different rank in execution will be incorrect.")
  "intended to be used with dynamic input shapes, please use opset version 11 to export the model.")


In [73]:
import tensorrt

In [74]:
tensorrt.__version__

'7.2.2.3'

In [77]:
assert tensorrt.Builder(tensorrt.Logger())

TypeError: pybind11::init(): factory function returned nullptr