# MinViT

In this notebook, I attempt to explain the vision transformer (ViT) architecture, which has found its way into computer vision as a powerful alternative to Convolutional Neural Networks (CNNs).

This implementation will focus on classifying the CIFAR-10 dataset, but is adaptable to many tasks, including semantic segmentation, instance segmentation, and image generation.

We begin by downloading the CIFAR-10 dataset, and transforming the data to `torch.Tensor`s.

In [17]:
import numpy as np
import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CIFAR10(root='./data/cifar-10', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data/cifar-10', train=False, download=True, transform=transform)

train_data, test_data

Files already downloaded and verified
Files already downloaded and verified


(Dataset CIFAR10
     Number of datapoints: 50000
     Root location: ./data/cifar-10
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
            ),
 Dataset CIFAR10
     Number of datapoints: 10000
     Root location: ./data/cifar-10
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
            ))

The images are represented as 3 channel (RGB) 32x32 pixel images. The dataset can be indexed, with the first index being the image index, and the second index indexing either the image data or the target. The pixel values are represented as `torch.float32` values from 0 to 1.

In [48]:
train_data.data.shape, len(train_data.targets)

((50000, 32, 32, 3), 50000)

In [49]:
train_data[0][0].numpy().shape, train_data[0][1]

((3, 32, 32), 6)

In [54]:
train_data[0][0][0], train_data[0][0].dtype

(tensor([[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]]),
 torch.float32)

If you are familiar with the transformer architecture, you likely know that transformers deal with 

<span style="background-color:rgba(107,64,216,.3);white-space:pre;">This</span><span style="background-color:rgba(104,222,122,.4);white-space:pre;"> is</span><span style="background-color:rgba(244,172,54,.4);white-space:pre;"> a</span><span style="background-color:rgba(239,65,70,.4);white-space:pre;"> test</span><span style="background-color:rgba(39,181,234,.4);white-space:pre;">.</span>

In [45]:
patch_size = 4
for i in range(0, 32, patch_size):
    for j in range(0, 32, patch_size):
        patch = train_data[0][0][:, i:i+patch_size, j:j+patch_size]
        print(patch.shape)

torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size([3, 4, 4])
torch.Size