# Get started

For now:

- ~~Load the data~~
- ~~Import the ViT~~
- ~~Add the classification layers~~
- ~~Setup the training loop~~
- ~~Divide code over separate files~~
- ~~Validation set~~
- ~~Calculate accuracy during training on validation set~~

Future task (separate notebooks):
- ~~Data inspection~~ -> Created a smaller dataset to train faster
- Data augmentation -> Can be simply done in dataset class

Questions / things to look into:
- what are the class tokens -> _An image is worth 16 x 16 words_

In [None]:
import sys
import os
from types import MethodType

from functools import partial
import torch.nn as nn
import torch
import timm
import numpy as np
from PIL import Image
from typing import Callable
from torchsummary import summary

from heads import *

In [None]:
# In case you gpu memory remains occupied by PyTorch after restarting the kernel
torch.cuda.empty_cache() 

## Set the configuration
The finetuned model is trained for classification on imageNet, so the MLP head on the ViT is also trained. However, since the MLP head is removed from the vision transformer to make it an encoder, this shouldn't be a problem. The encoder weights are just different (if they were not frozen during finetuning. I am not sure, but I don't think that's the case). 

In [None]:
MODEL_SIZE = "base"  # options are ['base', 'large', 'huge']
WEIGHTS_VERSION = "pretrained"  # options are ['pretrained', 'finetuned']
WEIGHTS_FOLDER = "weights"
NO_CLASSES = 5

DATA_FOLDER = "data"
TRAIN_FOLDER = "train"
TRAIN_LABELS_CSV = "reducedTrainLabels.csv"  # "trainLabels.csv"

BATCH_SIZE = 4
EPOCHS = 90
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Import the pre-trained MAE ViT

This [Github repository](https://github.com/facebookresearch/mae) provides a PyTorch implementation of the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377). 

In [None]:
# TODO: implement freeze of whole model except head
def prepare_vision_transformer(
    checkpoint_directory: str,
    model_architecture: dict,
    classification_head: nn.Module,
):
    """
    This function returns the vision transformer with the right head and weights.
    Arguments:
        checkpoint_directory (string): directory where the weights of the ViT are stored
        model_architecture (Callable): function that instantiates the ViT with certain settings
        classification_head (nn.Module): The classification head that will be attached directly to the ViT
    """
    vision_transformer = timm.models.vision_transformer.VisionTransformer(**model_architecture)
    # To ensure that the weights of the head are not set by the pretrained weights
    vision_transformer.head = None

    checkpoint = torch.load(checkpoint_directory)

    msg = vision_transformer.load_state_dict(checkpoint["model"], strict=False)
    print(msg)

    vision_transformer.head = classification_head

    return vision_transformer

In [None]:
# Architectures according to the original ViT paper: An image is worth 16x16 words
BASE_VIT = {
    "patch_size": 16,
    "embed_dim": 768,
    "depth": 12,
    "num_heads": 12,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}
LARGE_VIT = {
    "patch_size": 16,
    "embed_dim": 1024,
    "depth": 24,
    "num_heads": 16,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}
HUGE_VIT = {
    "patch_size": 14,
    "embed_dim": 1280,
    "depth": 32,
    "num_heads": 16,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}

In [None]:
# Choose the weights and the architecture
chkpts_finetuned = {
    "base": "mae_finetuned_vit_base.pth",
    "large": "mae_finetuned_vit_large.pth",
    "huge": "mae_finetuned_vit_huge.pth",
}
chkpts_pretrained = {
    "base": "mae_pretrain_vit_base.pth",
    "large": "mae_pretrain_vit_large.pth",
    "huge": "mae_pretrain_vit_huge.pth",
}
chkpts = {'pretrained': chkpts_pretrained, 'finetuned': chkpts_finetuned}[WEIGHTS_VERSION]

model_architectures= {
    "base": BASE_VIT,
    "large": LARGE_VIT,
    "huge": HUGE_VIT,
}

model_arch = model_architectures[MODEL_SIZE]
chkpt_dir = os.path.join(WEIGHTS_FOLDER, chkpts[MODEL_SIZE])
print(f"Weights directory: \n\t{chkpt_dir}\nModel architecture: \n\t{model_arch}")

In [None]:
# The heads are defined in heads.py
ViT_HEAD = OneLayer(model_arch['embed_dim'], NO_CLASSES)
# ViT_HEAD = PassThrough()

In [None]:
# instantiate the model
vision_transformer = prepare_vision_transformer(
    checkpoint_directory=chkpt_dir,
    model_architecture=model_arch,
    classification_head=ViT_HEAD,
)
# Output should be: <All keys matched successfully>

In [None]:
summary(vision_transformer, (3, 224, 224), device='cpu')

# Load the data
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

A class that contains the data. Extra data augmentation can be easily added. I already implemented the resize since the input images do not have the same size, which causes error when making a torch.Tensor with a batch of images.

In [None]:
from data import DiabeticRetinopathyDataset, Resize
from torchvision import transforms
import matplotlib.pyplot as plt
import data
from importlib import reload  # Python 3.4+
reload(data)

In [None]:
DR_dataset = data.DiabeticRetinopathyDataset(
    TRAIN_LABELS_CSV,
    DATA_FOLDER,
    TRAIN_FOLDER,
    transform=transforms.Compose([Resize(output_size=224)]),  # output size depends on the model
    size=40,
)
# train_set, val_set = DR_dataset.train_val_split(split_rate=0.8)
generator = torch.Generator().manual_seed(42)
train_set, val_set = torch.utils.data.random_split(DR_dataset, [0.8, 0.2], generator=generator)

In [None]:
# Visualize some data
def visualise_batch(images, labels):
    for i, im in enumerate(images):
        ax = plt.subplot(1, len(labels), i+1)
        ax.set_title(f"{labels[i].tolist()}")
        ax.imshow(im.permute(1, 2, 0))
    
visualise_batch(*train_set[[1, 2, 3]])

In [None]:
label_count = np.unique(DR_dataset.labels, return_counts=True)
print(" label | count \n" + \
      "-------|-------")
display = lambda c : str(c) + " " * (6-len(str(c)))
for label, count in zip(*label_count):
    print(f"   {label}   | {display(count)}") 

# Training loop
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt

from training import train_one_epoch, validate, save_model
from WeightedKappaLoss import WeightedKappaLoss

In [None]:
training_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
loss_fn = WeightedKappaLoss(num_classes=5, mode='quadratic')
acc_fn = WeightedKappaLoss(num_classes=5, mode='quadratic', validate=True)  # Returns a slightly different value

# From 'Masked Autoencoders Are Scalable Vision Learners' their linear probing procedure
blr = 0.1
lr = blr * BATCH_SIZE / 256
optimizer = torch.optim.SGD(vision_transformer.parameters(), lr=lr, momentum=0.9)
lr

In [None]:
# Freeze the whole model, except the classification head
for param in vision_transformer.parameters():
    param.requires_grad = False
for param in vision_transformer.head.parameters():
    param.requires_grad = True

In [None]:
# Define the name of the directory you want to save the current training session in.
# If the directory does not exist, it will be automatically created.
RUN_NAME = "Try_out"

In [None]:
vision_transformer.to(DEVICE)
train_losses, val_losses, val_accs = [], [], []
for i in range(EPOCHS):
    epoch = i + 1
    vision_transformer.train(True)
    
    # Train over all training data
    avg_train_loss = train_one_epoch(model=vision_transformer,
                               epoch_index=epoch,
                               training_loader=training_loader,
                               optimizer=optimizer,
                               loss_fn=loss_fn,
                              )
    train_losses.append(avg_train_loss)
    
    # Set the model to validation mode
    vision_transformer.eval()
    
    # Validation on validation data
    avg_val_loss, avg_val_acc = validate(model=vision_transformer,
                                         epoch_index=epoch,
                                         validation_loader=validation_loader,
                                         loss_fn=loss_fn,
                                         acc_fn=acc_fn
                                        )
    val_losses.append(avg_val_loss)
    val_accs.append(avg_val_acc)    
    
    # Save model
    save_model(vision_transformer, epoch, RUN_NAME)
    
    # plot statistics
    clear_output(wait=True)
    epochs_range = np.arange(i+1)
    ax = plt.subplot(1,2,1)
    ax.plot(epochs_range, np.array(train_losses), label="Train loss")
    ax.plot(epochs_range, np.array(val_losses), label="Val loss")
    ax.legend()
    ax = plt.subplot(1,2,2)
    ax.plot(epochs_range, np.array(val_accs), label="Val acc")
    ax.legend()
    plt.show()
        
    