In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from torchvision import datasets, transforms
from torchview import draw_graph
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from model import *

In [None]:
mps_device = None
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print('MPS GPU found!')
else:
    print ("MPS device not found.")

In [None]:
lr = 0.005
epochs = 20
batch_size = 256
train_perc = 0.8

In [None]:
transform = transforms.Compose([transforms.Resize( (128, 128) ), 
                                transforms.RandomHorizontalFlip(), 
                                transforms.RandomRotation(20),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = datasets.ImageFolder('./dataset/train', transform=transform)
train_set, test_set = torch.utils.data.random_split(dataset, [train_perc, 1-train_perc])
loader_train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
cnn = SimpleCNN()
fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=lr)

In [None]:
# Show model
print(f'Number of parameters {cnn.parameters()}')
model_graph = draw_graph(cnn, input_size=(1, 3, 128, 128), expand_nested=True)
model_graph.visual_graph

In [None]:
# Sending model to GPU if available
if mps_device != None:
    res = cnn.to(mps_device)

loss_history = {'train': [], 'test': []}
accuracy = []
for epoch in range(1,epochs+1):
    print(f"------> Epoch {epoch}")
    loss_train = 0.0
    loss_test = 0.0
    n_correct = 0.0
    n_examples = 0.0
    
    print('Train progress:')
    cnn.train()
    for i, data in enumerate(tqdm(loader_train)):
        # Unpacking batches and labels
        batch, labels = data
        batch = batch.to(mps_device)
        labels = labels.to(mps_device)

        # Computing prediction and updating weights
        optimizer.zero_grad()
        pred = cnn(batch)
        loss = fn(pred, labels)
        loss.backward()
        optimizer.step()

        # Loss per epoch 
        loss_train += loss.cpu().detach().numpy()
        
    loss_history['train'].append(loss_train / len(loader_train))

    # Computing loss on test set to check for optimal fitting
    print('Test progress:')
    with torch.no_grad():
        cnn.eval()
        for i, data in enumerate(tqdm(loader_test)):
            batch, labels = data
            batch = batch.to(mps_device)
            labels = labels.to(mps_device)
            pred = cnn(batch)
            loss = fn(pred, labels)
            loss_test += loss.cpu().detach().numpy()

            # Computing accuracy on test set
            correct = torch.eq( torch.max( F.softmax(pred, dim=1), dim=1 )[1], labels ).view(-1)
            n_correct += torch.sum(correct).item()
            n_examples += correct.shape[0]
        
        loss_history['test'].append(loss_test / len(loader_test))
        accuracy.append(n_correct / n_examples)
    print(f'Train loss = {loss_history['train'][-1]:.3f}')
    print(f'Test loss = {loss_history['test'][-1]:.3f}, test accuracy {accuracy[-1]:.3f}')

In [None]:
fig, ax = plt.subplots()
ax.set_title('Losses')
ax.plot(loss_history['train'], '-o', color='red', label='train loss')
ax.plot(loss_history['test'], '-o', color='blue', label='test loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

fig, ax = plt.subplots()
ax.set_title('Accuracy')
ax.plot(accuracy, '-o', color='orange')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')

In [None]:
# Confusion matrix
correct_labels = []
predicted_labels = []
with torch.no_grad():
    for i, data in enumerate( tqdm(loader_test) ):
        batch, true_labels = data
        cnn.cpu()
        cnn_out = cnn(batch)
        pred_labels = torch.max( F.softmax(cnn_out, dim=1), 1)[1]
        correct_labels.extend(true_labels)
        predicted_labels.extend(pred_labels)

In [None]:
cm = confusion_matrix(correct_labels, predicted_labels)
ConfusionMatrixDisplay(cm).plot()
