In [27]:
# import libraries
from jetbot import Robot, Camera, bgr8_to_jpeg
from IPython.display import display
from uuid import uuid1
import ipywidgets.widgets as widgets
import traitlets, time, os
import torch
import torchvision
import torch.nn.functional as F

model = torchvision.models.alexnet(pretrained=False)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

In [28]:
model.load_state_dict(torch.load('model_stop_02.pth'))

In [29]:
device = torch.device('cuda')
model = model.to(device)
robot = Robot()

In [30]:
# create buttons
button_layout = widgets.Layout(width='100px', height='80px', align_self='center')
stop_button = widgets.Button(description='stop', button_style='danger', layout=button_layout)
forward_button = widgets.Button(description='forward', layout=button_layout)
backward_button = widgets.Button(description='backward', layout=button_layout)
left_button = widgets.Button(description='left', layout=button_layout)
right_button = widgets.Button(description='right', layout=button_layout)

free_button = widgets.Button(description='add free', button_style='success', layout=button_layout)
blocked_button = widgets.Button(description='add blocked', button_style='danger', layout=button_layout)
free_count = widgets.IntText(layout=button_layout, value=0)
blocked_count = widgets.IntText(layout=button_layout, value=0)



# define buttons and actions
def stop(change):
    robot.stop()
    
def step_forward(change):
    robot.forward(0.3)
    time.sleep(0.3)
    robot.stop()

def step_backward(change):
    robot.backward(0.3)
    time.sleep(0.3)
    robot.stop()

def step_left(change):
    robot.left(0.2)
    time.sleep(0.2)
    robot.stop()

def step_right(change):
    robot.right(0.2)
    time.sleep(0.2)
    robot.stop()

def save_snapshot(directory):
    image_path = os.path.join(directory, str(uuid1()) + '.jpg')
    with open(image_path, 'wb') as f:
        f.write(image.value)

def save_free():
    global free_dir, free_count
    save_snapshot(free_dir)
    free_count.value = len(os.listdir(free_dir))
    print("saved free")
    
def save_blocked():
    global blocked_dir, blocked_count
    save_snapshot(blocked_dir)
    blocked_count.value = len(os.listdir(blocked_dir))
    print("saved blocked")
    
# link buttons to actions
stop_button.on_click(stop)
forward_button.on_click(step_forward)
backward_button.on_click(step_backward)
left_button.on_click(step_left)
right_button.on_click(step_right)
free_button.on_click(lambda x: save_free())
blocked_button.on_click(lambda x: save_blocked())

# display buttons
top_box = widgets.HBox([free_button, forward_button, blocked_button], layout=widgets.Layout(align_self='center'))
middle_box = widgets.HBox([left_button, backward_button, right_button], layout=widgets.Layout(align_self='center'))
bottom = widgets.HBox([free_count, stop_button, blocked_count], layout=widgets.Layout(align_self='center')) 

In [31]:
import cv2
import numpy as np

mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])

normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x

In [32]:
camera = Camera.instance(width=224, height=224, fps=10)
image = widgets.Image(format='jpeg', width=224, height=224)
blocked_slider = widgets.FloatSlider(description='blocked', min=0.0, max=1.0, orientation='vertical')

camera_link = traitlets.dlink((camera, 'value'), (image, 'value'), transform=bgr8_to_jpeg)

display(widgets.HBox([image, blocked_slider]))

HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C…

In [33]:
# display gui
display(top_box)
display(middle_box)
display(bottom)

HBox(children=(Button(button_style='success', description='add free', layout=Layout(align_self='center', heigh…

HBox(children=(Button(description='left', layout=Layout(align_self='center', height='80px', width='100px'), st…

HBox(children=(IntText(value=0, layout=Layout(align_self='center', height='80px', width='100px')), Button(butt…

In [34]:
import torch.nn.functional as F
import time

def update(change):
    global blocked_slider, robot
    x = change['new'] 
    x = preprocess(x)
    y = model(x)
    
    # we apply the `softmax` function to normalize the output vector so it sums to 1 (which makes it a probability distribution)
    y = F.softmax(y, dim=1)
    
    prob_blocked = float(y.flatten()[0])
    
    blocked_slider.value = prob_blocked
    
    if prob_blocked < 0.5:
        #robot.forward(0.1)
        print("free: ", prob_blocked)
    else:
        #robot.left(0.2)
        print("blocked: ", prob_blocked)
    
    time.sleep(0.001)
        
update({'new': camera.value})  # we call the function once to intialize

blocked:  0.9852383732795715


In [35]:
camera.observe(update, names='value')  # this attaches the 'update' function to the 'value' traitlet of our camera

In [25]:
camera.unobserve(update, names='value')
robot.stop()

blocked:  0.9995446801185608
blocked:  0.9911380410194397
blocked:  0.9799060225486755
