# Image Classification with Vision Transformers
in this notebook we'll go through some code examples of various transformers.  
we'll either implement the architecture from scratch using pytorch (following [keras code examples](https://keras.io/examples/)) or use the implementation from [vit-pytorch](https://github.com/guyk1971/vit-pytorch)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython.core.display import display, HTML
# from IPython.core.debugger import set_trace
display(HTML('<style>.container { width:75% !important; }</style>')) 
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline


In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets
from torchvision import transforms as T

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Image Classification with ViT
Reference: [keras example](https://keras.io/examples/vision/image_classification_with_vision_transformer/)  
In this section we'll implement the [Vision transformer](https://arxiv.org/abs/2010.11929) on CIFAR100 dataset

## prepare the data
including the transforms that has to be done when accessing items in the dataset (part of the `_get_item()` method)  
see [cifar10 tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) for a demo how compose transforms. we'll do the same normalization here.  

another reference : [pytorch cifar](https://github.com/kuangliu/pytorch-cifar)

for the normalization, one has to calculate the mean and stdev of the dataset. this can be done according to this [blog post](https://towardsdatascience.com/how-to-calculate-the-mean-and-standard-deviation-normalizing-datasets-in-pytorch-704bd7d05f4c)  
Other than the normalization, the keras example has some additional transformations. to follow the flow of the keras example, we'll start with defining the CIFAR100 class without these data augmentation transforms. we'll add them in a following subsection




In [4]:
num_classes = 100
input_shape = (32, 32, 3)

training_data = datasets.CIFAR100(
    root="~/hda/data/cifar-100",
    train=True,
    download=True,
    transform=T.ToTensor()
)

test_data = datasets.CIFAR100(
    root="~/hda/data/cifar-100",
    train=False,
    download=True,
    transform=T.ToTensor()
)

print(f"x_train shape: {training_data.data.shape} - y_train shape: {len(training_data.targets)}")
print(f"x_test shape: {test_data.data.shape} - y_test shape: {len(test_data.targets)}")

Files already downloaded and verified
Files already downloaded and verified
x_train shape: (50000, 32, 32, 3) - y_train shape: 50000
x_test shape: (10000, 32, 32, 3) - y_test shape: 10000


In [5]:
# Calculating the statistics for normalization: https://towardsdatascience.com/how-to-calculate-the-mean-and-standard-deviation-normalizing-datasets-in-pytorch-704bd7d05f4c
train_dataloader = DataLoader(dataset=training_data, batch_size=64)

def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

mean,std = get_mean_and_std(train_dataloader)
train_mean=tuple(mean.cpu().numpy())
train_std=tuple(std.cpu().numpy())
print(train_mean,train_std)


(0.50704783, 0.48648766, 0.4408386) (0.26733738, 0.25643668, 0.27614135)


In [None]:
# check the data
img,lbl=training_data[0]
plt.imshow(img.permute(1,2,0).cpu())

In [None]:
img.dtype

In [None]:
training_data.data.dtype

In [None]:
img.min()

## Configure the Hyperparameters
configure as follows: 
```
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
```



In [6]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

## Data Augmentation

```
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
```

Lets write the data_augmentation module in pytorch. 
following the example from "Modern Computer Vision with Pytorch" chapter 06, we'll define the data augmentation for pytorch - they have used a special package there but recommended the `transforms` module in pytorch. so let's use that. [transforms documentation](https://pytorch.org/vision/stable/transforms.html#transforms-on-pil-image-and-torch-tensor)  
Note : for an illustration of pytorch transformations see [this page](https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py)




- **Normalization** - not sure that we need the normalization. the CIFAR100 dataset already normalizes the input as part of the `__get_item__()` method. according to the [cifar10 tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) : "_The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1]._". the normalization is done by first transforming them [`ToTensor`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor) (which transform them to float at [0,1]) followed by the normalization with (0.5,0.5,0.5) for mean and stdev


see also this [data loadting tutorial](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)

In [7]:
data_augmentation = T.Compose([
        T.ToTensor(),
        T.Normalize(train_mean, train_std),
        T.Resize((image_size,image_size)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(5),
        ])

# lets redefine the datasets, now with the data_augmentation transforms
training_data = datasets.CIFAR100(
    root="~/hda/data/cifar-100",
    train=True,
    download=True,
    transform=data_augmentation
)

test_data = datasets.CIFAR100(
    root="~/hda/data/cifar-100",
    train=False,
    download=True,
    transform=data_augmentation
)

print(f"x_train shape: {training_data.data.shape} - y_train shape: {len(training_data.targets)}")
print(f"x_test shape: {test_data.data.shape} - y_test shape: {len(test_data.targets)}")

Files already downloaded and verified
Files already downloaded and verified
x_train shape: (50000, 32, 32, 3) - y_train shape: 50000
x_test shape: (10000, 32, 32, 3) - y_test shape: 10000


In [None]:
# check the data
img,lbl=training_data[2]
plt.imshow(img.permute(1,2,0).cpu())

In [None]:
plt.imshow(training_data.data[2])

In [8]:
# create data loaders
train_dataloader = DataLoader(dataset=training_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)

## Creating a ViT Model 

### Using `vit-pytorch`

In [9]:
from vit_pytorch import ViT

In [10]:
def create_vit_classifier():
    vit = ViT(image_size = image_size,
              patch_size = patch_size,
              num_classes = num_classes,
              dim = projection_dim,
              depth = transformer_layers,
              heads = num_heads,
              mlp_dim = 2048,
              dropout = 0.1,
              emb_dropout = 0.1)
    return vit

In [11]:
def train_epoch(model, device, train_dataloader, optim, epoch):
    model.train()
    for b_i, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)
        optim.zero_grad()
        pred_prob = model(X)
        loss = F.nll_loss(pred_prob, y) # nll is the negative likelihood loss
        loss.backward()
        optim.step()
        if b_i % 10 == 0:
            print('epoch: {} [{}/{} ({:.0f}%)]\t training loss: {:.6f}'.format(
                epoch, b_i * len(X), len(train_dataloader.dataset),
                100. * b_i / len(train_dataloader), loss.item()))
            

            
            
            
def test_epoch(model, device, test_dataloader):
    model.eval()
    loss = 0
    success = 0
    with torch.no_grad():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            pred_prob = model(X)
            loss += F.nll_loss(pred_prob, y, reduction='sum').item()  # loss summed across the batch
            pred = pred_prob.argmax(dim=1, keepdim=True)  # us argmax to get the most likely prediction
            success += pred.eq(y.view_as(pred)).sum().item()

    loss /= len(test_dataloader.dataset)

    print('\nTest dataset: Overall Loss: {:.4f}, Overall Accuracy: {}/{} ({:.0f}%)\n'.format(
        loss, success, len(test_dataloader.dataset),
        100. * success / len(test_dataloader.dataset)))

In [12]:
def run_experiment(model):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

    for epoch in range(1, num_epochs):
        train_epoch(model, device, train_dataloader, optimizer, epoch)
        test_epoch(model, device, test_dataloader)
    

vit_classifier = create_vit_classifier()

history = run_experiment(vit_classifier)


Test dataset: Overall Loss: -18.9049, Overall Accuracy: 100/10000 (1%)


Test dataset: Overall Loss: -56.5573, Overall Accuracy: 100/10000 (1%)


Test dataset: Overall Loss: -114.0480, Overall Accuracy: 100/10000 (1%)


Test dataset: Overall Loss: -190.7753, Overall Accuracy: 100/10000 (1%)



KeyboardInterrupt: 