In [None]:
import torch
import torch.nn as nn
import math

# 1. Patch Embedding Module

In [28]:
class PatchEmbed(nn.Module):
    """Split image into patches and then embed them
    
    Parameters
    ----------
    img_size : int
        Size of the image (it is square).
        
    patch_size : int
        Size of the patch (it is square).
    
    in_chans : int
        Number of input channels.
        
    embed_dim: int
        The embedding dimension.
        
    Attributes
    ----------
    
    n_patches : int
        Number of patches inside of our image.
         
    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches
        and their embedding.
    """
    
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )
    
    def forward(self, x):
        """Run forward pass.
        
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`.
        
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches, embed_dim)`.
        """
    
        x = self.proj(x)     # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2)     # (n_samples, embed_dim, n_patches)
        x = x.transpose(1,2) # (n_samples, n_patches, embed_dim)
        
        return x

## 1.1 Simple Example

In [33]:
conv_kwargs = {
    "in_channels": 3,
    "out_channels": 100,
    "kernel_size": 16,
    "stride": 16
}

proj_layer = nn.Conv2d(**conv_kwargs)

img_kwargs = {
    "n_samples": 1,
    "channel": 3,
    "img_size_1": 224,
    "img_size_2": 224,
}

input_img = torch.randn(*img_kwargs.values())
embed_img = proj_layer(input_img)

processed_output_img = embed_img.flatten(2)
processed_output_img = processed_output_img.transpose(1,2)


print("Input Image Shape     :", input_img.shape)
print("Output Conv2d Shape   :", embed_img.shape)
print("Processed Output Shape:", processed_output_img.shape)

Input Image Shape     : torch.Size([1, 3, 224, 224])
Output Conv2d Shape   : torch.Size([1, 100, 14, 14])
Processed Output Shape: torch.Size([1, 196, 100])


<p>
    Output shape is 14 because the kernel_size and stride are 16.
    There are 16 patches in rows to cover 224 width (224/16 = 14).
    Similarly, there are 16 patches in columns to cover 224 height (224/16 = 14).
</p>
<p>
    Total number of patches are 196 (224/16 ** 2) where 
    each row has 14 patches and each column has 14 patches
    which result in 14 ** 2 (224/16 ** 2.
</p>
<p>
    Finally, we'll flatten (14, 14) patches into a vector to get a flatten output.
</p>

## 1.2 Demo on Patch Embedding

In [34]:
patch_kwargs = {
    "img_size": 224,
    "patch_size": 16,
    "in_chans": 3,
    "embed_dim": 100,
}
patch_embedding = PatchEmbed(**patch_kwargs)
input_img = torch.randn(*img_kwargs.values())
embed_img = patch_embedding(input_img)

print("Input Image Shape  :", input_img.shape)
print("Output Conv2d Shape:", embed_img.shape)

Input Image Shape  : torch.Size([1, 3, 224, 224])
Output Conv2d Shape: torch.Size([1, 196, 100])


# 2. Attention Module

In [38]:
class Attention(nn.Module):
    """Attention Mechanism

    Parameters
    ----------
    dim : int
        The input and output dimension of per token features.

    n_heads : int
        Number of attention heads.

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    attn_p : float
        Dropout probability applied to the query, key and value tensors.\

    proj_p : float
        Dropout probability applied to the output tensor.

    Attributes
    ----------
    scale : float
        Normalizing constant for the dot product.

    qkv : nn.Linear
        Linear projection for the query, key and value.

    proj : nn.Linear
        Linear mapping that takes in the concatenated output of all attention
        heads and maps it into a new space

    attn_drop, proj_drop : nn.Dropout
        Dropout layers.
    """

    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = math.sqrt(1/self.head_dim)

        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """

        n_samples, n_tokens, dim = x.shape

        if dim != self.dim:
            raise ValueError

        # (n_samples, n_patches + 1, 3 * dim)
        qkv = self.qkv(x)

        # (n_sample, n_patches + 1, 3, n_heads, head_dim)
        qkv = qkv.reshape(
            n_samples, n_tokens, 3, self.n_heads, self.head_dim
        )

        # (3, n_samples, n_heads, n_patches + 1, head_dim)
        qkv = qkv.permute(
            2, 0, 3, 1, 4
        )

        q, k, v = qkv[0], qkv[1], qkv[2]

        k_t = k.transpose(-2, -1)
        dp = (q @ k_t) * self.scale
        attn = dp.softmax(dim=-1)

        # (n_samples, n_heads, n_patches + 1, head_dim
        weighted_avg = attn @ v

        # (n_samples, n_patches + 1, n_heads, head_dim)
        weighted_avg = weighted_avg.transpose(1, 2)

        # (n_samples, n_patches + 1, dim)
        weighted_avg = weighted_avg.flatten(2)

        # (n_samples, n_patches + 1, dim)
        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x

## 2.1 Demo on Attention

In [41]:
attn_kwargs = {
    "dim": 100,
    "n_heads": 10,
}

attention = Attention(**attn_kwargs)
attention_img = attention(embed_img)

print("Input Patch Embedding Shape:", embed_img.shape)
print("Output Attention Shape     :", attention_img.shape)

Input Patch Embedding Shape: torch.Size([1, 196, 100])
Output Attention Shape     : torch.Size([1, 196, 100])


# 3. MLP Module

In [45]:
class MLP(nn.Module):
    """Multilayer Perceptron.

    Parameters
    ----------
    in_features : int
        Number of input features.

    hidden_features : int
        Number of nodes in hidden layer.

    out_features : int
        Number of output features.

    p : float
        Dropout probability.

    Attributes
    ----------
    fc : nn.Linear
        The first linear layer.

    act : nn.GELU
        GELU activation function.

    fc2: nn.Linear
        The second linear layer.

    drop : nn.Dropout
        Dropout Layer
    """
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)

        self.act = nn.GELU()
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, in_features)`.

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, out_features)`.
        """

        # (n_samples, n_patches + 1, hidden_features)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)

        # (n_samples, n_patches + 1, out_features)
        x = self.fc2(x)
        x = self.drop(x)

        return x

## 3.1 Demo on MLP

In [46]:
mlp_kwargs = {
    "in_features": 100,
    "hidden_features": 150,
    "out_features": 200
}

mlp = MLP(**mlp_kwargs)
mlp_out_img = mlp(attention_img)

print("Input Attention Shape:", attention_img.shape)
print("Output MLP Shape     :", mlp_out_img.shape)

Input Attention Shape: torch.Size([1, 196, 100])
Output MLP Shape     : torch.Size([1, 196, 200])


# 4. Block Module

In [48]:
class Block(nn.Module):
    """Transformer block.

    Parameters
    ----------
    dim : int
        Embedding dimension.

    n_heads : int
        Number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension size of the `MLP` module with respect
        to `dim`

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability.

    Attributes
    ----------
    norm1, norm2 : LayerNorm
        Layer normalization.

    attn : Attention
        Attention module.

    mlp : MLP
        MLP module.
    """

    def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_bias,
            attn_p=attn_p,
            proj_p=p
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=hidden_features,
            out_features=dim,
        )

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """

        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x

# 5. Complete VisionTransformer Module

In [52]:
class VisionTransformer(nn.Module):
    """Simplified implementation of the Vision Transformer.

    Parameters
    ----------
    img_size : int
        Both height and width of the image (it is a square).

    patch_size : int
        Both height and width of the patch (it is a square).

    in_chans : int
        Number of input channels.

    n_classes : int
        Number of classes.

    embed_dim : int
        Dimensionality of the token/patch embeddings.

    depth : int
        Number of blocks.

    n_heads : int
        Number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension of the `MLP` module.

    qkv_bias : True
        If True then we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability

    Attributes
    ----------
    patch_embed : PatchEmbed
        Instance of `PatchEmbed` layer.

    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence.
        It has `embed_dim` elements.

    pos_emb : nn.Parameter
        Positional embedding of the cls token + all the patches.
        It has `(n_patches + 1) * embed_dim` elements.

    pos_drop : nn.Dropout
        Dropout layer.

    blocks : nn.ModuleList
        List of `Block` modules.

    norm : nn.LayerNorm
        Layer normalization.
    """

    def __init__(
            self,
            img_size=384,
            patch_size=16,
            in_chans=3,
            n_classes=1000,
            embed_dim=768,
            depth=12,
            n_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            p=0.,
            attn_p=0.,
    ):
        super().__init__()

        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,1+self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=p)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p,
                )

                for _ in range(depth)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        """Run the forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`.


        Returns
        -------
        logits : torch.Tensor
            Logits over all the classes - `(n_samples, n_classes)`.
        """

        n_samples = x.shape[0]
        x = self.patch_embed(x)

        # (n_samples, 1, embed_dim)
        cls_token = self.cls_token.expand(
            n_samples, -1, -1
        )

        # (n_samples, 1 + n_patches, embed_dim)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        cls_token_final = x[:, 0] # just the CLS token
        x = self.head(cls_token_final)

        return x

# 6. Verify Vision Transformer

In [78]:
import numpy as np
import timm
import torch

## 6.1 Helpers

In [79]:
def get_n_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

In [111]:
def assert_tensors_equal(t1, t2):
    a1, a2 = t1.detach().numpy(), t2.detach().numpy()
    return np.testing.assert_allclose(a1, a2, atol=1e-5)

## 6.2 Load the Model from Timm

In [105]:
model_name = "vit_base_patch16_384"
model_official = timm.create_model(model_name, pretrained=True)
model_official.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [106]:
custom_config = {
    "img_size": 384,
    "in_chans": 3,
    "patch_size": 16,
    "embed_dim": 768,
    "depth": 12,
    "n_heads": 12,
    "qkv_bias": True,
    "mlp_ratio": 4,
}

model_custom = VisionTransformer(**custom_config)
model_custom.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=1000, bi

In [107]:
for (n_o, p_o), (n_c, p_c) in zip(
    model_official.named_parameters(), model_custom.named_parameters()
):
    assert p_o.numel() == p_c.numel()
    print(f"{n_o} | {n_c}")

cls_token | cls_token
pos_embed | pos_embed
patch_embed.proj.weight | patch_embed.proj.weight
patch_embed.proj.bias | patch_embed.proj.bias
blocks.0.norm1.weight | blocks.0.norm1.weight
blocks.0.norm1.bias | blocks.0.norm1.bias
blocks.0.attn.qkv.weight | blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias | blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight | blocks.0.attn.proj.weight
blocks.0.attn.proj.bias | blocks.0.attn.proj.bias
blocks.0.norm2.weight | blocks.0.norm2.weight
blocks.0.norm2.bias | blocks.0.norm2.bias
blocks.0.mlp.fc1.weight | blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias | blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight | blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias | blocks.0.mlp.fc2.bias
blocks.1.norm1.weight | blocks.1.norm1.weight
blocks.1.norm1.bias | blocks.1.norm1.bias
blocks.1.attn.qkv.weight | blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias | blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight | blocks.1.attn.proj.weight
blocks.1.attn.proj.bias | blocks.1.attn.proj.b

## 6.3 Load the state_dict from Timm Model

In [113]:
model_custom.load_state_dict(model_official.state_dict())
model_custom.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=1000, bi

## 6.4 Compare the two model outputs

In [114]:
inp = torch.rand(1, 3, 384, 384)
res_c = model_custom(inp)
res_o = model_official(inp)

In [115]:
assert get_n_params(model_custom) == get_n_params(model_official)
assert_tensors_equal(res_c, res_o)

# 7. Inference

In [116]:
from PIL import Image

In [118]:
k = 10
imagenet_labels = dict(enumerate(open("./inference/classes.txt")))

In [121]:
img = (np.array(Image.open("./inference/cat.png")) / 128) -1
inp = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).to(torch.float32)
logits = model_custom(inp)
probs = torch.nn.functional.softmax(logits, dim=-1)

In [123]:
top_probs, top_ixs = probs[0].topk(k)

In [124]:
for i, (ix_, prob_) in enumerate(zip(top_ixs, top_probs)):
    ix = ix_.item()
    prob = prob_.item()
    cls = imagenet_labels[ix].strip()
    print(f"{i}: {cls:<45} --- {prob:.4f}")

0: tabby, tabby_cat                              --- 0.8001
1: tiger_cat                                     --- 0.1752
2: Egyptian_cat                                  --- 0.0172
3: lynx, catamount                               --- 0.0018
4: Persian_cat                                   --- 0.0011
5: Siamese_cat, Siamese                          --- 0.0002
6: bow_tie, bow-tie, bowtie                      --- 0.0002
7: weasel                                        --- 0.0001
8: lens_cap, lens_cover                          --- 0.0001
9: remote_control, remote                        --- 0.0001
