# Collision Avoidance - Train Model

In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch import nn
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import os
import cv2

In [None]:
#from google.colab import drive
#drive.mount('/gdrive')
#%cd /gdrive/MyDrive/tello/

In [3]:
for name in os.listdir("data/free"):
    if not "png" in name:
        continue
    img = cv2.imread("data/free/"+name)
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img_blurred = cv2.GaussianBlur(img_gray, (3, 3), 5)
    canny = cv2.Canny(img_blurred, 20, 160)
    cv2.imwrite("edges/free_edges/"+name, canny)
    #cv2.imshow('Canny', canny)
    #cv2.waitKey(0)
    #print(canny.shape)

In [4]:
for name in os.listdir("data/blocked"):
    if not "png" in name:
        continue
    img = cv2.imread("data/blocked/"+name)
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img_blurred = cv2.GaussianBlur(img_blurred, (3, 3), 5)
    canny = cv2.Canny(img_gray, 20, 160)
    cv2.imwrite("edges/blocked_edges/"+name, canny)
    #cv2.imshow('Canny', canny)
    #cv2.waitKey(0)
    #print(canny.shape)

### Create dataset instance

In [5]:
dataset = datasets.ImageFolder(
    'edges',
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(1),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.Grayscale(1),
        transforms.ToTensor()
    ])
)

### Split dataset into train and test sets

In [6]:
train_len = int(len(dataset)*0.75)
valid_len = len(dataset) - train_len
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [train_len, valid_len])

### Create data loaders to load data in batches

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

### Define the neural network

In [8]:
class EdgeNet(nn.Module):

    def __init__(self, num_classes=2):
        super(EdgeNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def edge_net(progress=True, **kwargs):
    model = EdgeNet(**kwargs)
    return model

In [15]:
model = edge_net()

In [16]:
device = torch.device('cuda') # Change to cpu if no cuda available
model = model.to(device)

### Train the neural network

In [None]:
NUM_EPOCHS = 100
BEST_MODEL_PATH = 'saved_models/best_model.pth'
best_f1 = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.0008, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    valid_fp_count = 0.0
    valid_fn_count = 0.0
    valid_tp_count = 0.0
    valid_tn_count = 0.0
    valid_accuracy = 0.0
    for images, labels in iter(valid_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        if len(labels[labels==0]) > 0:
            valid_fn_count += float(torch.sum(labels[labels==0] != outputs[labels==0].argmax(1)))
            valid_tp_count += float(torch.sum(labels[labels==0] == outputs[labels==0].argmax(1)))
        if len(labels[labels==1]) > 0:
            valid_fp_count += float(torch.sum(labels[labels==1] != outputs[labels==1].argmax(1)))
            valid_tn_count += float(torch.sum(labels[labels==1] == outputs[labels==1].argmax(1)))
    precision = valid_tp_count/(valid_tp_count + valid_fp_count + 0.00001)
    recall = valid_tp_count/(valid_tp_count + valid_fn_count + 0.00001)
    
    print("Precision ", precision)
    print("Recall ", recall)
    print("fp ", valid_fp_count)
    print("fn ", valid_fn_count)
    print("tp ", valid_tp_count)
    print("tn ", valid_tn_count)
    
    valid_f1 = 2*precision*recall/(precision+recall + 0.00001)
    print('%d: %f' % (epoch, valid_f1))
    if valid_f1 > best_f1:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_f1 = valid_f1