In [13]:
%load_ext autoreload
%autoreload 2

import os
import sys
import git

import uproot as ut
import awkward as ak
import numpy as np
import math
import vector
import sympy as sp

import re
from tqdm import tqdm
import timeit
import torch

sys.path.append( git.Repo('.', search_parent_directories=True).working_tree_dir )
from utils import *

import utils.torchUtils as gnn

plt.style.use('science')
plt.rcParams["figure.figsize"] = (10,10)
plt.rcParams['font.size'] =  15

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
tree = Tree(fc.passthrough.TESTING_MX_700_MY_400)

In [15]:
ak.argmax(tree.n_jet)

1662

In [16]:
testing = gnn.Dataset('data/signal-testing',transform=gnn.to_uptri_graph)
len(testing)

100000

In [17]:
import torch.nn.functional as F
from torch_geometric.nn import Linear, GCNConv
from utils.torchUtils.layers import EdgeConv, EdgeConvONNX

class GCN(torch.nn.Module):
    def __init__(self,nn1_out=32,nn2_out=64,for_onnx=False):
        super().__init__()
        
        EdgeConvLayer = EdgeConv
        if for_onnx: EdgeConvLayer = EdgeConvONNX
        
        nn1 = torch.nn.Sequential(
            Linear(2*testing.num_node_features +
                   testing.num_edge_features, nn1_out),
            torch.nn.ELU()
        )

        self.conv1 = EdgeConvLayer(nn1, edge_aggr=None, return_with_edges=True)

        nn2 = torch.nn.Sequential(
            Linear(5*nn1_out, nn2_out),
            torch.nn.ELU()
        )

        self.conv2 = EdgeConvLayer(nn2, edge_aggr=None, return_with_edges=True)

        self.edge_seq = torch.nn.Sequential(
            Linear(3*nn2_out, 2),
        )

        self.node_seq = torch.nn.Sequential(
            Linear(nn2_out, 2),
        )

    def forward(self, x, edge_index, edge_x):
        # x, edge_index = data.x, data.edge_index
        
        
        x, edge_x = self.conv1(x, edge_index, edge_x)
        x, edge_x = self.conv2(x, edge_index, edge_x)
        x, edge_x = self.node_seq(x), self.edge_seq(edge_x)

        return F.log_softmax(x, dim=1), F.log_softmax(edge_x, dim=1)

In [18]:
# model = gnn.GCN.load_from_checkpoint('models/graph_classifier/lightning_logs/version_1/checkpoints/epoch=19-step=31999.ckpt',dataset=testing)
# model_for_onnx = gnn.GCN.load_from_checkpoint('models/graph_classifier/lightning_logs/version_1/checkpoints/epoch=19-step=31999.ckpt',dataset=testing,for_onnx=False)


In [25]:
model = GCN()
model_for_onnx = GCN()

In [26]:
def get_inputs(graph,pad_nodes=None):
    node_x = graph.x 
    edge_index = graph.edge_index 
    edge_x = graph.edge_attr 
    
    if pad_nodes: 
        node_x = F.pad(node_x,(0,0,0,pad_nodes-graph.num_nodes))
    return (node_x,edge_index,edge_x)

In [27]:
def compare_outputs(outputs_1,outputs_2):
    with torch.no_grad():
        for out0,out1 in zip(outputs_1,outputs_2):
            print(out0.shape,out1.shape)
            print( ((out0-out1)**2).sum() )

In [28]:
input_values = get_inputs(testing[1662])
input_names = ['node_x','edge_index','edge_x']
output_names = ['node_y','edge_y']
list(map(lambda t:t.shape,input_values))


[torch.Size([11, 5]), torch.Size([2, 55]), torch.Size([55, 1])]

In [30]:
org_output = model(*input_values)
new_output = model_for_onnx(*input_values)

compare_outputs(org_output,new_output)

torch.Size([11, 2]) torch.Size([11, 2])
tensor(0.0517)
torch.Size([55, 2]) torch.Size([55, 2])
tensor(0.3866)


In [31]:
org_output[0],new_output[0]

(tensor([[-0.7587, -0.6317],
         [-0.7718, -0.6202],
         [-0.7373, -0.6509],
         [-0.7238, -0.6634],
         [-0.7542, -0.6356],
         [-0.6956, -0.6907],
         [-0.6560, -0.7318],
         [-0.7134, -0.6733],
         [-0.6413, -0.7478],
         [-0.6842, -0.7022],
         [-0.6464, -0.7422]], grad_fn=<LogSoftmaxBackward>),
 tensor([[-0.7228, -0.6643],
         [-0.6952, -0.6911],
         [-0.7135, -0.6732],
         [-0.7113, -0.6753],
         [-0.7158, -0.6710],
         [-0.7563, -0.6337],
         [-0.7391, -0.6492],
         [-0.6819, -0.7045],
         [-0.7014, -0.6850],
         [-0.6978, -0.6886],
         [-0.6889, -0.6974]], grad_fn=<LogSoftmaxBackward>))

In [35]:
import torch.onnx

for param in model.parameters():
    param.requires_grad = False
with torch.no_grad():
    torch.onnx.export(model,
                    input_values,
                    "gnn-model.onnx",
                    input_names=input_names,
                    output_names=output_names,
                    opset_version=12,
                    dynamic_axes={
                        'node_x': {0: 'n_nodes'},
                        'edge_index': {1: 'n_edges'},
                        'edge_x': {0: 'n_edges'},
                        'node_y': {0: 'n_nodes'},
                        'edge_y': {0: 'n_edges'}}
                    )


RuntimeError: ONNX export failed on an operator with unrecognized namespace torch_scatter::scatter_max. If you are trying to export a custom operator, make sure you registered it with the right domain and version.

In [34]:
import onnx
onnx_model = onnx.load('gnn-model.onnx')
onnx.checker.check_model(onnx_model)


ModuleNotFoundError: No module named 'onnx'

In [72]:
import onnxruntime as ort

ort_sess = ort.InferenceSession('gnn-model.onnx')

In [80]:
g = testing[0]
print(g)
input_values = get_inputs(g,pad_nodes=20)
output_values = model(*input_values)
onnx_output = ort_sess.run(None,{name:value.numpy() for name,value in zip(input_names,input_values)})


Data(x=[7, 5], edge_index=[2, 21], edge_attr=[21, 1], y=[7], edge_y=[21])
torch.Size([20, 2]) (20, 2)
tensor(0.)
torch.Size([21, 2]) (21, 2)
tensor(2.4869e-14)


In [83]:
np.exp(onnx_output[0])

array([[0.51562   , 0.48437998],
       [0.5112753 , 0.4887247 ],
       [0.50532067, 0.4946793 ],
       [0.5060546 , 0.49394533],
       [0.5042418 , 0.49575827],
       [0.5060317 , 0.4939682 ],
       [0.4981752 , 0.5018248 ],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998],
       [0.51562   , 0.48437998]], dtype=float32)