In [None]:
import sys, os, json

sys.path.append(os.environ["BUILD_WORKSPACE_DIRECTORY"])
from loguru import logger as LOGGER
import torch
import tensorrt as trt
import onnx
import onnxsim
from pathlib import Path

MODEL_DIR = "/home/gabriel/models/doors"
from lib.ml.inference.backends.trt import Backend

## Setup

In [None]:
ts_model_path = Path(
    os.path.join(
        MODEL_DIR,
        "2022-10-26_americold_modesto_0011_cha/2022-10-26_americold_modesto_0011_cha.pt",
    )
)
device = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
example_input = (
    torch.randn(16, 3, 224, 224, requires_grad=False).to(device).half()
)

## Convert torchscript to onnx

In [None]:
onnx_model_path = ts_model_path.with_suffix('.onnx')
extra_files = {"model_config": ""}
ts_model = torch.jit.load(ts_model_path, _extra_files=extra_files).eval().half()
example_output = ts_model(example_input)
output_names = ["output0"]
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}}  # shape(1,3,224,224)
dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)

with torch.no_grad():
    torch.onnx.export(
        ts_model,
        example_input,
        onnx_model_path,
        verbose=False,
        opset_version=12,
        do_constant_folding=True,
        input_names=['images'],
        output_names=["output0"]
        example_outputs=example_output,
        dynamic_axes=dynamic,
    )
    onnx_model = onnx.load(onnx_model_path)
    onnx.checker.check_model(onnx_model)
    onnx.save(onnx_model, onnx_model_path)
    
    onnx_model, check = onnxsim.simplify(onnx_model)
    assert check, 'assert check failed'
    onnx.save(onnx_model, onnx_model_path)

## Convert onnx to engine

In [None]:
workspace = 4
engine_model_path = ts_model_path.with_suffix(".engine")
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx_model_path)):
    raise RuntimeError(f"failed to load ONNX file: {onnx_model_path}")
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
    LOGGER.info(
        f'TensorRT input "{inp.name}" with shape{inp.shape} {inp.dtype}'
    )
for out in outputs:
    LOGGER.info(
        f'TensorRT output "{out.name}" with shape{out.shape} {out.dtype}'
    )
if example_input.shape[0] <= 1:
    LOGGER.warning(
        f"TensorRT WARNING dynamic model requires maximum --batch-size argument"
    )
profile = builder.create_optimization_profile()
for inp in inputs:
    profile.set_shape(
        inp.name,
        (1, *example_input.shape[1:]),
        (max(1, example_input.shape[0] // 2), *example_input.shape[1:]),
        example_input.shape,
    )
config.add_optimization_profile(profile)
if builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(
    engine_model_path, "wb"
) as t:
    t.write(engine.serialize())

In [None]:
b = Backend(engine_model_path, device)

In [None]:
b.bindings