### Path Setup

In [None]:
# Change these path
MODEL_SAVE_PATH = "models"

DATA = [
    # TFRecord or Dataset Directory
]

### Includes

In [None]:
import pathlib

# torch as to be import at very beggining to make sure
import torch
import numpy as np
from dataloader import FastDataLoader
from device import device
from matplotlib import pyplot as plt
from models import get_model
from numpy import random
from sampler import load_full_dataset
from scipy import signal
from torch.optim import Adam
from torch.utils.data import DataLoader

### Initalize Data Sampler

In [None]:
batch_size = 256
trainset, testset = load_full_dataset(DATA, train_test_split=0.8)

In [None]:
plt.style.use("classic")
fig, axes = plt.subplots(2, 6, figsize=(25, 8))
axes = axes.flatten()
for i in range(12):
    img, steering, throttle = trainset[random.randint(0, len(trainset))]
    axes[i].imshow(img.permute(1, 2, 0).cpu().numpy())
    axes[i].axis("off")
    axes[i].set_title(f"{steering: .4f}, {throttle: .4f}")

In [None]:
from metrics import angle_metric, direction_metric, loss_fn
from train import Trainer

import wandb

wandb.init()

### Train

In [None]:
trainers = {}

In [None]:
all_models = [
    # Pytorch Hub
    # "alexnet",
    # "vgg16_bn",
    "resnet34",
    "googlenet",
    # Custom
    "cnn"
]

save_dir = pathlib.Path(MODEL_SAVE_PATH)
load_trainer = True

for model_name in all_models:
    print(f"Training {model_name}")

    if model_name in trainers:
        trainer = trainers[model_name]
        # move the model back to device
        trainer.model = trainer.model.to(device)
        trainer.optim.load_state_dict(trainer.optim.state_dict())

    else:
        model = get_model(model_name)().to(device)
        optimizer = Adam(model.parameters(), lr=1e-4)
        trainer = Trainer(save_dir, model, optimizer, turning_weight=5, epochs=1000)
        if load_trainer:
            print(f"Loading trainer")
            trainer.load()

        trainers[model_name] = trainer
        del model, optimizer

    if trainer.i < trainer.epochs:
        try:
            sampler_train = FastDataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
            sampler_test = FastDataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)
            trainer.train(sampler_train, sampler_test)
        finally:
            # Move the model to CPU
            trainer.model = trainer.model.to('cpu')
            trainer.optim.load_state_dict(trainer.optim.state_dict())
            trainer.save()

            try:
                # Close iterator
                sampler_test.close()
                sampler_train.close()
            except:
                pass

In [None]:
# Reload
import importlib
import models
importlib.reload(models)
from models import get_model

In [None]:
# Clean cache
g = globals()
del_list = ["model", "trainer", "optimizer"] + [f"_{i}" for i in range(1000)] + [f"_i{i}" for i in range(1000)]
for i in del_list:
    if i in g:
        del g[i]

import gc
gc.collect()
gc.collect()
torch.cuda.empty_cache()

### Test

In [None]:
test_sampler = DataLoader(testset, batch_size=12, shuffle=True)
test_iterator = iter(test_sampler)

In [None]:
trainer = trainers["cnn"]
model = trainer.model.to(device)
train_log = np.array(trainer.train_log)
validation_log = np.array(trainer.validation_log)

In [None]:
img, steering, throttle = next(test_iterator)
Y = torch.stack([steering, throttle], dim=1).type(torch.float32).to(device)
X = img.to(device) / 256
with torch.no_grad():
    Y_pred = model(X)

val_loss = loss_fn(Y[:, 0], Y[:, 1], Y_pred[:, 0], Y_pred[:, 1], throttle_weight=0.2)
print(val_loss)

plt.style.use("classic")
fig, axes = plt.subplots(2, 6, figsize=(40, 8))
axes = axes.flatten()
p_fn = lambda x: ','.join([f'{i: .4f}' for i in x])

for i in range(12):
    axes[i].imshow(img[i].permute(1, 2, 0).cpu().numpy())
    axes[i].axis("off")
    axes[i].set_title(f"Model:{p_fn(Y_pred[i].tolist())}\nTrue: {p_fn([steering[i], throttle[i]])}")

### Plot Loss

In [None]:
plt.style.use("classic")
plt.figure(figsize=(25, 4))
plt.subplot(1, 3, 1)
plt.plot(train_log[:, 0], '.', markersize=1, color="black")
plt.yscale('log')

plt.subplot(1, 3, 2)
plt.plot(validation_log[:, 0], '-', markersize=3, color="black")
plt.yscale("log")

plt.subplot(1, 3, 3)
plt.plot(signal.convolve(train_log[:, 0], np.ones(100) / 100, 'valid'), '.', markersize=1, color="red")

### Plot Angle Metric

In [None]:
plt.style.use("classic")
plt.figure(figsize=(25, 4))
plt.subplot(1, 3, 1)
plt.plot(train_log[:, 1], '.', markersize=1, color="black")

plt.subplot(1, 3, 2)
plt.plot(validation_log[:, 1], '-', markersize=3, color="black")

plt.subplot(1, 3, 3)
plt.plot(signal.convolve(train_log[:, 1], np.ones(100) / 100, 'valid'), '.', markersize=1, color="red")

### Plot Direction Metric

In [None]:
plt.style.use("classic")
plt.figure(figsize=(25, 4))
plt.subplot(1, 3, 1)
plt.plot(train_log[:, 2], '.', markersize=1, color="black")

plt.subplot(1, 3, 2)
plt.plot(validation_log[:, 2], '-', markersize=3, color="black")

plt.subplot(1, 3, 3)
plt.plot(signal.convolve(train_log[:, 2], np.ones(100) / 100, 'valid'), '.', markersize=1, color="red")