In [1]:
import cv2
import os
from os.path import isfile, join, exists
from IPython.display import display, HTML
import PIL
import shutil
import itertools
from datetime import datetime

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.functional import mse_loss
import torchvision.models as models
from torchsummary import summary

from model_src.data import JetBotDataset

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
if 'google.colab' in str(get_ipython()):
  from google.colab.patches import cv2_imshow
  imshow = cv2_imshow
else:
  def imshow(a):
    """
    img= img.clip(0, 255).astype('uint8')
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    """
    a = a.clip(0, 255).astype('uint8')
    if a.ndim == 3:
      if a.shape[2] == 4:
        a = cv2.cvtColor(a, cv2.COLOR_BGRA2RGBA)
      else:
        a = cv2.cvtColor(a, cv2.COLOR_BGR2RGB)
    display(PIL.Image.fromarray(a))

In [3]:
train_dataset = JetBotDataset('dataset/augmented', use_next=True)
test_dataset = JetBotDataset('dataset/augmented', split_type='test', use_next=True)

In [4]:
train_dataloader = DataLoader(train_dataset, batch_size = 128, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = 128, shuffle = True)

In [5]:
class SqueezeNetBasedNetwork(nn.Module):
    def __init__(self, output_attr: int):
        super().__init__()
        self.squeezenet = models.squeezenet1_1(pretrained=True)
        self.regressor = nn.Sequential(
            nn.Linear(1000, 64),
            nn.Linear(64, output_attr)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.squeezenet(x)
        x = self.regressor(x)
        return torch.flatten(x, 1)

In [6]:
model = SqueezeNetBasedNetwork(2)
model = model.to(device)

In [7]:
NUM_EPOCHS = 50
time = str(datetime.now()).replace(' ', '-').replace(':', '-')
BEST_MODEL_PATH = 'model/best_model.pth'
models_path = 'model/archive'
best_loss = torch.inf

optimizer = optim.Adam(model.parameters())

for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for imgs, controls in iter(train_dataloader):
        imgs = imgs.to(device)
        controls = controls.to(device)
        optimizer.zero_grad()
        outputs = model(imgs).double()
        loss = mse_loss(outputs, controls)
        train_loss += loss
        loss.backward()
        optimizer.step()
    train_loss /= len(train_dataloader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = mse_loss(outputs, labels)
        test_loss += float(loss)
    test_loss /= len(test_dataloader)
    
    print('%f, %f' % (train_loss, test_loss))
    file_name = f'{time}-epoch-{epoch}.pt'
    file_path = models_path + '/' + file_name
    torch.save(model.state_dict(), file_path)
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss

0.285021, 0.184251
0.336539, 0.184323
0.237402, 0.183777
0.237465, 0.183903
0.237442, 0.183615
0.237443, 0.183265
0.237554, 0.184957
0.237553, 0.180383
0.237601, 0.183107
0.237581, 0.183490
0.237622, 0.184459
0.237566, 0.194764
0.237658, 0.183245
0.237631, 0.185591
0.237656, 0.178149
0.237630, 0.183038
0.237585, 0.179051
0.237558, 0.181218
0.237606, 0.181286
0.237625, 0.178296
0.237493, 0.186531
0.237541, 0.182701
0.237486, 0.185194
0.237518, 0.185160
0.237497, 0.184009
0.237578, 0.187233
0.237640, 0.188868
0.237434, 0.185964
0.237559, 0.179526
0.237590, 0.182321
0.237526, 0.184730
0.237566, 0.185516
0.237521, 0.180065
0.237509, 0.180035
0.237445, 0.184937
0.237487, 0.180916


KeyboardInterrupt: 