# JETARM - Build TensorRT model

Use this notebook to build the TensortRT model.

Initialize the PyTorch model.

In [1]:
import torch
import torchvision

model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 4)
model = model.cuda().eval()

Load the trained weights from the ``model_resnet18.pth`` file.

In [2]:
model.load_state_dict(torch.load('model_resnet18.pth'))

<All keys matched successfully>

Transfer the model weights to the GPU device.

In [3]:
device = torch.device('cuda')

Convert and optimize the model using torch2trt for faster inference with TensorRT. This optimization process can take a couple minutes to complete.

In [4]:
from torch2trt import torch2trt

data = torch.zeros((1, 3, 224, 224)).cuda()

model_trt = torch2trt(model, [data], fp16_mode=True)

Save the optimized model.

In [5]:
torch.save(model_trt.state_dict(), 'model_trt.pth')