# Get started

For now:
- ~~Load the data ~~
- ~~Import the ViT~~
- ~~Add the classification layers~~
- ~~Setup the training loop~~
- Divide code over separate files
- Validation set
- Calculate accuracy during training on validation set

Future task (separate notebooks):
- Data inspection
- Data augmentation -> Can be simply done in dataset class

Questions / things to look into:
- what are the class tokens -> _An image is worth 16 x 16 words_

In [None]:
import sys
import os
from types import MethodType

from functools import partial
import torch.nn as nn
import torch
import timm
import numpy as np
from PIL import Image
from typing import Callable
from torchsummary import summary

## Set the configuration
The finetuned model is trained for classification on imageNet, so the MLP head on the ViT is also trained. However, since the MLP head is removed from the vision transformer to make it an encoder, this shouldn't be a problem. The encoder weights are just different (if they were not frozen during finetuning. I am not sure, but I don't think that's the case). 

In [None]:
MODEL_SIZE = "base"  # options are ['base', 'large', 'huge']
WEIGHTS_VERSION = "pretrained"  # options are ['pretrained', 'finetuned']
WEIGHTS_FOLDER = "weights"
NO_CLASSES = 5
EMBED_DIM = 768  # hardcoded right now, but depends on ViT architecture

DATA_FOLDER = "data"
TRAIN_FOLDER = "train"
TRAIN_LABELS_CSV = "trainLabels.csv"

BATCH_SIZE = 4
EPOCHS = 2
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Import the pre-trained MAE ViT

This [Github repository](https://github.com/facebookresearch/mae) provides a PyTorch implementation of the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377). 

In [None]:
# TODO: implement freeze of whole model except head
def prepare_vision_transformer(
    checkpoint_directory: str,
    model_architecture: dict,
    classification_head: nn.Module,
):
    """
    This function returns the vision transformer with the right head and weights.
    Arguments:
        checkpoint_directory (string): directory where the weights of the ViT are stored
        model_architecture (Callable): function that instantiates the ViT with certain settings
        classification_head (nn.Module): The classification head that will be attached directly to the ViT
    """
    vision_transformer = timm.models.vision_transformer.VisionTransformer(**model_architecture)
    # To ensure that the weights of the head are not set by the pretrained weights
    vision_transformer.head = None

    checkpoint = torch.load(checkpoint_directory)

    msg = vision_transformer.load_state_dict(checkpoint["model"], strict=False)
    print(msg)

    vision_transformer.head = classification_head

    return vision_transformer

In [None]:
# Architectures according to the original ViT paper: An image is worth 16x16 words
BASE_VIT = {
    "patch_size": 16,
    "embed_dim": 768,
    "depth": 12,
    "num_heads": 12,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}
LARGE_VIT = {
    "patch_size": 16,
    "embed_dim": 1024,
    "depth": 24,
    "num_heads": 16,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}
HUGE_VIT = {
    "patch_size": 14,
    "embed_dim": 1280,
    "depth": 32,
    "num_heads": 16,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}

In [None]:
# # Architectures taken from https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_vit.py#L56-L74
# def vit_base_patch16(**kwargs):
#     """ViT-Base as defined in: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"""
#     model = timm.models.vision_transformer.VisionTransformer(
#         patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
#         norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
#     return model

# def vit_large_patch16(**kwargs):
#     """ViT-Large as defined in: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"""
#     model = timm.models.vision_transformer.VisionTransformer(
#         patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
#         norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
#     return model

# def vit_huge_patch14(**kwargs):
#     """ViT-Huge as defined in: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"""
#     model = timm.models.vision_transformer.VisionTransformer(
#         patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
#         norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
#     return model

In [None]:
# Choose the weights and the architecture
chkpts_finetuned = {
    "base": "mae_finetuned_vit_base.pth",
    "large": "mae_finetuned_vit_large.pth",
    "huge": "mae_finetuned_vit_huge.pth",
}
chkpts_pretrained = {
    "base": "mae_pretrain_vit_base.pth",
    "large": "mae_pretrain_vit_large.pth",
    "huge": "mae_pretrain_vit_huge.pth",
}
chkpts = {'pretrained': chkpts_pretrained, 'finetuned': chkpts_finetuned}[WEIGHTS_VERSION]

model_architectures= {
    "base": BASE_VIT,
    "large": LARGE_VIT,
    "huge": HUGE_VIT,
}

model_arch = model_architectures[MODEL_SIZE]
chkpt_dir = os.path.join(WEIGHTS_FOLDER, chkpts[MODEL_SIZE])
print(f"Weights directory: \n\t{chkpt_dir}\nModel architecture: \n\t{model_arch}")

In [None]:
# Define the classification head of the ViT. Should be a torch.nn.module

class PassThrough(nn.Module):
    """
        This is a dummy head that just passes through the encoder output
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x
    
# ViT_head = PassThrough()

""" Use a linear layer as classification head. """
ViT_head = nn.Linear(model_arch['embed_dim'], NO_CLASSES)

In [None]:
# instantiate the model
vision_transformer = prepare_vision_transformer(
    checkpoint_directory=chkpt_dir,
    model_architecture=model_arch,
    classification_head=ViT_head,
)
# Output should be: <All keys matched successfully>

In [None]:
summary(vision_transformer, (3, 224, 224), device='cpu')

# Load the data
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

A class that contains the data. Extra data augmentation can be easily added. I already implemented the resize since the input images do not have the same size, which causes error when making a torch.Tensor with a batch of images.

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
import matplotlib.pyplot as plt


In [None]:
class DiabeticRetinopathyDataset(Dataset):
    """The Diabetic Retinopathy dataset from Kaggle."""

    def __init__(self, csv_file: str, root_dir: str, image_dir: str, size: int = None, transform=None):
        """
        Arguments:
            csv_file (string): The csv file with the labels
            root_dir (string): The directory where the data is stored. The csv file should be in this directory.
            image_dir (string): The path to the directory with the images from the root_dir.
            size (int): Only consider the first number of samples from the csv file. 
            transform (Callable, optional): Optional transform to be applied on samples
        """
        self.df = pd.read_csv(os.path.join(root_dir, csv_file))
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, image_dir)
        self.transform = transform
        self.ToTensor = transforms.ToTensor()
        
        self.items = self.df.iloc[:, 0]
        self.labels = self.df.iloc[:, 1]
        if size is not None:
            self.items = self.items[:size]
            self.labels = self.labels[:size]

    def __getitem__(self, idx) -> torch.Tensor:
        '''
        Can obtain images as tensors from dataset by giving an index or 
        list/np.array/torch.Tensor of indices as input.
        
        Return (torch.Tensor, torch.Tensor): a batch of images (B x C x H x W), labels (B)
        
        '''
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if isinstance(idx, int):
            idx = [idx]
        
        assert max(idx) < len(self.items), "Outside the size of the dataset"
        
        selected_items = self.items[idx]
        sel_imgs_paths = [os.path.join(self.image_dir, im) + '.jpeg' for im in self.items[idx].tolist()]
        sel_labels = self.labels[idx].tolist()
        images = [Image.open(p) for p in sel_imgs_paths]
        
        if self.transform:
            images = self.transform(images)
            
        images_tensors = [self.ToTensor(im) for im in images]
        
        return torch.stack(images_tensors, dim=0).squeeze(), torch.Tensor(sel_labels).type(torch.int64).squeeze()
    
    def __len__(self):
        return len(self.items)

In [None]:
class Resize(object):
    """
    A class to resize the samples used for data augmentation.
    
    input is list of PIL objects, output should be list of PIL objects
    
    Arguments:
        output_size (type): ...
    """
    def __init__(self, output_size: int):
        self.output_size = output_size
        
    def __call__(self, samples):
        samples = [im.resize((self.output_size,)*2) for im in samples]
        return samples

In [None]:
DR_dataset = DiabeticRetinopathyDataset(
    TRAIN_LABELS_CSV,
    DATA_FOLDER,
    TRAIN_FOLDER,
    transform=transforms.Compose([Resize(output_size=224)]),  # output size depends on the model
    size=800,
)

In [None]:
# Visualize some data
def visualise_batch(images, labels):
    for i, im in enumerate(images):
        ax = plt.subplot(1, len(labels), i+1)
        ax.set_title(f"{labels[i].tolist()}")
        ax.imshow(im.permute(1, 2, 0))
    
visualise_batch(*DR_dataset[[1, 2, 3]])

# Training loop
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

In [None]:
from tqdm import tqdm

In [None]:
training_loader = torch.utils.data.DataLoader(DR_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    avg_loss = 0.

    for i, data in tqdm(enumerate(training_loader)):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = vision_transformer(inputs)
        # Nonsens to make the shape align
        outputs = outputs[:, :5]

        # Compute the loss and its gradients
#         print(f"Model output: \n{outputs}")
#         print(f"labels: \n{labels}")
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        
        # Print an update every 20 batches
        if i % 20 == 19:
            avg_loss = running_loss / 100 # loss per batch
            print(f"  batch {i+1} loss: {avg_loss}")
            running_loss = 0.

    return avg_loss

In [None]:
optimizer = torch.optim.SGD(vision_transformer.parameters(), lr=0.001, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
# Freeze the whole model, except the classification head
for param in vision_transformer.parameters():
    param.requires_grad = False
for param in vision_transformer.head.parameters():
    param.requires_grad = True

In [None]:
vision_transformer.to(DEVICE)
for epoch in range(EPOCHS):
    vision_transformer.train(True)
    avg_loss = train_one_epoch(epoch+1)
    
    # Set the model to validation mode
    vision_transformer.eval()
    # Run on the validation set and calculate performance
    # TODO