In [1]:
from jetcam.csi_camera import CSICamera
import torchvision.transforms as transforms
from xy_dataset import XYDataset
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget
import torch
import torchvision
import threading
import time
from utils import preprocess
import torch.nn.functional as F

In [2]:
# Camera Setup
camera = CSICamera(width=224, height=224)
camera.running = True

# Dataset Setup
TASK = 'road_following'
CATEGORIES = ['apex']
DATASETS = ['A', 'B']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {name: XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True) for name in DATASETS}
dataset = datasets[DATASETS[0]]

# UI Setup
camera.unobserve_all()
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
snapshot_widget = ipywidgets.Image(width=camera.width, height=camera.height)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')
count_widget = ipywidgets.IntText(description='count', value=dataset.get_count(category_widget.value))

def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')

def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')

def save_snapshot(_, content, msg):
    if content['event'] == 'click':
        x, y = content['eventData']['offsetX'], content['eventData']['offsetY']
        dataset.save_entry(category_widget.value, camera.value, x, y)
        snapshot = cv2.circle(camera.value.copy(), (x, y), 8, (0, 255, 0), 3)
        snapshot_widget.value = bgr8_to_jpeg(snapshot)
        count_widget.value = dataset.get_count(category_widget.value)
camera_widget.on_msg(save_snapshot)

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    dataset_widget, category_widget, count_widget
])

display(data_collection_widget)

# Model Setup
device = torch.device('cuda')
output_dim = 2 * len(dataset.categories)
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)
model = model.to(device)

model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='road_following_model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)

def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])
display(model_widget)

# Live Prediction
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)

def live(state_widget, model, camera, prediction_widget):
    global dataset
    while state_widget.value == 'live':
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed).detach().cpu().numpy().flatten()
        category_index = dataset.categories.index(category_widget.value)
        x = int(camera.width * (output[2 * category_index] / 5.0 + 0.55))
        y = int(camera.height * (output[2 * category_index + 1] / 6.0 + 0.4))
        prediction = cv2.circle(image.copy(), (x, y), 8, (255, 0, 0), 3)
        prediction_widget.value = bgr8_to_jpeg(prediction)

def start_live(change):
    if change['new'] == 'live':
        threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget)).start()
state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([prediction_widget, state_widget])
display(live_execution_widget)

# Training & Evaluation
BATCH_SIZE = 8
optimizer = torch.optim.Adam(model.parameters())
epochs_widget = ipywidgets.IntText(description='epochs', value=10)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')

def train_eval(is_training):
    global model, dataset
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    state_widget.value = 'stop'
    train_button.disabled = eval_button.disabled = True
    time.sleep(1)
    model = model.train() if is_training else model.eval()
    for images, category_idx, xy in iter(train_loader):
        images, xy = images.to(device), xy.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = torch.mean((outputs[:, 2 * category_idx:2 * category_idx+2] - xy)**2)
        if is_training:
            loss.backward()
            optimizer.step()
        progress_widget.value += len(category_idx) / len(dataset)
        loss_widget.value = loss.item()
    train_button.disabled = eval_button.disabled = False
    state_widget.value = 'live'

train_button.on_click(lambda c: train_eval(True))
eval_button.on_click(lambda c: train_eval(False))

train_eval_widget = ipywidgets.VBox([epochs_widget, progress_widget, loss_widget, ipywidgets.HBox([train_button, eval_button])])
display(train_eval_widget)

# Complete UI
all_widget = ipywidgets.VBox([ipywidgets.HBox([data_collection_widget, live_execution_widget]), train_eval_widget, model_widget])
display(all_widget)


VBox(children=(HBox(children=(ClickableImageWidget(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x0…

VBox(children=(Text(value='road_following_model.pth', description='model path'), HBox(children=(Button(descrip…

VBox(children=(Image(value=b'', format='jpeg', height='224', width='224'), ToggleButtons(description='state', …

VBox(children=(IntText(value=10, description='epochs'), FloatProgress(value=0.0, description='progress', max=1…

VBox(children=(HBox(children=(VBox(children=(HBox(children=(ClickableImageWidget(value=b'\xff\xd8\xff\xe0\x00\…