In [1]:
import torch
import numpy as np
from utils.dataset import SimpleIterDataset
import os

In [2]:
data_config_name = 'data/ak8_points_pf_sv_mass_regression.yaml'
import networks.particle_net_pf_sv_mass_regression as network_module
export_onnx = "../sonic-models/models/particlenet_AK8_MassRegression/1/model.onnx"

In [3]:
model_state_dict = data_config_name.replace('.yaml','.pt')
jit_model_save = data_config_name.replace('.yaml','_ragged_gpu_jit.pt')
onnx_model = data_config_name.replace('.yaml','.onnx')

In [4]:
data_config = SimpleIterDataset([], data_config_name, for_training=False).config

In [5]:
data_config

<utils.data.config.DataConfig at 0x10a7f03a0>

In [6]:
model, model_info = network_module.get_model(data_config, for_inference=True)

In [7]:
model = torch.jit.script(model)

In [8]:
model_info

{'input_names': ['pf_points',
  'pf_features',
  'pf_mask',
  'sv_points',
  'sv_features',
  'sv_mask',
  'batch_shapes_pf_points',
  'batch_shapes_pf_features',
  'batch_shapes_pf_mask',
  'batch_shapes_sv_points',
  'batch_shapes_sv_features',
  'batch_shapes_sv_mask'],
 'input_shapes': {'pf_points': (200,),
  'pf_features': (2500,),
  'pf_mask': (100,),
  'sv_points': (20,),
  'sv_features': (110,),
  'sv_mask': (10,),
  'batch_shapes_pf_points': (1, 2),
  'batch_shapes_pf_features': (1, 2),
  'batch_shapes_pf_mask': (1, 2),
  'batch_shapes_sv_points': (1, 2),
  'batch_shapes_sv_features': (1, 2),
  'batch_shapes_sv_mask': (1, 2)},
 'output_names': ['output'],
 'dynamic_axes': {'pf_points': {0: 'n_pf'},
  'pf_features': {0: 'n_pf'},
  'pf_mask': {0: 'n_pf'},
  'sv_points': {0: 'n_sv'},
  'sv_features': {0: 'n_sv'},
  'sv_mask': {0: 'n_sv'},
  'batch_shapes_pf_points': {0: 'N'},
  'batch_shapes_pf_features': {0: 'N'},
  'batch_shapes_pf_mask': {0: 'N'},
  'batch_shapes_sv_points': {

In [9]:
model.to('cuda')

RecursiveScriptModule(
  original_name=ParticleNetTagger
  (pf_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (sv_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (pn): RecursiveScriptModule(
    original_name=ParticleNet
    (edge_convs): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=EdgeConvBlock


In [10]:
model.load_state_dict(torch.load(model_state_dict, map_location=torch.device('cuda')))

In [11]:
model.eval()

RecursiveScriptModule(
  original_name=ParticleNetTagger
  (pf_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (sv_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (pn): RecursiveScriptModule(
    original_name=ParticleNet
    (edge_convs): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=EdgeConvBlock


In [12]:
print(model)

RecursiveScriptModule(
  original_name=ParticleNetTagger
  (pf_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (sv_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (pn): RecursiveScriptModule(
    original_name=ParticleNet
    (edge_convs): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=EdgeConvBlock


In [13]:
from train import onnx

In [14]:
model = model.cpu()

In [15]:
model.eval()

RecursiveScriptModule(
  original_name=ParticleNetTagger
  (pf_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (sv_conv): RecursiveScriptModule(
    original_name=FeatureConv
    (conv): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=BatchNorm1d)
      (1): RecursiveScriptModule(original_name=Conv1d)
      (2): RecursiveScriptModule(original_name=BatchNorm1d)
      (3): RecursiveScriptModule(original_name=ReLU)
    )
  )
  (pn): RecursiveScriptModule(
    original_name=ParticleNet
    (edge_convs): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=EdgeConvBlock


In [16]:
    inputs = tuple(
        torch.ones(model_info['input_shapes'][k], dtype=torch.float32) if ('batch_shapes_' not in k) else (torch.tensor([[len(data_config.input_dicts[k.replace('batch_shapes_', '')]), data_config.input_length[k.replace('batch_shapes_', '')]]], dtype=torch.int32)) for k in model_info['input_names'])

In [17]:
inputs

(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1.]),
 tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 

In [18]:
print (model_info['input_names'])

['pf_points', 'pf_features', 'pf_mask', 'sv_points', 'sv_features', 'sv_mask', 'batch_shapes_pf_points', 'batch_shapes_pf_features', 'batch_shapes_pf_mask', 'batch_shapes_sv_points', 'batch_shapes_sv_features', 'batch_shapes_sv_mask']


In [19]:
torch.onnx.export(model, inputs, export_onnx,
                  input_names=model_info['input_names'],
                  output_names=model_info['output_names'],
                  dynamic_axes=model_info.get('dynamic_axes', None),
                  opset_version=13)



In [20]:
model(*inputs)

tensor([[1.]], grad_fn=<SoftmaxBackward0>)