# Testing ViT implementation
This notebook goes through testing my implementation of ViT described in [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929), you can also find the paper under `refrences`.

I explain some concepts and add notes along the way.

In [None]:
#@title Import packages
import seaborn
seaborn.set()
import torch
import io
import urllib.request

from data.load_data import load_data
from data.utils import patchify, unpatchify
from architectures.ClassificationViT import ClassificationViT

In [None]:
#@title Download Testing Data
def load_from_url(url):
    return torch.load(io.BytesIO(urllib.request.urlopen(url).read()))

test_data = load_from_url('https://github.com/Berkeley-CS182/cs182hw9/raw/main/test_reference.pt')
auto_grader_data = load_from_url('https://github.com/Berkeley-CS182/cs182hw9/raw/main/autograder_student.pt')
auto_grader_data['output'] = {}

In [None]:
#@title Utilities for Testing
def save_auto_grader_data():
    torch.save(
        {'output': auto_grader_data['output']},
        'autograder.pt'
    )

def rel_error(x, y):
    return torch.max(
        torch.abs(x - y)
        / (torch.maximum(torch.tensor(1e-8), torch.abs(x) + torch.abs(y)))
    ).item()

def check_error(name, x, y, tol=1e-3):
    error = rel_error(x, y)
    if error > tol:
        print(f'The relative error for {name} is {error}, should be smaller than {tol}')
    else:
        print(f'The relative error for {name} is {error}')

def check_acc(acc, threshold):
    if acc < threshold:
        print(f'The accuracy {acc} should >= threshold accuracy {threshold}')
    else:
        print(f'The accuracy {acc} is better than threshold accuracy {threshold}')

## Vision Transformer
The first part of this notebook is implementing Vision Transformer (ViT) and training it on CIFAR dataset.


### Image patchify and unpatchify

In ViT, an image is split into fixed-size patches, each of them are then linearly embedded, position embeddings are added, and the resulting sequence of vectors is fed to a standard Transformer encoder. The architecture can be seen in the following figure.
![vit](https://github.com/google-research/vision_transformer/blob/main/vit_figure.png?raw=true)

To get started with implementing ViT, we need to implement splitting image batch into fixed-size patches batch in ```patchify``` and combining patches batch into the original image batch in ```unpatchify```. The `patchify` function has been implemented for you. **Please implement `unpatchify`,** assuming that the output image is squared.

This implementation uses [einops](https://github.com/arogozhnikov/einops) for flexible tensor operations, you can check out its [tutorial](https://einops.rocks/1-einops-basics/).

In [None]:
#@title Test your implementation
x = test_data['input']['patchify']
y = test_data['output']['patchify']
check_error('patchify', patchify(x), y)

x = auto_grader_data['input']['patchify']
auto_grader_data['output']['patchify'] = patchify(x)
save_auto_grader_data()


x = test_data['input']['unpatchify']
y = test_data['output']['unpatchify']
check_error('unpatchify', unpatchify(x), y)

x = auto_grader_data['input']['unpatchify']
auto_grader_data['output']['unpatchify'] = unpatchify(x)

save_auto_grader_data()

### ViT Model

The implementation if transformer simply wraps `nn.TransformerEncoder` of PyTorch, you can find it under `architectures/Transformer`. 

We combine the transformer with the patching and classification head in  `architectures/ClassificationViT`. 


#### Testing Implementation

In [None]:
model = ClassificationViT(10)
model.load_state_dict(test_data['weights']['ClassificationViT'])
x = test_data['input']['ClassificationViT.forward']
y = model.forward(x)
check_error('ClassificationViT.forward', y, test_data['output']['ClassificationViT.forward'])

model.load_state_dict(auto_grader_data['weights']['ClassificationViT'])
x = auto_grader_data['input']['ClassificationViT.forward']
y = model.forward(x)
auto_grader_data['output']['ClassificationViT.forward'] = y
save_auto_grader_data()

### Data Loader and Preprocess

We use ```torchvision``` to download and prepare images and labels. ViT usually works on a much larger image dataset, but due to our limited computational resources, we train our ViT on CIFAR-10.

You can find the code under `data/load_data`

In [None]:
train_loader, test_loader = load_data()
train_loader, test_loader

### Supervised Training ViT
You can test the full implementation by running the training script `main.py`. You can find a notebook that does so inside `notebooks/training_test`