# Draw! - Live Demo simple version (with TensorRT)

The task of the robot is to drive around without collision to draw on a canvas.

The trained network should estimate in which direction it is save to drive:

- forward : both tracks have same speed value
- right : left track has set speed, right track is set 0
- left : right track has set speed, left track is set 0
- turn_right : left track has set speed, right has set -speed
- turn_left : right track has set speed, left has set -speed

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import PIL.Image
import numpy as np
import traitlets
from jetbot import Robot, Camera
from torch2trt import TRTModule
import time

device = torch.device('cuda')

model_trt = TRTModule()
model_trt.load_state_dict(torch.load('best_model_trt.pth'))

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().half()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda().half()

def preprocess(image):
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device).half()
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

camera = Camera.instance(width=224, height=224, fps=15)

robot = Robot()

class_names = ['forward', 'left', 'right', 'turn_left', 'turn_right']

def move_robot(direction, speed):
    if direction == 'forward':
        robot.set_motors(speed, speed)
    elif direction == 'left':
        robot.set_motors(0, speed)
    elif direction == 'right':
        robot.set_motors(speed, 0)
    elif direction == 'turn_left':
        robot.set_motors(-speed, speed)
    elif direction == 'turn_right':
        robot.set_motors(speed, -speed)
    else:
        robot.stop()

def update(change):
    x = change['new']
    x = preprocess(x)
    y = model_trt(x)
    y = F.softmax(y, dim=1)

    pred_class = torch.argmax(y, dim=1).item()
    pred_prob = torch.max(y).item()

    speed = 1
    threshold = 0.35

    if pred_prob > threshold:
        move_robot(class_names[pred_class], speed)
    else:
        move_robot('stop', 0)

update({'new': camera.value})  # call the update function once to initialize
camera.observe(update, names='value')  # update function is called asynchronously

## Stop the robot

In [None]:
camera.unobserve_all()
camera.stop()
time.sleep(1)
robot.stop()