In [None]:
%cd "../"

In [2]:
import torch
import onnxruntime as ort
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import zoom
from networks.UNet_plusplus.UNet_plusplus import UNet_plusplus
from networks.RotCAtt.RotCAtt import RotCAtt
from networks.RotCAtt.config import get_config

### Convert

In [3]:
# Trained model
model = torch.load("TorchModels/model2.pth")
input = torch.rand(3, 1, 128, 128).to(torch.float32).cuda()
output = model(input)

In [None]:
# Untrained model (just architecture)
model = RotCAtt(get_config()).cuda()
input = torch.rand(3, 1, 128, 128).to(torch.float32).cuda()
output = model(input)

In [4]:
# convert
torch.onnx.export(model, input, "OnnxModels/onnx_model4.onnx",
                  input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

### Inference

In [None]:
import tensorrt
print(tensorrt.__version__)

In [None]:
tensorrt.__file__

In [None]:
# Input and preprocessing
image = np.load('samples/0001_0170.npy')
img_size = 128
x, y = image.shape
if x != img_size and y != img_size:
    inputs = zoom(image, (img_size / x, img_size / y), order=0)
    
inputs = inputs[np.newaxis,np.newaxis,:,:]

# Inference
providers = [
    ('TensorrtExecutionProvider', {
        'device_id': 0,                       # Select GPU to execute
        'trt_max_workspace_size': 2147483648, # Set GPU memory usage limit
        'trt_fp16_enable': True,              # Enable FP16 precision for faster inference  
        'trt_engine_cache_enable': True,
        'trt_engine_cache_path': 'Engine/onnx_model_sim_engine_2',
        'trt_engine_hw_compatible' : True
    }),
    ('CUDAExecutionProvider', {
        'device_id': 0,
        'arena_extend_strategy': 'kNextPowerOfTwo',
        'gpu_mem_limit': 2 * 1024 * 1024 * 1024,
        'cudnn_conv_algo_search': 'EXHAUSTIVE',
        'do_copy_in_default_stream': True,
    })
]

ort_session = ort.InferenceSession("OnnxModels/onnx_model_sim2.onnx", providers=providers)
inp = {ort_session.get_inputs()[0].name: inputs}
out = ort_session.run(None, inp)

out = np.argmax(out[0], axis=1)
plt.imshow(out[0])
plt.show()