# Collision Avoidance - Train Model

In [6]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import os
import cv2

ModuleNotFoundError: No module named 'torch'

In [None]:
#from google.colab import drive
#drive.mount('/gdrive')
#%cd /gdrive/MyDrive/tello/

In [3]:
for name in os.listdir("data/free"):
    if not "png" in name:
        continue
    img = cv2.imread("data/free/"+name)
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    canny = cv2.Canny(img_gray, 60, 160)
    cv2.imwrite("edges/free_edges/"+name, canny)
    #cv2.imshow('Canny', canny)
    #cv2.waitKey(0)
    #print(canny.shape)

In [4]:
for name in os.listdir("data/blocked"):
    if not "png" in name:
        continue
    img = cv2.imread("data/blocked/"+name)
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    canny = cv2.Canny(img_gray, 60, 160)
    cv2.imwrite("edges/blocked_edges/"+name, canny)
    #cv2.imshow('Canny', canny)
    #cv2.waitKey(0)
    #print(canny.shape)

### Create dataset instance

In [5]:
dataset = datasets.ImageFolder(
    'edges',
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(1),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.ToTensor()
    ])
)

NameError: name 'datasets' is not defined

### Split dataset into train and test sets

In [None]:
train_len = int(len(dataset)*0.75)
valid_len = len(dataset) - train_len
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [train_len, valid_len])

### Create data loaders to load data in batches

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

### Define the neural network

In [None]:
model = models.alexnet(pretrained=True)

In [None]:
model.features[0] = torch.nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

In [None]:
device = torch.device('cuda') # Change to cpu if no cuda available
model = model.to(device)

### Train the neural network

In [None]:
import matplotlib.pyplot as plt
import numpy as np
cont = 0
for images, labels in iter(train_loader):
    if cont >= 10:
        break
    cont += 1
    img = np.reshape(images[0], (224, 224, 3))
    plt.imshow(img)
    plt.show()
    print(labels[0])

In [None]:
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'saved_models/best_model_1.pth'
best_f1 = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    valid_fp_count = 0.0
    valid_fn_count = 0.0
    valid_tp_count = 0.0
    valid_tn_count = 0.0
    valid_accuracy = 0.0
    for images, labels in iter(valid_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        if len(labels[labels==0]) > 0:
            valid_fn_count += float(torch.sum(labels[labels==0] != outputs[labels==0].argmax(1)))
            valid_tp_count += float(torch.sum(labels[labels==0] == outputs[labels==0].argmax(1)))
        if len(labels[labels==1]) > 0:
            valid_fp_count += float(torch.sum(labels[labels==1] != outputs[labels==1].argmax(1)))
            valid_tn_count += float(torch.sum(labels[labels==1] == outputs[labels==1].argmax(1)))
    precision = valid_tp_count/(valid_tp_count + valid_fp_count)
    recall = valid_tp_count/(valid_tp_count + valid_fn_count)
    
    print("Precision ", precision)
    print("Recall ", recall)
    print("fp ", valid_fp_count)
    print("fn ", valid_fn_count)
    print("tp ", valid_tp_count)
    print("tn ", valid_tn_count)
    
    valid_f1 = 2*precision*recall/(precision+recall)
    print('%d: %f' % (epoch, valid_f1))
    if valid_f1 > best_f1:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_f1 = valid_f1

Precision  0.8823529411764706
Recall  0.7142857142857143
fp  2.0
fn  6.0
tp  15.0
tn  43.0
Accuracy  4.0
0: 0.789474
