-
Notifications
You must be signed in to change notification settings - Fork 702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert model to TensorRT fails #206
Comments
I managed to convert this after getting inspiration from the ONNX PR #129 (most notably this line). It is important to use import torch
import torch_tensorrt
from dinov2.models import vision_transformer as vits
from dinov2.layers import MemEffAttention #NestedTensorBlock as Block
from dinov2.layers.attention import Attention
from dinov2.layers.block import Block
from dinov2.models.vision_transformer import DinoVisionTransformer
from functools import partial
import time
import numpy as np
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
def benchmark(model, input_shape=(1, 3, 224, 224), dtype='fp32', nwarmup=50, nruns=500):
input_data = torch.randn(input_shape)
input_data = input_data.to("cuda")
if dtype=='fp16':
input_data = input_data.half()
print("Warm up ...")
with torch.no_grad():
for _ in range(nwarmup):
features = model(input_data)
torch.cuda.synchronize()
print("Start timing ...")
timings = []
with torch.no_grad():
for i in range(1, nruns+1):
start_time = time.time()
features = model(input_data)
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)
if i%100==0:
print('Iteration %d/%d, ave batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))
print("Input shape:", input_data.size())
print("Output features size:", features.size())
print('Average batch time: %.2f ms'%(np.mean(timings)*1000))
class DinoV2Small(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = DinoVisionTransformer(
img_size=518,
patch_size=14,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
init_values=1.0,
ffn_layer='mlp',
block_chunks=0,
block_fn=partial(Block, attn_class=Attention),
)
url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth'
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
self.model.load_state_dict(state_dict)
def forward(self, input):
return self.model.forward_features(input)["x_norm_patchtokens"]
dev = torch.device('cuda')
inp = torch.rand(1, 3, 224, 224, device=dev)
model = DinoV2Small().to(dev).eval()
model_load = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14").to(dev).eval()
model_load_out = model_load.forward_features(inp)['x_norm_patchtokens']
model_out = model(inp)
assert ((model_load_out - model_out)**2).mean() < 1e-5
print("Loaded model")
benchmark(model_load)
traced_model = torch.jit.trace(model, inp)
print("Traced model")
benchmark(traced_model)
trt_ts_module = torch_tensorrt.compile(traced_model,
inputs = [torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float)],
enabled_precisions = {torch.float},
truncate_long_and_double=True,
)
model_trt_out = trt_ts_module(inp)
assert ((model_load_out - model_trt_out)**2).mean() < 1e-3
print("compiled model")
benchmark(trt_ts_module)
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") This is based on this notebook. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I am trying to convert the model to a torch TensorRT pre-compiled model for inference on an embedded device. I am using this script:
But I get the following error:
It would be great if the source files could be adapted so that a TensorRT export is possible.
The text was updated successfully, but these errors were encountered: