In [None]:
import os
import PIL
import numpy as np
import cv2
import torch
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms

import traitlets
import ipywidgets.widgets as widgets
from IPython.display import display

from jetbot import Robot, Camera, bgr8_to_jpeg
from mobile import MobileController

In [None]:

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().half()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda().half()

In [None]:
def _preprocess(image):
        image = PIL.Image.fromarray(image)
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image).to(device).half()
        image.sub_(mean[:, None, None]).div_(std[:, None, None])
        return image[None, ...]

In [None]:
# Initialize DL model

MODEL_PATH = 'best_steering_model_ResNet.pth'
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model.load_state_dict(torch.load(MODEL_PATH))
model = model.to(device)
model = model.eval().half()


In [None]:
#Initialize Jetbot
WHEEL_TRACK = 10
robot = Robot()
# MobileControllerはコントローラからの操作性を上げるため、速度値を一定の範囲で丸め込む。
# しかし、推論時では丸こめ処理で期待した速度と変わるため推論時結果の速度をそのままロボットへ伝えるためpass_throughフラグをTrueにする。
mobile_controller = MobileController(WHEEL_TRACK, robot, pass_through=True)

In [None]:
# Prepare gage widgets.

speed = widgets.FloatSlider(min=-1.0, max=1.0, description='speed')
steering = widgets.FloatSlider(min=-1.0, max=1.0, description='steering')

traitlets.dlink((mobile_controller, 'speed'), (speed,'value'))
traitlets.dlink((mobile_controller, 'radius'), (steering, 'value'))

In [None]:
#Camera initialize
camera = Camera.instance(fps=10, width=224, height=224)
image = widgets.Image(format='jpeg', width=224, height=224)
camera_link = traitlets.dlink((camera,'value'), (image,'value'), transform=bgr8_to_jpeg)

In [None]:
layout = widgets.Layout(width='100px', height='64px')

panel = widgets.VBox([speed, steering])
display(widgets.HBox([panel,image]))

In [None]:
def update(change):
    image = _preprocess(change['new'])
    xy = model(image).detach().float().cpu().numpy().flatten()
    slottle = xy[0]
    handle = xy[1]
    mobile_controller.set_control(float(slottle), float(handle))

update({'new':camera.value})


In [None]:
camera.observe(update, names='value')

In [None]:
camera.unobserve(update, names='value')
camera_link.unlink()
robot.stop()

In [None]:
robot.stop()