# Collision Avoidance (ResNet18)

Reference: [NVIDIA JetBot Github Repository][JETBOT_GITHUB_LINK]

[JETBOT_GITHUB_LINK]: https://github.com/NVIDIA-AI-IOT/jetbot/tree/master/notebooks/collision_avoidance

## 0. Import modules

In [None]:
import torch
import torch.optim as optim # for SGD
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

import torchvision
import torchvision.models as models # for resnet18
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

import matplotlib.pyplot as plt
import time

## 1. Prepare the dataset

In [None]:
DATASET_PATH = "./datasets/dataset_white"
# DATASET_PATH = "./datasets/dataset_black"

IMAGE_WIDTH  = 224
IMAGE_HEIGHT = 224

# The two constants below use the values specified in the reference.
NORMALIZE_MEAN = (0.485, 0.456, 0.406)
NORMALIZE_STD  = (0.229, 0.224, 0.225)

total_dataset = ImageFolder(
    DATASET_PATH,
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
        transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
    ])
)

print(f"{len(total_dataset)} images have been loaded.") # Logger

In [None]:
SPLIT_RATIO = (0.8, 0.1, 0.1) # train : valid : test

total_data_num = len(total_dataset)

train_data_num = int(total_data_num * SPLIT_RATIO[0])
valid_data_num = int(total_data_num * SPLIT_RATIO[1])
model_data_num = train_data_num + valid_data_num

test_data_num  = int(total_data_num * SPLIT_RATIO[2])

model_dataset, test_dataset  = random_split(total_dataset, (model_data_num, test_data_num))
train_dataset, valid_dataset = random_split(model_dataset, (train_data_num, valid_data_num))

#-- Logger --#
print(f"Train Dataset: {len(train_dataset)} images.") # print(train_data_num)
print(f"Validation Dataset: {len(valid_dataset)} images.") # print(valid_data_num)
print(f"Test Dataset: {len(test_dataset)} images.") # print(test_data_num)
#-- Logger --#

In [None]:
BATCH_SIZE = 8

train_loader = DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 0
)

valid_loader = DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 0
)

## 2. Define the model (ResNet18)

In [None]:
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 2) # 2 for 'blocked' and 'free'

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("This environment supports the CUDA.") # Logger
else:
    print("This environment does not support the CUDA.") # Logger
    print("The model will run on the CPU instead.") # Logger
    # pass

model = model.to(device)

# print(model) 

## 3. Train the model

In [None]:
CURRENT_TIME = time.strftime('%Y%m%d_%I%M%S%p', time.localtime())
BEST_MODEL_PATH = f"./best_models/best_model_resnet18_{CURRENT_TIME}.pth"

# hyper parameters
NUM_EPOCHS = 30
LEARNING_RATE = 0.001
MOMENTUM = 0.9
L2_CONST = 1e-5

best_accuracy = 0.0 # validation accuracy

# SGD optimizer with L2 regularization
optimizer = optim.SGD(model.parameters(),
                      lr=LEARNING_RATE,
                      momentum=MOMENTUM,
                      weight_decay=L2_CONST)

accuracy_history = []


# model training loop
for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        
    valid_error = 0.0
    for images, labels in iter(valid_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        outputs = outputs.argmax(1)
        valid_error += float(torch.sum(torch.abs(labels - outputs)))
        
    valid_accuracy = 1.0 - (valid_error / valid_data_num)
    accuracy_history.append(valid_accuracy)
    
    print(f"[Epoch {epoch}] Validation accuracy: {valid_accuracy: .5f}") # Logger
    
    if valid_accuracy > best_accuracy:
        print("\tSave the best model.") # Logger
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = valid_accuracy

print("Training Complete!") # Logger
print(f"Best validation accuracy: {best_accuracy}") # Logger

In [None]:
PLOT_PATH = f"./plots/validation_accuracy_plot_restnet18_{CURRENT_TIME}.png"

title = "Validation Accuracy Plot (ResNet18)"
subtitle = f"(lr={LEARNING_RATE}, momentum={MOMENTUM})"

plt.plot(accuracy_history)
plt.suptitle(title)
plt.title(subtitle)
plt.xlabel("Epoch")
plt.ylabel("Validation accuracy")

plt.savefig(PLOT_PATH)
plt.show()

## 4. Test the model

In [None]:
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 2)

print(f"Load model from \"{BEST_MODEL_PATH}\".") # Logger

model.load_state_dict(torch.load(BEST_MODEL_PATH))

model = model.to(device)
model = model.eval().half()

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0
)

In [None]:
correct_case_count = 0

for case, sample in enumerate(iter(test_loader)):
    image, label = sample
    image = image.to(device).half()
    label = int(label)
    predict = model(image)
    predict = F.softmax(predict, dim=1)
    predict = predict.data[0]
    
    #-- Logger --#
    print(f"[Test Case {case}]")
    print(f"\t[Prediction] {float(predict[0]) : .5f} : {float(predict[1]) : .5f}")
    print(f"\t[Real output] {label}")
    
    if label == 1 and float(predict[0]) < float(predict[1]):
        correct_case_count += 1
        print(f"\t[Result] Correct")
    elif label == 0 and float(predict[0]) > float(predict[1]):
        correct_case_count += 1
        print(f"\t[Result] Correct")
    else:
        print(f"\t[Result] Incorrect")
    #-- Logger --#
    
print(f"[Total Test Accuracy] {correct_case_count/test_data_num : .5f}")