In [None]:
# switch to the project directory
%cd ..
# working directory should be ../pdi

In [None]:
import sys
import os
module_path = os.path.abspath('src')

if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from pdi.constants import (
    PARTICLES_DICT,
    TARGET_CODES
)
from pdi.data.constants import GROUP_ID_KEY

In [None]:
import torch
import torch.nn as nn
from pdi.data.preparation import FeatureSetPreparation
from pdi.models import AttentionModel
from pdi.data.types import Split
import onnx

device = torch.device("cpu")


data_preparation = FeatureSetPreparation()
(train_loader, ) = data_preparation.prepare_dataloaders(1, 0, [Split.TRAIN])

input_data, _, data_dict = next(iter(train_loader))
print(input_data)
gid = data_dict.get(GROUP_ID_KEY)

dummy_input = input_data.to(device)
print(dummy_input.shape)

In [None]:
input_name = 'input'
output_name = 'output'

In [None]:
from pdi.data.config import MODEL_NAME

model_dir = MODEL_NAME
os.makedirs(f"onnx/Proposed/{model_dir}", exist_ok=True)
# for target_code in TARGET_CODES:
for target_code in [211, 2212, 321]:
    name_code = str(target_code)
    name_code = name_code.replace("-", "0")
    load_path = f"models/Proposed/{model_dir}/{PARTICLES_DICT[target_code]}.pt"
    export_path = f"onnx/Proposed/{model_dir}/attention_model_{name_code}.onnx"
    saved_model = torch.load(load_path)
    model = AttentionModel(*saved_model["model_args"]).to(device)
    model.thres = saved_model["model_thres"]
    model.load_state_dict(saved_model["state_dict"])
    model_with_sigmoid = nn.Sequential(model, nn.Sigmoid())

    torch.onnx.export(model_with_sigmoid, dummy_input, export_path, 
                      export_params=True,
                      opset_version=14,
                      do_constant_folding=True,
                      input_names=[input_name],
                      output_names=[output_name],
                      dynamic_axes={input_name: {0: 'batch size'}})

    onnx_model = onnx.load(export_path)
    onnx.checker.check_model(onnx_model)

#### Test

In [None]:
import onnxruntime as rt
print(export_path)
sess = rt.InferenceSession(export_path)

In [None]:
test_input = torch.rand(2, dummy_input.shape[1])
print(test_input)

input_name = sess.get_inputs()[0].name

res = sess.run(None, {input_name: test_input.cpu().detach().numpy()})
print(res)