# Collision Avoidance (DNN)

## 0. Import modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim # for SGD
from torch.utils.data import random_split, DataLoader

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

import os # isdir, mkdir
import matplotlib.pyplot as plt
from time import localtime, strftime

## 1. Prepare the dataset

In [None]:
# DATASET_PATH = "./datasets/dataset_white"
DATASET_PATH = "./datasets/dataset_blue"

IMAGE_WIDTH  = 32
IMAGE_HEIGHT = 32
IMAGE_CHANNEL = 1

In [None]:
total_dataset = ImageFolder(
    DATASET_PATH,
    transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.Grayscale(num_output_channels=IMAGE_CHANNEL),
        transforms.ToTensor(),
        transforms.Normalize([0.449], [0.226]),
        transforms.Lambda(lambda img: torch.flatten(img)) # https://stackoverflow.com/questions/60900406/
    ])
)

print(f"{len(total_dataset)} images have been loaded.") 

In [None]:
SPLIT_RATIO = (0.8, 0.1, 0.1) # train : valid : test

total_data_num = len(total_dataset)

train_data_num = int(total_data_num * SPLIT_RATIO[0])
valid_data_num = int(total_data_num * SPLIT_RATIO[1])
model_data_num = train_data_num + valid_data_num

test_data_num  = int(total_data_num * SPLIT_RATIO[2])

model_dataset, test_dataset  = random_split(total_dataset, [model_data_num, test_data_num])
train_dataset, valid_dataset = random_split(model_dataset, [train_data_num, valid_data_num])

#-- Logger --#
print(f"Train Dataset: {len(train_dataset)} images.") # print(train_data_num)
print(f"Validation Dataset: {len(valid_dataset)} images.") # print(valid_data_num)
print(f"Test Dataset: {len(test_dataset)} images.") # print(test_data_num)
#-- Logger --#

In [None]:
BATCH_SIZE = 8

train_loader = DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 0
)

valid_loader = DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 0
)

## 2. Define the model (DNN)

In [None]:
INPUT_SIZE = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_CHANNEL

class DNN(nn.Module):
    
    """Custom DNN model for the Image classification."""
    
    __slots__ = "__model"
    
    def __init__(self, input_dim=INPUT_SIZE, output_dim=2, hidden_dims=(128, 64, 32), do_batch_normal=True, dropout=0):
        
        super(DNN, self).__init__()
        
        dims_list = (input_dim, *hidden_dims)
        model_components = []
        
        # hidden layers
        for i in range(1, len(dims_list)):
            current_input_dim = dims_list[i-1]
            current_output_dim = dims_list[i]
            model_components.append(nn.Linear(current_input_dim, current_output_dim))
            
            if do_batch_normal == True:
                model_components.append(nn.BatchNorm1d(current_output_dim))
            
            model_components.append(nn.ReLU())
            
            if dropout > 0:
                model_components.append(nn.Dropout(dropout))
        
        # output layer
        output_layer = nn.Linear(dims_list[-1], output_dim)
        model_components.append(output_layer)
        model_components.append(nn.Softmax(dim=1))
        
        # make DNN model
        self.__model = nn.Sequential(*model_components)
        
    def forward(self, x):
        return self.__model(x)

## 3. Train the model

In [None]:
model = DNN(hidden_dims=(128, 64, 32))

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("This environment supports the CUDA.") # Logger
else:
    print("This environment does not support the CUDA.") # Logger
    print("The model will be running on the CPU instead.") # Logger
    # pass

model = model.to(device)

# print(model) 

In [None]:
if not os.path.isdir("./best_models"):
    os.mkdir("./best_models")

CURRENT_TIME = strftime('%Y%m%d_%H%M%S', localtime())
BEST_MODEL_PATH = f"./best_models/best_model_dnn_{CURRENT_TIME}.pth"

# hyper parameters
EPOCHS = 30
LEARNING_RATE = 0.001
MOMENTUM = 0.9
L2_CONST = 1e-4

best_accuracy = 0.0 # validation accuracy

criterion = nn.CrossEntropyLoss()

# SGD optimizer with L2 regularization
optimizer = optim.SGD(model.parameters(),
                      lr=LEARNING_RATE,
                      momentum=MOMENTUM,
                      weight_decay=L2_CONST)

accuracy_history = []

EPOCH_DIGIT = len(str(EPOCHS)) # for Logger

In [None]:
# model training loop
for epoch in range(EPOCHS):
    
    model.train()
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
    valid_error = 0.0
    for images, labels in iter(valid_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        valid_error += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
        
    valid_accuracy = 1.0 - float(valid_error) / float(valid_data_num)
    
    if valid_accuracy < 0:
        valid_accuracy = 0
    
    accuracy_history.append(valid_accuracy)
    
    print(f"[Epoch {epoch: >{EPOCH_DIGIT}d}] Accuracy: {valid_accuracy: .5f}") # Logger
    
    if valid_accuracy > best_accuracy:
        print("\tSave the best model") # Logger
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = valid_accuracy

print("Training Complete!") # Logger
print(f"Best validation accuracy: {best_accuracy: .5f}") # Logger

In [None]:
if not os.path.isdir("./plots"):
    os.mkdir("./plots")

PLOT_PATH = f"./plots/validation_accuracy_plot_dnn_{CURRENT_TIME}.png"

title = "Validation Accuracy Plot (DNN)"
subtitle = f"(lr={LEARNING_RATE}, momentum={MOMENTUM} l2_constant={L2_CONST})"

plt.plot(accuracy_history)
plt.suptitle(title)
plt.title(subtitle)
plt.xlabel("Epoch")
plt.ylabel("Validation accuracy")

plt.savefig(PLOT_PATH)
plt.show()

## 4. Test the model

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0
)

In [None]:
correct_case_count = 0

model.eval()

for case, sample in enumerate(iter(test_loader)):
    image, label = sample
    image = image.to(device)
    label = int(label)
    predict = model(image)
    predict = predict.flatten()
    
    #-- Logger --#
    print(f"[Test Case {case}]")
    print(f"\t[Prediction] {float(predict[0]): .5f} : {float(predict[1]): .5f}")
    # print(f"\t[Prediction] Blocked : Free")
    print(f"\t[Real output] {label}") # 0: Blocked, 1: Free
    #-- Logger --#
    
    if label == 1 and float(predict[0]) < float(predict[1]):
        correct_case_count += 1
        print(f"\t[Result] Correct") # Logger
    elif label == 0 and float(predict[0]) > float(predict[1]):
        correct_case_count += 1
        print(f"\t[Result] Correct") # Logger
    else:
        print(f"\t[Result] Incorrect") # Logger
        # pass
    
print(f"[Total Test Accuracy] {correct_case_count/test_data_num : .5f}")