# Draw! - Live Demo (with TensorRT)

The task of the robot is to drive around without collision to draw on a canvas.

The trained network should estimate in which direction it is save to drive:

- forward : both tracks have same speed value
- right : left track has set speed, right track is set 0
- left : right track has set speed, left track is set 0
- turn_right : left track has set speed, right has set -speed
- turn_left : right track has set speed, left has set -speed

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import PIL.Image
import numpy as np
import traitlets
from IPython.display import display
import ipywidgets.widgets as widgets
from jetbot import Robot, Camera, bgr8_to_jpeg
from torch2trt import TRTModule
import time

device = torch.device('cuda')

model_trt = TRTModule()
model_trt.load_state_dict(torch.load('best_model_trt.pth'))

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().half()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda().half()

normalize = torchvision.transforms.Normalize(mean, std)

def preprocess(image):
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device).half()
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

camera = Camera.instance(width=224, height=224, fps=15)
image = widgets.Image(format='jpeg', width=224, height=224)
speed_slider = widgets.FloatSlider(description='speed', min=0.0, max=1, value=0.0, step=0.1, orientation='horizontal')

camera_link = traitlets.dlink((camera, 'value'), (image, 'value'), transform=bgr8_to_jpeg)

class_names = ['forward', 'left', 'right', 'turn_left', 'turn_right']
class_widgets = [widgets.Label(value=f"{class_name}: 0%", layout=widgets.Layout(width='auto'), style={'description_width': 'initial'}) for class_name in class_names]

display(widgets.VBox([widgets.HBox([image, widgets.VBox([speed_slider] + class_widgets)])]))

robot = Robot()

def update(change):
    global speed_slider, robot, class_widgets
    x = change['new'] 
    x = preprocess(x)
    y = model_trt(x)

    # We apply the `softmax` function to normalize the output vector so it sums to 1 (which makes it a probability distribution)
    y = F.softmax(y, dim=1)

    # Get the predicted class and its probability
    pred_class = torch.argmax(y, dim=1).item()
    pred_prob = torch.max(y).item()

    # Update the class widgets
    for i, widget in enumerate(class_widgets):
        prob = y[0, i].item()
        indicator = "<===" if i == pred_class else ""
        widget.value = f"{class_names[i]}: {prob * 100:.2f}% {indicator}"

    speed = speed_slider.value

    # Define a threshold for making a decision
    threshold = 0.5

    if pred_prob > threshold:
        if pred_class == 0:  # forward
            robot.forward(speed)
        elif pred_class == 1:  # left
            robot.left(speed)
        elif pred_class == 2:  # right
            robot.right(speed)
        elif pred_class == 3:  # turn_left
            robot.set_motors(-speed, speed)
        elif pred_class == 4:  # turn_right
            robot.set_motors(speed, -speed)
        else:
            robot.stop()
    else:
        robot.stop()

update({'new': camera.value})  # we call the function once to initialize
camera.observe(update, names='value') # update function is called asynchronously

## Stop the robot

In [None]:
camera.unobserve_all()
camera.stop()
time.sleep(1)
robot.stop()