# Draw! - Live Demo (with TensorRT)

The task of the robot is to drive around without collision to draw on a canvas.

The trained network should estimate in which direction it is save to drive:

- forward : both tracks have same speed value
- right : left track has set speed, right track is set 0
- left : right track has set speed, left track is set 0
- turn_right : left track has set speed, right has set -speed
- turn_left : right track has set speed, left has set -speed

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import PIL.Image
import numpy as np
import traitlets
from IPython.display import display
import ipywidgets.widgets as widgets
from jetbot import Robot, Camera, bgr8_to_jpeg
from torch2trt import TRTModule
import time

device = torch.device('cuda')

model_trt = TRTModule()

# to load Resnet50 model:
#model_trt.load_state_dict(torch.load('best_model_trt_resnet50.pth'))

model_trt.load_state_dict(torch.load('best_model_trt.pth'))

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().half()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda().half()

normalize = torchvision.transforms.Normalize(mean, std)

def preprocess(image):
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device).half()
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

camera = Camera.instance(width=224, height=224, fps=15)

image = widgets.Image(format='jpeg', width=224, height=224)
speed_slider = widgets.FloatSlider(description='speed', min=0.0, max=1.0, value=0.0, step=0.05, orientation='horizontal')
threshold_slider = widgets.FloatSlider(description='threshold', min=0.0, max=1.0, value=0.35, step=0.01, orientation='horizontal')
alpha_slider = widgets.FloatSlider(description='alpha', min=0.0, max=1.0, value=0.5, step=0.01, orientation='horizontal')

camera_link = traitlets.dlink((camera, 'value'), (image, 'value'), transform=bgr8_to_jpeg)

class_names = ['forward', 'left', 'right', 'turn_left', 'turn_right']
class_widgets = [widgets.Label(value=f"{class_name}: 0%", layout=widgets.Layout(width='auto'), style={'description_width': 'initial'}) for class_name in class_names]

fps_widget = widgets.FloatText(description='FPS', value=0)

display(widgets.VBox([
    image, 
    widgets.VBox([speed_slider, threshold_slider, alpha_slider, fps_widget] + class_widgets)
]))

robot = Robot()

frame_counter = 0
start_time = time.time()

def move_robot(direction, speed):
    if direction == 'forward':
        robot.set_motors(speed * 0.6, speed * 0.6)
    elif direction == 'left':
        robot.set_motors(0, speed)
    elif direction == 'right':
        robot.set_motors(speed, 0)
    elif direction == 'turn_left':
        robot.set_motors(-speed * 0.9, speed * 0.9)
    elif direction == 'turn_right':
        robot.set_motors(speed * 0.9, -speed * 0.9)
    else:
        robot.stop()

# calculate moving average probability to smooth robot movement
moving_average_probs = torch.zeros(1, len(class_names)).cuda()

def update_moving_average_probs(y):
    global moving_average_probs, alpha_slider
    alpha = alpha_slider.value
    moving_average_probs = alpha * y + (1 - alpha) * moving_average_probs

# Hysteresis mechanism
min_time_between_changes = 0.5  # minimum time in seconds between movement changes
last_change_time = 0
last_direction = None

def update(change):
    global speed_slider, class_widgets, frame_counter, start_time, last_change_time, last_direction
    x = change['new']
    x = preprocess(x)
    y = model_trt(x)

    # apply the `softmax` function to normalize the output vector so it sums to 1 (which makes it a probability distribution)
    y = F.softmax(y, dim=1)

    # update the moving average of class probabilities
    update_moving_average_probs(y)

    # get the predicted class and its probability from the moving average
    pred_class = torch.argmax(moving_average_probs, dim=1).item()
    pred_prob = torch.max(moving_average_probs).item()

    # update the class widgets
    for i, widget in enumerate(class_widgets):
        prob = moving_average_probs[0, i].item()
        indicator = "<===" if i == pred_class else ""
        widget.value = f"{class_names[i]}: {prob * 100:.2f}% {indicator}"

    speed = speed_slider.value
    threshold = threshold_slider.value

    current_time = time.time()
    time_elapsed_since_last_change = current_time - last_change_time

    if pred_prob > threshold and (last_direction != class_names[pred_class] or time_elapsed_since_last_change >= min_time_between_changes):
        move_robot(class_names[pred_class], speed)
        last_direction = class_names[pred_class]
        last_change_time = current_time
    elif pred_prob <= threshold and (last_direction != 'stop' or time_elapsed_since_last_change >= min_time_between_changes):
        move_robot('stop', 0)
        last_direction = 'stop'
        last_change_time = current_time

    # calculate and update the FPS widget
    frame_counter += 1
    elapsed_time = time.time() - start_time
    fps = frame_counter / elapsed_time
    fps_widget.value = fps

update({'new': camera.value})  # call the update function once to initialize
camera.observe(update, names='value') # update function is called asynchronously

## Stop the robot

In [None]:
camera.unobserve_all()
camera.stop()
time.sleep(1)
robot.stop()