# Road Following by Classification - Live Demo with TensorRT

In this notebook, we will try to make the JetBot to follow the desired road by using the trained model.

## TensorRTモデルの作成
First we define a DNN model. This needs to be identical to what used for training.

In [1]:
import torch
import torchvision
from torch2trt import torch2trt

model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 3)
model = model.cuda().eval().half()

model.load_state_dict(torch.load('best_model_resnet18.pth'))

<All keys matched successfully>

In [6]:
device = torch.device('cuda')

In [3]:
from torch2trt import torch2trt

data = torch.zeros((1, 3, 224, 224)).cuda().half()
model_trt = torch2trt(model, [data], fp16_mode=True)

torch.save(model_trt.state_dict(), 'best_model_resnet18_trt.pth')

## TensorRTモデルのロード
Next we need to upload the `best_model.pth` in the file browser.

Then load the parameters on the model from the `best_model.pth`.

In [5]:
import torch
import torchvision
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load('best_model_resnet18_trt.pth'))

<All keys matched successfully>

## Preprocessing Function

Now we create a function for preprocessing image data taken by the camera. This is very similar to what we have done in the collision avoidance example.

In [6]:
import cv2
import numpy as np
import torchvision.transforms as transforms

#mean = 255.0 * np.array([0.485, 0.456, 0.406])
#stdev = 255.0 * np.array([0.229, 0.224, 0.225])

#normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    # x = normalize(x)
    x = x.cuda().half()
    x = x[None, ...]
    x = x/255.
    return x

## Camera Instance
Now we create a camera instance.

In [7]:
import traitlets
from IPython.display import display
import ipywidgets.widgets as widgets
from jetbot import Camera, bgr8_to_jpeg

image = widgets.Image(format='jpeg', width=300, height=300)
camera = Camera.instance(width=224, height=224)
camera_link = traitlets.dlink((camera, 'value'), (image, 'value'), transform=bgr8_to_jpeg)
display(image)

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x02\x01\x0…

## Inference
Let's try to make an inference. This process takes for a while for the first time because it needs to load a model.

In [8]:
import torch.nn.functional as F

with torch.no_grad():
    model.eval()
    x = camera.value
    x = preprocess(x)
    y = model_trt(x)
    y = F.softmax(y, dim=1)
    y_idx = torch.argmax(y, dim=1).item()
    print(y_idx)

0


## Robot Instance
Create the robot instance to drive the motors.

In [9]:
from jetbot import Robot

robot = Robot()

## Define Actions
Create functions of move forward, turn left and turn right.

In [22]:
def move_forward():
    robot.set_motors(0.3, 0.3)

def turn_left():
    robot.set_motors(0.1, 0.25)

def turn_right():
    robot.set_motors(0.25, 0.1)
    
actions_dict = {0:"Go Forward", 1:"Turn Left", 2:"Turn Right"}

## Run JetBot
Run JetBot with a loop.

In [23]:
import time
t0 = time.time()

display(image)
steps = 1000

softmax = torch.nn.Softmax(dim=1)

with torch.no_grad():
    model.eval()

    for i in range(steps):
        x = camera.value
        x = preprocess(x)
        y = model_trt(x)
        y = softmax(y)
        y_idx = torch.argmax(y, dim=1).item()

        if y_idx == 0: move_forward()
        elif y_idx == 1: turn_left()
        elif y_idx == 2: turn_right()
        
        now = time.time()
        dt = now-t0
        t0 = now
        FPS = 1/dt
        
        print(f"\rStep:{i+1}/{steps}   Action:{actions_dict[y_idx]}   FPS:{FPS:.1f}", end="")

robot.stop()

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x02\x01\x0…

Step:1000/1000   Action:Turn Right   FPS:63.6

If you are done, stop the robot and the camera.

In [24]:
robot.stop()
camera.stop()