In [None]:
# This is just a practice notebook of concepts of ViT. (Vision Transformers). Nothing here is original work and everything here is referenced.

In [None]:
# References: https://towardsdatascience.com/vision-transformers-in-pytorch-43d13cb7ec7a

Convolutional neural networks (CNNs) have been the pre-dominant backbone for almost all networks used in computer vision and image-related tasks due to the advantages they have in 2D neighbourhood awareness and translation equivariance compared to traditional multi-layer perceptrons (MLPs). 

Why are CNNs so popular in the computer vision domain? The answer lies in the inherent nature of convolutions. The kernels, or the convolutional windows aggregate features from nearby pixels together, allowing features nearby to be considered together during learning. In addition, as we shift the kernels through out the images, features appearing in anywhere on the image could be detected and utilised for classification — we refer to this as translation equivariance. These characteristics allow CNNs to extract features regardless of the location the feature lies in the images, and hence encouraged significant improvements in image classification tasks in the past years.

<p align="center">
  <img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png" alt>
  <em><p align="center">Vision Transformer</p></em>
</p>

 It divides images into patches, and further uses these patches and convert them to embeddings, then feeds them as sequences equivalent to the embeddings in language processing to find the attentions between each other.

 It is worth noting that throughout extensive studies in the original paper, vision transformers only outperforms CNNs when the pre-trained dataset reaches a very large scale. Hence, it is less preferred to self-train it if your computational resources are fairly limited.

In [None]:
# Good Sources for various implementation of ViT: https://github.com/lucidrains/vit-pytorch

In [15]:
pip install vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-0.26.3-py3-none-any.whl (50 kB)
[?25l[K     |██████▌                         | 10 kB 26.9 MB/s eta 0:00:01[K     |█████████████                   | 20 kB 8.0 MB/s eta 0:00:01[K     |███████████████████▌            | 30 kB 7.2 MB/s eta 0:00:01[K     |██████████████████████████      | 40 kB 6.7 MB/s eta 0:00:01[K     |████████████████████████████████| 50 kB 2.5 MB/s 
Collecting einops>=0.3
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops, vit-pytorch
Successfully installed einops-0.3.2 vit-pytorch-0.26.3


In [17]:
import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

In [18]:
preds

tensor([[ 4.9436e-01, -2.3076e-01,  7.2608e-02, -5.5751e-01, -2.6363e-01,
         -1.7895e-01, -7.0695e-01, -4.9113e-01,  8.5270e-01, -7.1864e-02,
          3.6665e-01,  3.3224e-02,  3.5915e-02, -1.9711e-01,  7.0423e-02,
         -1.5850e+00, -1.6209e-01, -1.1476e-01, -1.1295e+00,  1.6895e-01,
         -6.1484e-01, -4.9369e-01,  1.0236e-01,  1.2722e-01,  9.4640e-01,
         -1.1465e+00,  5.3555e-01,  3.8996e-01,  2.6781e-01, -3.4655e-01,
         -9.9033e-01, -1.1465e+00,  2.1329e-01,  1.2906e-01,  1.8500e-01,
          5.4008e-01, -1.4790e-01,  8.7655e-01, -6.7191e-01, -2.3208e-01,
         -3.7972e-01,  1.3862e+00,  5.9828e-02,  6.8012e-01,  4.9246e-01,
         -2.7190e-01, -8.4761e-02,  1.1026e+00, -3.7720e-01,  6.1621e-01,
         -6.8438e-01, -4.1190e-01, -5.9758e-01, -6.4834e-01, -7.0096e-01,
          9.3936e-02,  7.4022e-01, -1.0899e+00, -4.8905e-02, -3.2548e-01,
          4.7038e-01,  6.4257e-02,  5.9614e-01, -1.0586e-02,  4.6490e-02,
         -1.1619e+00,  2.8293e-01,  2.

## params: 
- image_size: int.
- Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- patch_size: int.
- Number of patches. image_size must be divisible by patch_size.
The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
- num_classes: int.
Number of classes to classify.
- dim: int.
Last dimension of output tensor after linear transformation nn.Linear(..., dim).
- depth: int.
Number of Transformer blocks.
- heads: int.
Number of heads in Multi-head Attention layer.
- mlp_dim: int.
Dimension of the MLP (FeedForward) layer.
- channels: int, default 3.
Number of image's channels.
- dropout: float between [0, 1], default 0..
Dropout rate.
- emb_dropout: float between [0, 1], default 0.
Embedding dropout rate.
- pool: string, either cls token pooling or mean pooling
