In [None]:
# To compile CLIP text model, I had to convert input text ids to float before calling argmax()
# Otherwise ONNX parsing fails due to "This version of TensorRT does not support INT32 ArgMin/ArgMax input data."

In [53]:
import os
from typing import Dict
import json
import itertools
import time

import clip
import onnx
from onnxruntime import InferenceSession
import torch
import numpy as np
import tensorrt as trt

from tensorrt_inference.backend import (
    build_engine, save_engine, load_engine
)

In [None]:
model, _ = clip.load("ViT-B/32")
model = model.eval()
model.to("cuda:0")

## Inputs

In [4]:
image_tensors = torch.stack([torch.rand((3, 224, 224))]).to("cuda:0")
text_tensors = torch.stack([torch.randint(0, 100, (77,), dtype=torch.int32)]).to("cuda:0")

In [5]:
image_embeddings = model.visual(image_tensors.half())
text_embeddings = model(text_tensors)

## ONNX export

In [6]:
def convert_to_onnx(
    model_pytorch, output_path: str, inputs_pytorch: Dict[str, torch.Tensor], quantization: bool
) -> None:
    if quantization:
        try:
            from pytorch_quantization.nn import TensorQuantizer
        except ImportError:
            raise ImportError(
                "It seems that pytorch-quantization is not yet installed. "
                "It is required when you enable the quantization flag and use CUDA device."
                "Please find installation instructions on "
                "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization or use:\n"
                "pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\\&"
                "subdirectory=tools/pytorch-quantization/"
            )

        TensorQuantizer.use_fb_fake_quant = True

    dynamic_axis = dict()
    for k in inputs_pytorch.keys():
        dynamic_axis[k] = {0: "batch_size", 1: "sequence"}
    dynamic_axis["output"] = {0: "batch_size"}
    with torch.no_grad():
        torch.onnx.export(
            model_pytorch,
            args=tuple(inputs_pytorch.values()),
            f=output_path,
            opset_version=12,
            do_constant_folding=True,
            input_names=list(inputs_pytorch.keys()),
            output_names=["output"],
            dynamic_axes=dynamic_axis,
            training=torch.onnx.TrainingMode.EVAL,
            verbose=False,
        )
    if quantization:
        TensorQuantizer.use_fb_fake_quant = False

In [7]:
clip_vit_onnx_path = "/home/g.racic/clip_vit_onnx.onnx"
clip_transformer_onnx_path = "/home/g.racic/clip_transformer_onnx.onnx"

In [8]:
sample_image = {"image": image_tensors.half()}

In [9]:
convert_to_onnx(model.visual, clip_vit_onnx_path, sample_image, True)

E0118 07:45:24.499392 140626572773184 amp_wrapper.py:31] AMP is not avaialble.


In [10]:
sample_text = {"text": text_tensors}

In [11]:
convert_to_onnx(model, clip_transformer_onnx_path, sample_text, True)

  "If indices include negative values, the exported graph will produce incorrect results.")


## ONNX inference

In [12]:
sess_vit = InferenceSession(clip_vit_onnx_path)
print("The model expects input shape: ", sess_vit.get_inputs()[0].shape)

The model expects input shape:  ['batch_size', 'sequence', 224, 224]


In [None]:
sess_vit.run(None, {"image": image_tensors.half().cpu().numpy()})
# 150 ms ± 456 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)

In [13]:
sess_transformer = InferenceSession(clip_transformer_onnx_path)
print("The model expects input shape: ", sess_transformer.get_inputs()[0].shape)

The model expects input shape:  ['batch_size', 'sequence']


In [None]:
sess_transformer.run(None, {"text": text_tensors.cpu().numpy()})
# 82.5 ms ± 515 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)

In [14]:
model_proto = onnx.load(clip_transformer_onnx_path)

## TRT compilation

In [80]:
trt_logger = trt.Logger(trt.Logger.VERBOSE)
runtime = trt.Runtime(trt_logger)

In [81]:
# Min, optim, max shapes used for TRT optimizer
text_tensor_shapes = [(1, 77), (1, 77), (1, 77)]
image_tensor_shapes = [(1, 3, 224, 224), (1, 3, 224, 224), (1, 3, 224, 224)]

In [32]:
text_engine = build_engine(
    runtime=runtime,
    onnx_file_path=clip_transformer_onnx_path,
    logger=trt_logger,
    min_shape=text_tensor_shapes[0],
    optimal_shape=text_tensor_shapes[1],
    max_shape=text_tensor_shapes[2],
    workspace_size=10000 * 1024 * 1024,
    fp16=True,
    int8=False
)

In [33]:
trt_text_clip_path = "/home/g.racic/trt_text_clip"

In [34]:
save_engine(engine=engine, engine_file_path=trt_text_clip_path)

In [35]:
trt_text_model = load_engine(
    runtime=runtime, engine_file_path=trt_text_clip_path
)

In [83]:
image_engine = build_engine(
    runtime=runtime,
    onnx_file_path=clip_vit_onnx_path,
    logger=trt_logger,
    min_shape=image_tensor_shapes[0],
    optimal_shape=image_tensor_shapes[1],
    max_shape=image_tensor_shapes[2],
    workspace_size=10000 * 1024 * 1024,
    fp16=True,
    int8=False
)

In [84]:
trt_image_clip_path = "/home/g.racic/trt_image_clip"

In [85]:
save_engine(engine=engine, engine_file_path=trt_image_clip_path)

In [86]:
trt_image_model = load_engine(
    runtime=runtime, engine_file_path=trt_image_clip_path
)

## Benchmark

In [47]:
def benchmark(model_fn, input_data, batch_size, nwarmup=50, nruns=1000):
    _data = itertools.cycle(input_data)
    print("Warm up ...")
    with torch.no_grad():
        for n in range(nwarmup):
            model_fn(next(_data))
    torch.cuda.synchronize()
    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(1, nruns+1):
            start_time = time.time()
            model_fn(next(_data))
            torch.cuda.synchronize()
            end_time = time.time()
            timings.append(end_time - start_time)
            if i%100==0:
                print('Iteration %d/%d, avg batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))
 
    print('Average throughput: %.2f example/second'%(batch_size/np.mean(timings)))

In [54]:
benchmark(trt_image_model, [{"image": image_tensors.half().cpu().numpy()}], 1)

Warm up ...
Start timing ...
Iteration 100/1000, avg batch time 1.52 ms
Iteration 200/1000, avg batch time 1.49 ms
Iteration 300/1000, avg batch time 1.48 ms
Iteration 400/1000, avg batch time 1.47 ms
Iteration 500/1000, avg batch time 1.47 ms
Iteration 600/1000, avg batch time 1.47 ms
Iteration 700/1000, avg batch time 1.46 ms
Iteration 800/1000, avg batch time 1.46 ms
Iteration 900/1000, avg batch time 1.46 ms
Iteration 1000/1000, avg batch time 1.46 ms
Average throughput: 684.90 example/second


In [55]:
benchmark(model.visual, [image_tensors.half()], 1)

Warm up ...
Start timing ...
Iteration 100/1000, avg batch time 7.48 ms
Iteration 200/1000, avg batch time 7.49 ms
Iteration 300/1000, avg batch time 7.49 ms
Iteration 400/1000, avg batch time 7.49 ms
Iteration 500/1000, avg batch time 7.49 ms
Iteration 600/1000, avg batch time 7.49 ms
Iteration 700/1000, avg batch time 7.49 ms
Iteration 800/1000, avg batch time 7.49 ms
Iteration 900/1000, avg batch time 7.49 ms
Iteration 1000/1000, avg batch time 7.49 ms
Average throughput: 133.56 example/second


In [56]:
benchmark(trt_text_model, [{"text": text_tensors.cpu().numpy()}], 1)

Warm up ...
Start timing ...
Iteration 100/1000, avg batch time 1.47 ms
Iteration 200/1000, avg batch time 1.42 ms
Iteration 300/1000, avg batch time 1.40 ms
Iteration 400/1000, avg batch time 1.39 ms
Iteration 500/1000, avg batch time 1.39 ms
Iteration 600/1000, avg batch time 1.38 ms
Iteration 700/1000, avg batch time 1.38 ms
Iteration 800/1000, avg batch time 1.37 ms
Iteration 900/1000, avg batch time 1.37 ms
Iteration 1000/1000, avg batch time 1.37 ms
Average throughput: 729.70 example/second


In [57]:
benchmark(model, [text_tensors], 1)

Warm up ...
Start timing ...
Iteration 100/1000, avg batch time 7.39 ms
Iteration 200/1000, avg batch time 7.39 ms
Iteration 300/1000, avg batch time 7.40 ms
Iteration 400/1000, avg batch time 7.40 ms
Iteration 500/1000, avg batch time 7.40 ms
Iteration 600/1000, avg batch time 7.41 ms
Iteration 700/1000, avg batch time 7.41 ms
Iteration 800/1000, avg batch time 7.41 ms
Iteration 900/1000, avg batch time 7.41 ms
Iteration 1000/1000, avg batch time 7.41 ms
Average throughput: 134.95 example/second
