# Run the models inside a directory on the test set

In [None]:
from data import DiabeticRetinopathyDataset, CropBlack, Resize
from torchvision import transforms
import torch
import timm
import os
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from torchsummary import summary

In [None]:
MODEL_ROOT_DIR = "models"
MODEL_NAME = "Kaggle One Layer"

DATA_FOLDER = "data"
TEST_FOLDER = "test"
TEST_LABELS_CSV = "testLabels.csv"

WEIGHTS_VERSION = "pretrained"  # Don't change
WEIGHTS_FOLDER = "weights"

NO_CLASSES = 5
BATCH_SIZE = 10

MODEL_SIZE = "large"  # options are ['base', 'large', 'huge']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch.cuda.empty_cache()

## Load the model

In [None]:
# TODO: implement freeze of whole model except head
def prepare_vision_transformer(
    checkpoint_directory: str,
    model_architecture: dict,
    classification_head: nn.Module,
):
    """
    This function returns the vision transformer with the right head and weights.
    Arguments:
        checkpoint_directory (string): directory where the weights of the ViT are stored
        model_architecture (Callable): function that instantiates the ViT with certain settings
        classification_head (nn.Module): The classification head that will be attached directly to the ViT
    """
    vision_transformer = timm.models.vision_transformer.VisionTransformer(**model_architecture)
    # To ensure that the weights of the head are not set by the pretrained weights
    vision_transformer.head = None

    checkpoint = torch.load(checkpoint_directory)

    msg = vision_transformer.load_state_dict(checkpoint["model"], strict=False)
    print(msg)

    vision_transformer.head = classification_head

    return vision_transformer

In [None]:
# Architectures according to the original ViT paper: An image is worth 16x16 words
BASE_VIT = {
    "patch_size": 16,
    "embed_dim": 768,
    "depth": 12,
    "num_heads": 12,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}
LARGE_VIT = {
    "patch_size": 16,
    "embed_dim": 1024,
    "depth": 24,
    "num_heads": 16,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}
HUGE_VIT = {
    "patch_size": 14,
    "embed_dim": 1280,
    "depth": 32,
    "num_heads": 16,
    "mlp_ratio": 4,
    "qkv_bias": True,
    "norm_layer": partial(nn.LayerNorm, eps=1e-6),
}

In [None]:
# Choose the weights and the architecture
chkpts_finetuned = {
    "base": "mae_finetuned_vit_base.pth",
    "large": "mae_finetuned_vit_large.pth",
    "huge": "mae_finetuned_vit_huge.pth",
}
chkpts_pretrained = {
    "base": "mae_pretrain_vit_base.pth",
    "large": "mae_pretrain_vit_large.pth",
    "huge": "mae_pretrain_vit_huge.pth",
}
chkpts = {'pretrained': chkpts_pretrained, 'finetuned': chkpts_finetuned}[WEIGHTS_VERSION]

model_architectures= {
    "base": BASE_VIT,
    "large": LARGE_VIT,
    "huge": HUGE_VIT,
}

model_arch = model_architectures[MODEL_SIZE]
chkpt_dir = os.path.join(WEIGHTS_FOLDER, chkpts[MODEL_SIZE])
print(f"Weights directory: \n\t{chkpt_dir}\nModel architecture: \n\t{model_arch}")

In [None]:
from heads import OneLayer

In [None]:
# The heads are defined in heads.py
ViT_HEAD = OneLayer(model_arch['embed_dim'], NO_CLASSES)
# ViT_HEAD = PassThrough()



In [None]:
# instantiate the model
vision_transformer = prepare_vision_transformer(
    checkpoint_directory=chkpt_dir,
    model_architecture=model_arch,
    classification_head=ViT_HEAD,
)
# Output should be: <All keys matched successfully>

In [None]:
summary(vision_transformer, (3, 224, 224), device='cpu')

## Load the data

In [None]:
DR_test_set = DiabeticRetinopathyDataset(
    TEST_LABELS_CSV,
    DATA_FOLDER,
    TEST_FOLDER,
    transform=transforms.Compose([CropBlack(),
                                  Resize(output_size=224)]),
    sample_rates=None,
#     size=50,
)

In [None]:
len(DR_test_set)

In [None]:
# Visualize some data
def visualise_batch(images, labels):
    for i, im in enumerate(images):
        ax = plt.subplot(1, len(labels), i+1)
        ax.set_title(f"{labels[i].tolist()}")
        ax.imshow(im.permute(1, 2, 0))
    
visualise_batch(*DR_test_set[[1, 2, 3]])

In [None]:
label_count = np.unique(DR_test_set.labels, return_counts=True)
print(" label | count \n" + \
      "-------|-------")
display = lambda c : str(c) + " " * (6-len(str(c)))
for label, count in zip(*label_count):
    print(f"   {label}   | {display(count)}") 

## Import the head weights

In [None]:
weight_files = np.array([f for f in os.listdir(os.path.join(MODEL_ROOT_DIR, MODEL_NAME))
                        if os.path.isfile(os.path.join(MODEL_ROOT_DIR, MODEL_NAME, f))])

# 'OneLayer_35.pth' is for example shown before 'OneLayer_4.pth', so fix this
numbers = []
for idx, f in enumerate(weight_files):
    name = f.split("_")[0]
    number = (f.split(".")[0]).split("_")[1]
    number = "0" + number if len(number) == 1 else number
    numbers.append(number)

# Start with the latest epoch
weight_files = weight_files[np.argsort(numbers)][::-1]  
weight_files

In [None]:
def load_head_weights(model: nn.Module, weight_file: str):
    checkpoint_dir = os.path.join(MODEL_ROOT_DIR, MODEL_NAME, weight_file)
    checkpoint = torch.load(checkpoint_dir)
    
    msg = model.head.load_state_dict(checkpoint, strict=True)
    print(msg)
    
    return model

vision_transformer = load_head_weights(vision_transformer, weight_files[0])

## Run the testing

In [None]:
from IPython.display import clear_output
from tqdm import tqdm
from WeightedKappaLoss import WeightedKappaLoss

In [None]:
test_loader = torch.utils.data.DataLoader(DR_test_set, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
acc_fn = WeightedKappaLoss(num_classes=5, mode='quadratic', validate=True)

In [None]:
vision_transformer.to(DEVICE)
test_accs = []

### Run the testing for one weight file
# --> Pick the epoch the weight file belongs to: 
epoch = 31
weight_file = weight_files[-epoch] 
print(f"Analyzing weights from {weight_file}")
# Load the weights
checkpoint_dir = os.path.join(MODEL_ROOT_DIR, weight_file)
checkpoint = torch.load(checkpoint_dir)

msg = vision_transformer.head.load_state_dict(checkpoint, strict=True)
print(msg)

# Put the model in evaluation mode
vision_transformer.eval()
_, avg_test_acc = validate(model=vision_transformer,
                           epoch_index=epoch,
                           validation_loader=test_loader,
                           loss_fn=None,
                           acc_fn=acc_fn,
                          )
test_accs.append(avg_test_acc)

print(f"The Quadratic Weighted Kappa Score on the test set for {weight_file} is: {avg_test_acc}")
