Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert model to TensorRT fails #206

Open
janblumenkamp opened this issue Sep 6, 2023 · 1 comment
Open

Convert model to TensorRT fails #206

janblumenkamp opened this issue Sep 6, 2023 · 1 comment

Comments

@janblumenkamp
Copy link

I am trying to convert the model to a torch TensorRT pre-compiled model for inference on an embedded device. I am using this script:

import torch_tensorrt
import torch

model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")

trt_ts_module = torch_tensorrt.compile(model,
    # If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
    inputs = [
        torch_tensorrt.Input( # Specify input object with shape and dtype
            min_shape=[1, 3, 224, 224],
            opt_shape=[1, 3, 224, 224],
            max_shape=[1, 3, 224, 224],
            # For static size shape=[1, 3, 224, 224]
            dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
    ],
    enabled_precisions = {torch.half}, # Run with FP16
)

torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedded Torchscript

But I get the following error:

Traceback (most recent call last):
  File "[...]/torchrt_compile_dino.py", line 6, in <module>
    trt_ts_module = torch_tensorrt.compile(model,
  File "[...]/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 132, in compile
    [...]
  File "/local/scratch/jb2270/miniconda3/envs/panoptes/lib/python3.10/site-packages/torch/jit/frontend.py", line 359, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/jb2270/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 290
    def forward(self, *args, is_training=False, **kwargs):
                                                 ~~~~~~~ <--- HERE
        ret = self.forward_features(*args, **kwargs)
        if is_training:

It would be great if the source files could be adapted so that a TensorRT export is possible.

@janblumenkamp
Copy link
Author

I managed to convert this after getting inspiration from the ONNX PR #129 (most notably this line). It is important to use torch.jit.trace to trace the model before compiling it with TensorRT:

import torch
import torch_tensorrt
from dinov2.models import vision_transformer as vits
from dinov2.layers import MemEffAttention #NestedTensorBlock as Block
from dinov2.layers.attention import Attention
from dinov2.layers.block import Block
from dinov2.models.vision_transformer import DinoVisionTransformer
from functools import partial
import time
import numpy as np

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

def benchmark(model, input_shape=(1, 3, 224, 224), dtype='fp32', nwarmup=50, nruns=500):
    input_data = torch.randn(input_shape)
    input_data = input_data.to("cuda")
    if dtype=='fp16':
        input_data = input_data.half()

    print("Warm up ...")
    with torch.no_grad():
        for _ in range(nwarmup):
            features = model(input_data)
    torch.cuda.synchronize()
    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(1, nruns+1):
            start_time = time.time()
            features = model(input_data)
            torch.cuda.synchronize()
            end_time = time.time()
            timings.append(end_time - start_time)
            if i%100==0:
                print('Iteration %d/%d, ave batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))

    print("Input shape:", input_data.size())
    print("Output features size:", features.size())

    print('Average batch time: %.2f ms'%(np.mean(timings)*1000))

class DinoV2Small(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = DinoVisionTransformer(
            img_size=518,
            patch_size=14,
            embed_dim=384,
            depth=12,
            num_heads=6,
            mlp_ratio=4,
            init_values=1.0,
            ffn_layer='mlp',
            block_chunks=0,
            block_fn=partial(Block, attn_class=Attention),
        )

        url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth'
        state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
        self.model.load_state_dict(state_dict)

    def forward(self, input):
        return self.model.forward_features(input)["x_norm_patchtokens"]

dev = torch.device('cuda')
inp = torch.rand(1, 3, 224, 224, device=dev)
model = DinoV2Small().to(dev).eval()
model_load = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14").to(dev).eval()

model_load_out = model_load.forward_features(inp)['x_norm_patchtokens']
model_out = model(inp)

assert ((model_load_out - model_out)**2).mean() < 1e-5

print("Loaded model")
benchmark(model_load)

traced_model = torch.jit.trace(model, inp)
print("Traced model")
benchmark(traced_model)

trt_ts_module = torch_tensorrt.compile(traced_model,
    inputs = [torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float)],
    enabled_precisions = {torch.float},
    truncate_long_and_double=True,
)

model_trt_out = trt_ts_module(inp)

assert ((model_load_out - model_trt_out)**2).mean() < 1e-3

print("compiled model")
benchmark(trt_ts_module)

torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")

This is based on this notebook.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant