# ViT on (Fashion) MNIST

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import Trainer

from att_tools import (
    MNISTDataModule,
    ClassifierViT
)

In [None]:
_ = torch.manual_seed(1223334444)  # set random seed manually

## (Fashion) MNIST data

In [None]:
data_set = 'mnist'

mnist = MNISTDataModule(
    data_set=data_set,
    data_dir='../run/data/',
    mean=0.5,
    std=0.5,
    batch_size=32
)

mnist.prepare_data()  # download data if not yet done
mnist.setup(stage='test')  # create test set

In [None]:
test_loader = mnist.test_dataloader()
x_batch, y_batch = next(iter(test_loader))

In [None]:
class_names = mnist.test_set.classes

fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(5, 4.5))
for idx, ax in enumerate(axes.ravel()):
    image = mnist.renormalize(x_batch[idx, 0]).numpy()
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set_title(class_names[y_batch[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## ViT model

In [None]:
ckpt_file = f'../run/{data_set}/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

vit = ClassifierViT.load_from_checkpoint(ckpt_file)

vit = vit.eval()
vit = vit.to(device)

## Testing

In [None]:
trainer = Trainer(logger=False)

test_metrics = trainer.test(
    model=vit,
    dataloaders=test_loader,
    verbose=True
)

In [None]:
confmat = vit.test_confmat.compute().numpy()  # assemble confusion matrix
norm_confmat = confmat / confmat.sum(axis=1).reshape(-1, 1)  # normalize rows

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
img = ax.imshow(
    norm_confmat, cmap='viridis', aspect='equal', vmin=0, vmax=1
)
ax.set(
    xlabel='predicted',
    ylabel='actual',
    xticks=list(range(len(class_names))),
    yticks=list(range(len(class_names)))
)
ax.set_xticklabels(class_names, rotation='vertical')
ax.set_yticklabels(class_names, rotation='horizontal')
fig.colorbar(img)
fig.tight_layout()

## Attention maps

In [None]:
with torch.no_grad():
    y, weights = vit(x_batch, return_weights=True)

print(f'Attention weights shape: {weights.shape}')

In [None]:
random_ids = torch.randperm(len(x_batch))

fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(5, 6))
for row_idx in range(axes.shape[0]):
    for col_idx in range(axes.shape[1]):
        if col_idx == 0:
            img = mnist.renormalize(x_batch[random_ids[row_idx], 0]).numpy()
            axes[row_idx,col_idx].imshow(img, cmap='gray', vmin=0, vmax=1)
        else:
            att = weights[random_ids[row_idx], col_idx-1, 0, 1:].view(4, 4).numpy()
            axes[row_idx,col_idx].imshow(att, cmap='gray', vmin=0)
        axes[row_idx,col_idx].set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()