In [None]:
%pip install labml-nn

Collecting labml-nn
  Downloading labml_nn-0.4.137-py3-none-any.whl.metadata (9.2 kB)
Collecting labml==0.4.168 (from labml-nn)
  Downloading labml-0.4.168-py3-none-any.whl.metadata (7.5 kB)
Collecting labml-helpers==0.4.89 (from labml-nn)
  Downloading labml_helpers-0.4.89-py3-none-any.whl.metadata (1.4 kB)
Collecting torchtext (from labml-nn)
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting fairscale (from labml-nn)
  Downloading fairscale-0.4.13.tar.gz (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting gitpython (from labml==0.4.168->labml-nn)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting gi

In [None]:
import torch
from torch import nn

from labml_helpers.module import Module
from labml_nn.transformers import TransformerLayer
from labml_nn.utils import clone_module_list

In [None]:
class PatchEmbeddings(Module):
    def __init__(self, d_model: int, patch_size: int, in_channels: int):
        super(PatchEmbeddings, self).__init__()
        self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

    def forward(self, x:torch.Tensor):
        '''x: input image tensor of shape [batch_size, channels, height, width]'''
        x = self.conv(x) # [batch_size, d_model, patch_height, patch_width]
        batch_size, c, p_h, p_w = x.shape
        x = x.permute(2, 3, 0, 1)  # [patch_height, patch_width, batch_size, d_model]

        # set N or seq_len as the leading dimension is a convention and computationally efficient
        x = x.view(-1, batch_size, c) # [sequence_len(patch_num N), batch_size, d_model]
        return x

In [None]:
class LearnablePositionalEmbeddings(Module):
    def __init__(self, d_model: int, max_len: int = 5_000):
        super(LearnablePositionalEmbeddings, self).__init__()
        # wrapping a tensor in 'nn.Parameter' makes it learnable within the Pytorch Model (by backpropagation)
        # max_len: max_len of sequence, 1: compatible with batch processing (broadcasting)
        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)

    def forward(self, x:torch.Tensor):
        '''x: input tensor of shape [patch_num, batch_size, d_model]'''
        pe = self.positional_encodings[:x.shape[0]]
        return x + pe

In [None]:
class ClassificationHead(Module):
    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
        '''use [CLS] token to classify the image
           use two linear layers and an activation function while training
        '''
        super(ClassificationHead, self).__init__()
        self.linear1 = nn.Linear(d_model, n_hidden)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(n_hidden, n_classes)

    def forward(self, x:torch.Tensor):
        '''x: [CLS] token'''
        x = self.activation(self.lienar1(x))
        x = self.linear2(x)
        return x

In [None]:
class VisionTransformer(Module):
    def __init__(self, transformer_layer: TransformerLayer,
                 n_layers: int, patch_emb: PatchEmbeddings,
                 pos_emb: LearnablePositionalEmbeddings,
                 classification: ClassificationHead):
        super(VisionTransformer, self).__init__()
        self.patch_emb = patch_emb
        self.pos_emb = pos_emb
        self.classification = classification
        self.transformer_layers = clone_module_list(transformer_layer, n_layers)
        # transformer_layer.size = d_model
        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
        self.ln = nn.LayerNorm([transformer_layer.size])

    def forward(self, x:torch.Tensor):
        '''x: the input image tensor shape of [batch_size, channels, height, width]'''
        x = self.patch_emb(x) # [N, batch_size, d_model]
        # expand: creates a view(shallow copy) of singleton dimensions of a tensor
        # -1: keep the dimension, you can set a specific dimension for a singleton-dimension of a tensor
        # the expanded dimensions point to the same memory location
        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
        x = torch.cat([cls_token_emb, x], dim=0)
        x = self.pos_emb(x)

        for layer in self.transformer_layers:
            x = layer(x=x, mask=None)

        x = x[0] # get the [CLS] token
        x = self.ln(x)
        x = self.classification(x) # get logits
        return x