In [1]:
import torch
import torchvision
import cv2
import numpy as np
import traitlets
from IPython.display import display
import ipywidgets.widgets as widgets
from jetbot import Camera, bgr8_to_jpeg
import torch.nn.functional as F
import time
from jetbot import Robot

robot = Robot()
    
mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])
normalize = torchvision.transforms.Normalize(mean, stdev)

segment_dims = (74, 223)

def create_segments(frame, dims, overlap):
    width = dims[0]
    height = dims[1]
    rows, columns, channels = frame.shape
    coords_vec = []
    frames_vec = []
    for i in range(0, rows-height, height-overlap):
        for j in range(0, columns-width, width-overlap):
            sub_frame = frame[i:(i+height), j:(j + width)]
            coords_vec.append((i, j))
            frames_vec.append(sub_frame)
    return coords_vec, frames_vec


def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x


def draw_bounding_boxes(frame, draw_list, sub_coords):
    global segment_dims
    # 0 - height
    for i in draw_list:
        cv2.rectangle(frame, (sub_coords[i][1], sub_coords[i][0]), (segment_dims[0]+sub_coords[i][1], segment_dims[1]+sub_coords[i][1]), (0, 255, 255), 2)
    return frame
    
    
def get_draw_list(frame):
    sub_coords, sub_frames = create_segments(frame, segment_dims, 0)
    i = 0
    draw_list = []
    for sub_f in sub_frames:
        frame_proc = preprocess(cv2.resize(sub_f, (224, 224)))
        y = model(frame_proc)
        y = F.softmax(y, dim=1)
        prob = float(y.flatten()[0])
        if prob > 0.5:
            draw_list.append(i)
        i += 1
    return draw_list, sub_coords
        

def robot_update(draw_list):
    global robot
    if len(draw_list) == 1:
        if draw_list[0] == 0:
            robot.left(0.3)
        elif draw_list[0] ==1:
            robot.forward(0.3)
        elif draw_list[0] == 2:
            robot.right(0.3)
    elif len(draw_list) == 2:
        if draw_list[0] == 1:
            robot.left(0.3)
        else:
            robot.right(0.3)
    
    else:
        robot.stop()
        
            
    
    
def update(change):
    global person_slider, horse_slider, nothing_slider, image
    x = change['new']
    frame = x

    draw_list, sub_coords = get_draw_list(frame)
    frame = draw_bounding_boxes(frame, draw_list, sub_coords)
    robot_update(draw_list)
    image.value = bgr8_to_jpeg(frame)
   # time.sleep(0.001)

    
model = torchvision.models.alexnet(pretrained=False)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 3)
model.load_state_dict(torch.load('best_model_laptop_farm.pth'))
model.eval()
device = torch.device('cuda')
model = model.to(device)  # place model on GPU

camera = Camera.instance(width=224, height=224)
image = widgets.Image(format='jpeg', width=224, height=224)

display(widgets.VBox([widgets.HBox([image])]))

camera.observe(update, names='value')  # this attaches the 'update' function to the 'value' traitlet of our camera
update({'new': camera.value})  # we call the function once to initialize


VBox(children=(HBox(children=(Image(value=b'', format='jpeg', height='224', width='224'),)),))

In [None]:
robot_update(draw_list)

In [None]:
time.sleep(0.1)  # add a small sleep to make sure frames have finished processing
camera.stop()
print("app stopped")