# *Attention*

In order to be able to run this notebook without problems, one must do the following first:
- Add a shortcut of [this](https://drive.google.com/drive/folders/1l-74Qd42waVghTfSsJAPJIMXeg0f77XK?usp=sharing) drive folder to your own drive.
- Mount on your drive.
- Set the *drive_folder* variable to point to the path of the shared folder on your drive.

In [18]:
drive_folder = './drive/MyDrive/Shortcuts/AML_HW1/'
!ls './drive/MyDrive/Shortcuts/AML_HW1/'

AML_HW1.pdf  AML_HW1.zip  Practical


In [17]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp './drive/MyDrive/Shortcuts/AML_HW1/AML_HW1.zip' './'
!unzip AML_HW1.zip

# **Imports**

In [26]:
import os
import glob
import time
import torch
import torchvision
import numpy as np
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader

# **Parameters**

In [27]:
data_path = 'AML_HW1/Practical/Q4/tiny-image-subset/'

BATCH_SIZE = 128
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda:0


# **General Utilities**

In [28]:
def conv_block(in_channels, out_channels):
    '''
    returns a block conv-bn-relu-pool
    '''
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )
    
class ProtoNetBack(nn.Module):
    def __init__(self, input_channels = 1):
        super(ProtoNetBack, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_channels, 64, 1),
            conv_block(64, 128),
            conv_block(128, 128),
            conv_block(128, 128),
            conv_block(128, 128),
        )

    def get_embedding_size(self, input_size = (1,28,28)):
        device = next(self.parameters()).device
        x = torch.rand([2,*input_size]).to(device)
        with torch.no_grad():
            output = self.forward(x)
            emb_size = output.shape[-1]
        
        del x,output
        torch.cuda.empty_cache()

        return emb_size

    def forward(self, x):
        return self.layers (x).reshape ([x.shape[0] , -1])

In [29]:
class Trainer():

    def __init__(self, name, net, optimizer, criterion, epochs, trainloader,
                 valloader, log_every_iter=100, eval_every_epoch=1, save_every_iter=500,
                 save_path='./', device=None):
        
        self.epochs = epochs
        self.trainloader = trainloader
        self.valloader = valloader
        self.net = net
        self.optimizer = optimizer
        self.criterion = criterion
        self.log_every_iter = log_every_iter
        self.eval_every_epoch = eval_every_epoch
        self.device = device
        self.save_path = save_path
        self.save_every_iter = save_every_iter
        self.name = name
    
    def train(self):
        print('Number of iterations in each epoch: {}'.format(len(self.trainloader)))
        print('Number of validation iterations: {}'.format(len(self.valloader)))
        print('Training started ...')
        
        history = {
            'train_loss': [],
            'train_acc': [],
            'val_acc': []
        }
        epoch_train_loss = 0.0
        epoch_train_acc = 0.0
        epoch_count = 0.0
        
        for epoch in range(self.epochs):
            
            running_loss = 0.0
            running_lp_fraction = 0.0
            running_accuracy = 0.0
            tick = time.time()
            t = time.time()

            for i, data in enumerate(self.trainloader, 0):
                
                overhead_duration = time.time() - t
                t = time.time()

                inputs, labels = data
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.net(inputs)
                loss = self.criterion(outputs, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                preds = torch.argmax(outputs, axis=1)
                acc = ((preds == labels) * 1.0).mean()

                running_loss += loss.item()
                running_accuracy += acc.item()
                epoch_train_loss += loss.item() * inputs.shape[0]
                epoch_train_acc += acc.item() * inputs.shape[0]
                epoch_count += inputs.shape[0]

                processing_duration = time.time() - t
                t = time.time()
                running_lp_fraction += overhead_duration / processing_duration

                if i % self.log_every_iter == self.log_every_iter - 1:
                    tock = time.time()
                    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / self.log_every_iter:.3f}, \
                            accuracy = {running_accuracy / self.log_every_iter:.4f}')
                    tick = time.time()
                    running_loss = 0.0
                    running_lp_fraction = 0.0
                    running_accuracy = 0.0
                
                # if i % self.save_every_iter == self.save_every_iter - 1:
                #     self.save_model(epoch)
                #     print('Training saved')

            if epoch % self.eval_every_epoch == self.eval_every_epoch - 1:
                epoch_val_acc = self.evaluate()
                self.net.train()
            
            history['train_loss'].append(epoch_train_loss / epoch_count)
            history['train_acc'].append(epoch_train_acc / epoch_count)
            history['val_acc'].append(epoch_val_acc)
            epoch_train_loss = 0.0
            epoch_train_acc = 0.0
            epoch_count = 0.0

        self.save_model(self.epochs, False)
        print('Finished Training')
        return history

    def evaluate(self):
        print('Evaluating on validation set ...')
        sum_acc, sum_loss = 0.0, 0.0
        total_data = 0
        self.net.eval()
        for i, data in enumerate(self.valloader, 0):
            inputs, labels = data
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            outputs = self.net(inputs)
            total_data += outputs.shape[0]
            preds = torch.argmax(outputs, axis=1)
            acc = ((preds == labels) * 1.0).mean()
            sum_acc += acc.item() * preds.shape[0]

        epoch_acc = sum_acc / total_data
        print('{}: {}'.format('accuracy', epoch_acc)) 
        print('-------------------------------')
        return epoch_acc

    def save_model(self, epoch, checkpoint=True):
        if checkpoint:
            torch.save({
                'model_state_dict': self.net.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
            }, os.path.join(self.save_path, '%04d.pth' % (epoch + 1)))
        else:
            torch.save({
                'model_state_dict': self.net.state_dict()
            }, os.path.join(self.save_path, '{}.pth'.format(self.name)))

# **Part A**

## Model Definition

In [6]:
class CustomNet(nn.Module):
    def __init__(self, input_channels=1):
        super(CustomNet, self).__init__()
        self.pn_back = ProtoNetBack(input_channels=input_channels)
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pn_back(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        out = self.fc3(x)
        return out

## Train on CIFAR-10

In [7]:
from torchvision.transforms.transforms import RandomHorizontalFlip
transform = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Resize((32, 32)),
     transforms.RandomCrop(25),
     transforms.Resize((32, 32)),
     transforms.RandomHorizontalFlip(0.5),
     transforms.ColorJitter(brightness=0.5, hue=0.3),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

trainset = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


In [8]:
net = CustomNet(input_channels=3).to(DEVICE)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

tr = Trainer('4a', net, optimizer, criterion, 20, trainloader, testloader,
             log_every_iter=30, eval_every_epoch=1, save_every_iter=190,
             device=DEVICE)
history = tr.train()

Number of iterations in each epoch: 391
Number of validation iterations: 79
Training started ...
[1,    30] loss: 2.136,                             accuracy = 0.1948
[1,    60] loss: 1.948,                             accuracy = 0.2745
[1,    90] loss: 1.870,                             accuracy = 0.3000
[1,   120] loss: 1.783,                             accuracy = 0.3419
[1,   150] loss: 1.731,                             accuracy = 0.3677
[1,   180] loss: 1.630,                             accuracy = 0.3995
[1,   210] loss: 1.622,                             accuracy = 0.4047
[1,   240] loss: 1.600,                             accuracy = 0.4232
[1,   270] loss: 1.507,                             accuracy = 0.4568
[1,   300] loss: 1.488,                             accuracy = 0.4602
[1,   330] loss: 1.483,                             accuracy = 0.4727
[1,   360] loss: 1.412,                             accuracy = 0.4885
[1,   390] loss: 1.351,                             accuracy = 

In [None]:
# plots

xaxis = list(range(len(history['train_loss'])))
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(8, 4)
ax1.plot(xaxis, history['train_loss'])
ax1.set_title('Train Loss')
ax1.set_xlabel('Epoch Number')
ax1.set_ylabel('Loss')
ax2.plot(xaxis, history['train_acc'], label='Train')
ax2.plot(xaxis, history['val_acc'], label='Validation')
ax2.set_title('Accuracy')
ax2.set_xlabel('Epoch Number')
ax2.set_ylabel('Accuracy')
ax2.legend()

# **Part B**

## Hyperparameters

In [30]:
BATCH_SIZE = 128

## Tiny Image Dataset

In [31]:
class TinyImageDataset(Dataset):

    def __init__(self, path, transform=None):
        self.classes = sorted(glob.glob('{}/*'.format(path)))
        self.num_images = len(glob.glob('{}/*/*'.format(path), recursive=True))
        self.x = []
        self.y = []
        for i, c in enumerate(self.classes):
            image_paths = glob.glob('{}/*'.format(c))
            for image_path in image_paths:
                img = self.pil_loader(image_path)
                self.x.append(img)
                self.y.append(i)
        self.transform = transform

    def pil_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        return self.transform(self.x[idx]), self.y[idx]

transform = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Resize((32, 32)),
    #  transforms.RandomCrop(25),
    #  transforms.Resize((32, 32)),
    #  transforms.RandomHorizontalFlip(0.5),
    #  transforms.ColorJitter(brightness=0.5, hue=0.3),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

trainset = TinyImageDataset(data_path + 'train', transform=transform)
valset = TinyImageDataset(data_path + 'val', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

In [None]:
img = trainset[100][0].numpy()
plt.imshow(img.transpose((1, 2, 0)))

## Part B.1 (Training from scratch)

In [35]:
net = CustomNet(input_channels=3).to(DEVICE)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

tr = Trainer('4b1', net, optimizer, criterion, 20, trainloader, valloader,
             log_every_iter=15, eval_every_epoch=1, device=DEVICE)
history = tr.train()

Number of iterations in each epoch: 40
Number of validation iterations: 4
Training started ...
[1,    15] loss: 2.089,                             accuracy = 0.2448
[1,    30] loss: 1.709,                             accuracy = 0.3990
Evaluating on validation set ...
accuracy: 0.27199999904632566
-------------------------------
[2,    15] loss: 1.646,                             accuracy = 0.4198
[2,    30] loss: 1.506,                             accuracy = 0.4854
Evaluating on validation set ...
accuracy: 0.4059999983310699
-------------------------------
[3,    15] loss: 1.425,                             accuracy = 0.5000
[3,    30] loss: 1.367,                             accuracy = 0.5312
Evaluating on validation set ...
accuracy: 0.47200000047683716
-------------------------------
[4,    15] loss: 1.320,                             accuracy = 0.5484
[4,    30] loss: 1.241,                             accuracy = 0.5672
Evaluating on validation set ...
accuracy: 0.5079999966621399

In [None]:
# plots

xaxis = list(range(len(history['train_loss'])))
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(8, 4)
ax1.plot(xaxis, history['train_loss'])
ax1.set_title('Train Loss')
ax1.set_xlabel('Epoch Number')
ax1.set_ylabel('Loss')
ax2.plot(xaxis, history['train_acc'], label='Train')
ax2.plot(xaxis, history['val_acc'], label='Validation')
ax2.set_title('Accuracy')
ax2.set_xlabel('Epoch Number')
ax2.set_ylabel('Accuracy')
ax2.legend()

## Part B.2 (Finetuning feed-forward layers of pretrained model)

In [45]:
net = CustomNet(input_channels=3).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()

checkpoint = torch.load('4a.pth')
net.load_state_dict(checkpoint['model_state_dict'])
for param in net.parameters():
    param.requires_grad = False
net.fc1 = nn.Linear(512, 256)
net.fc2 = nn.Linear(256, 128)
net.fc3 = nn.Linear(128, 10)
net = net.to(DEVICE)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

tr = Trainer('4b2', net, optimizer, criterion,
             20, trainloader, valloader, log_every_iter=15, eval_every_epoch=1,
             device=DEVICE)
history = tr.train()

Number of iterations in each epoch: 40
Number of validation iterations: 4
Training started ...
[1,    15] loss: 2.173,                             accuracy = 0.2354
[1,    30] loss: 1.890,                             accuracy = 0.3422
Evaluating on validation set ...
accuracy: 0.3859999964237213
-------------------------------
[2,    15] loss: 1.616,                             accuracy = 0.4354
[2,    30] loss: 1.554,                             accuracy = 0.4474
Evaluating on validation set ...
accuracy: 0.45600000524520873
-------------------------------
[3,    15] loss: 1.497,                             accuracy = 0.5010
[3,    30] loss: 1.412,                             accuracy = 0.5182
Evaluating on validation set ...
accuracy: 0.4339999966621399
-------------------------------
[4,    15] loss: 1.390,                             accuracy = 0.5333
[4,    30] loss: 1.370,                             accuracy = 0.5198
Evaluating on validation set ...
accuracy: 0.44800000262260437

In [None]:
# plots

xaxis = list(range(len(history['train_loss'])))
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(8, 4)
ax1.plot(xaxis, history['train_loss'])
ax1.set_title('Train Loss')
ax1.set_xlabel('Epoch Number')
ax1.set_ylabel('Loss')
ax2.plot(xaxis, history['train_acc'], label='Train')
ax2.plot(xaxis, history['val_acc'], label='Validation')
ax2.set_title('Accuracy')
ax2.set_xlabel('Epoch Number')
ax2.set_ylabel('Accuracy')
ax2.legend()

## Part B.3 (Finetuning the whole pretrained model)

In [49]:
net = CustomNet(input_channels=3).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()

checkpoint = torch.load('4a.pth')
net.load_state_dict(checkpoint['model_state_dict'])
net.fc1 = nn.Linear(512, 256)
net.fc2 = nn.Linear(256, 128)
net.fc3 = nn.Linear(128, 10)
net = net.to(DEVICE)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

def accuracy_fn(outputs, y):
    preds = torch.argmax(outputs, axis=1)
    n = outputs.shape[0]
    return (preds == y).sum() / n

tr = Trainer('4b3', net, optimizer, criterion,
             20, trainloader, valloader, log_every_iter=15, eval_every_epoch=1,
             device=DEVICE)
history = tr.train()

Number of iterations in each epoch: 40
Number of validation iterations: 4
Training started ...
[1,    15] loss: 2.292,                             accuracy = 0.1464
[1,    30] loss: 2.252,                             accuracy = 0.2422
Evaluating on validation set ...
accuracy: 0.28600000095367434
-------------------------------
[2,    15] loss: 2.131,                             accuracy = 0.3245
[2,    30] loss: 2.059,                             accuracy = 0.3005
Evaluating on validation set ...
accuracy: 0.3740000023841858
-------------------------------
[3,    15] loss: 1.909,                             accuracy = 0.3698
[3,    30] loss: 1.799,                             accuracy = 0.4161
Evaluating on validation set ...
accuracy: 0.39399999952316284
-------------------------------
[4,    15] loss: 1.679,                             accuracy = 0.4573
[4,    30] loss: 1.665,                             accuracy = 0.4646
Evaluating on validation set ...
accuracy: 0.4400000014305115

In [None]:
# plots

xaxis = list(range(len(history['train_loss'])))
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(8, 4)
ax1.plot(xaxis, history['train_loss'])
ax1.set_title('Train Loss')
ax1.set_xlabel('Epoch Number')
ax1.set_ylabel('Loss')
ax2.plot(xaxis, history['train_acc'], label='Train')
ax2.plot(xaxis, history['val_acc'], label='Validation')
ax2.set_title('Accuracy')
ax2.set_xlabel('Epoch Number')
ax2.set_ylabel('Accuracy')
ax2.legend()