Define dataset

In [1]:
import numpy
from jetcam.csi_camera import CSICamera
import torchvision.transforms as transforms
from xy_dataset import XYDataset

TASK = 'road_following'

CATEGORIES = ['apex']

DATASETS = ['A', 'B','C','D']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)

Define model

In [2]:

import torch
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget
import torchvision
dataset = datasets[DATASETS[3]]
device = torch.device('cuda')
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

# ALEXNET
model = torchvision.models.alexnet(pretrained=True)
model.classifier[-1] = torch.nn.Linear(4096, output_dim)

# SQUEEZENET 
#model = torchvision.models.squeezenet1_1(pretrained=True)
#model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
#model.num_classes = len(dataset.categories)

# RESNET 18
#model = torchvision.models.resnet18(pretrained=True)
#model.fc = torch.nn.Linear(512, output_dim)

# RESNET 34
#model = torchvision.models.resnet34(pretrained=True)
#model.fc = torch.nn.Linear(512, output_dim)

# DENSENET 121
#model = torchvision.models.densenet121(pretrained=True)
#model.classifier = torch.nn.Linear(1024, output_dim)

model = model.to(device)

model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='road_following_model.pth')
file_widget = ipywidgets.Text(description='file path', value='file.txt')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,file_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])


display(model_widget)

VBox(children=(Text(value='road_following_model.pth', description='model path'), Text(value='file.txt', descri…

Train model

In [3]:
import copy
import time
import pandas as pd
BATCH_SIZE = 30

optimizer = torch.optim.Adam(model.parameters())
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')




def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    #try:
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    #state_widget.value = 'stop'
    train_button.disabled = True
    eval_button.disabled = True
    time.sleep(1)

    if is_training:
        model = model.train()
    else:
        model = model.eval()
    data = pd.DataFrame({"epoch":[],"time":[],"loss":[]})
    total_epochs = epochs_widget.value
    

    while epochs_widget.value > 0:
        i = 0
        sum_loss = 0.0
        error_count = 0.0
        t_in = time.time()
        for images, category_idx, xy in iter(train_loader):
            # send data to device
            images = images.to(device)
            xy = xy.to(device)
            if is_training:
                # zero gradients of parameters
                optimizer.zero_grad()

            # execute model to get outputs
            outputs = model(images)
            # compute MSE loss over x, y coordinates for associated categories
            loss = 0.0
            for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
            loss /= len(category_idx)

            if is_training:
                # run backpropogation to accumulate gradients
                loss.backward()

                # step optimizer to adjust parameters
                optimizer.step()

            # increment progress
            count = len(category_idx.flatten())
            i += count
            sum_loss += float(loss)
            progress_widget.value = i / len(dataset)
            loss_widget.value = sum_loss / i
            
        
        if is_training:
            epochs_widget.value = epochs_widget.value - 1
        else:
            break
        t = time.time()-t_in
        epoch = total_epochs - epochs_widget.value
        data = pd.concat([data,pd.DataFrame({"epoch":[epoch],"time":[t],"loss":[sum_loss/i]})])
    model = model.eval()
    
    path = file_widget.value
    with open(path, 'a') as f:
        df_string = data.to_string(index=False)
        f.write(df_string)
    train_button.disabled = False
    eval_button.disabled = False
    #state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

VBox(children=(IntText(value=1, description='epochs'), FloatProgress(value=0.0, description='progress', max=1.…

Convert model to trt

In [10]:
import torch
import torchvision

CATEGORIES = ['apex']
    
device = torch.device('cuda')
# ALEXNET
#model = torchvision.models.alexnet(pretrained=False)
#model.classifier[-1] = torch.nn.Linear(4096, output_dim)

#RESNET18
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 2 * len(CATEGORIES))

# SQUEEZENET 
#model = torchvision.models.squeezenet1_1(pretrained=True)
#model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
#model.num_classes = len(dataset.categories)

# RESNET 34
#model = torchvision.models.resnet34(pretrained=False)
#model.fc = torch.nn.Linear(512, output_dim)

model = model.cuda().eval().half()
model.load_state_dict(torch.load(model_path_widget.value))
from torch2trt import torch2trt

data = torch.zeros((1, 3, 224, 224)).cuda().half()

model_trt = torch2trt(model, [data], fp16_mode=True)
torch.save(model_trt.state_dict(), model_path_widget.value[:-4]+'_trt.pth')