In [12]:
import numpy as np
import torch 
import torch.nn as nn
import einops
from einops import rearrange, reduce , repeat

In [132]:
class LinearProj(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.lin = nn.Linear(28*28*3 , 8)

    def forward(self,x):
        #print(x.shape)
        x = rearrange(x , 'b p p1 p2 c -> b p (p1 p2 c)' )
        #print(x.shape)
        x = self.lin(x)
        return x

In [133]:
l = LinearProj()

In [134]:
p = torch.randn(2, 4,28, 28, 3)

In [135]:
p = l(p)
p.shape

torch.Size([2, 4, 8])

In [136]:
model_dimension = 8
class VITblock(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.norm1 = nn.LayerNorm(model_dimension)
        self.norm2 = nn.LayerNorm(model_dimension)
        self.mlp = nn.Sequential(
            nn.Linear(model_dimension , model_dimension)
        )
        self.MHSA = nn.MultiheadAttention(embed_dim = model_dimension, num_heads = 1)

    def forward(self,x):
        # x is of shape (bs x no. of patches x model_dimension )
        x1 = self.norm1(x)
        x2,_ = self.MHSA(x1,x1,x1)
        x  = x2 + x

        x3 = self.norm2(x)
        x4 = self.mlp(x3)
        x  = x4 + x

        return x


In [137]:
v = VITblock()

In [138]:
v(p).shape

torch.Size([2, 4, 8])

In [142]:
v(p)

tensor([[[-0.5854, -0.7514,  0.2198,  0.0442,  0.5123,  0.4325, -1.5617,
          -1.4152],
         [-0.1297,  0.5873,  1.1905,  0.3146, -0.0548,  0.8600, -0.9206,
           0.6139],
         [ 0.4040, -0.0265,  0.4570, -0.1232,  1.3797,  1.3209, -0.4143,
          -1.0343],
         [-1.3091, -0.1063, -0.9931,  0.9654,  1.5126,  1.3525,  0.7966,
          -1.0016]],

        [[ 0.4665, -0.9921, -0.3968, -1.5002, -0.4527, -0.5212, -1.4241,
          -1.6723],
         [ 0.7723, -0.7284, -0.4974,  1.1497,  0.3244, -1.0656,  1.4813,
          -0.5505],
         [-0.5572,  1.0361,  0.1432, -0.0859, -0.1507,  0.7860,  0.4746,
           1.7172],
         [-0.4817,  0.4083, -0.1530, -0.5637,  0.6309,  1.0865, -0.6380,
          -1.2389]]], grad_fn=<AddBackward0>)

In [145]:
x = torch.randn(1 , 8)

In [146]:
x = repeat(x , 'a b -> r a b', r = 2)
x

tensor([[[ 2.2520, -1.0080, -0.1512,  0.4547,  1.3315, -0.7934,  2.4654,
           0.2535]],

        [[ 2.2520, -1.0080, -0.1512,  0.4547,  1.3315, -0.7934,  2.4654,
           0.2535]]])

In [148]:
x = torch.cat(( x , v(p)), dim = 1)

In [149]:
x.shape

torch.Size([2, 5, 8])

In [150]:
p = torch.randn(5 , 8)

In [155]:
x = (x + p)

In [156]:
x

tensor([[[ 2.2596, -0.0104,  0.9144,  0.1786,  1.1872, -1.6954,  2.0234,
           0.0288],
         [-0.6215, -1.5826, -0.4923, -0.2332,  0.6252, -0.0492, -0.2681,
          -1.9030],
         [ 0.0706, -1.0837,  0.2758,  1.1160,  0.5501,  1.3569, -2.0969,
          -0.7577],
         [ 0.5817, -1.5051,  0.3715, -2.4433,  1.9197,  0.2633,  0.0136,
          -0.7897],
         [-2.0136,  0.6554, -1.9909,  2.0055,  2.0433,  2.9869, -0.3803,
          -2.1213]],

        [[ 2.2596, -0.0104,  0.9144,  0.1786,  1.1872, -1.6954,  2.0234,
           0.0288],
         [ 0.4304, -1.8233, -1.1088, -1.7776, -0.3398, -1.0029, -0.1306,
          -2.1602],
         [ 0.9726, -2.3994, -1.4122,  1.9512,  0.9293, -0.5687,  0.3050,
          -1.9222],
         [-0.3796, -0.4425,  0.0576, -2.4060,  0.3893, -0.2715,  0.9025,
           1.9618],
         [-1.1861,  1.1700, -1.1507,  0.4764,  1.1616,  2.7209, -1.8149,
          -2.3587]]], grad_fn=<AddBackward0>)

In [157]:
x.shape

torch.Size([2, 5, 8])

In [159]:
x[: , 0]

tensor([[ 2.2596, -0.0104,  0.9144,  0.1786,  1.1872, -1.6954,  2.0234,  0.0288],
        [ 2.2596, -0.0104,  0.9144,  0.1786,  1.1872, -1.6954,  2.0234,  0.0288]],
       grad_fn=<SelectBackward0>)

In [161]:
class Patchify(nn.Module):
    def __init__(self , p = 2) -> None:
        super().__init__()
        self.p = p

    def forward(self , x ):
        x = rearrange(x , 'b (h p1) (w p2)  c -> b (p1 p2) h w c', p1=self.p, p2=self.p)
        return x

In [162]:
p = Patchify()

In [163]:
x = torch.randn(3 , 4,4,3)

In [168]:
x[0]

tensor([[[ 1.0733e+00,  1.7882e-04, -6.0929e-01],
         [ 2.6887e-01, -5.9926e-01, -8.6274e-01],
         [-2.5372e-01,  5.2756e-01, -1.0957e+00],
         [ 6.2930e-01,  8.5015e-01,  8.7898e-01]],

        [[-1.6170e-01,  1.1345e+00, -1.2453e-01],
         [ 6.4747e-01,  8.0367e-01, -2.9888e-01],
         [-9.6629e-01, -1.7787e+00, -8.0773e-01],
         [-5.5597e-01,  8.2175e-01,  4.1799e-01]],

        [[-1.7154e-01, -9.4263e-01, -7.7320e-01],
         [-7.7652e-02, -1.1777e+00,  5.5803e-02],
         [ 2.9365e+00,  5.1143e-01, -1.0198e+00],
         [ 1.2620e+00, -1.6780e-01, -1.0016e+00]],

        [[-6.2960e-01, -2.1806e+00, -9.5690e-01],
         [-3.5090e-01,  9.1976e-01,  5.1849e-01],
         [-1.5406e-01, -1.0412e+00,  1.1421e+00],
         [-6.6721e-01, -8.5002e-01, -3.4606e-01]]])

In [164]:
p(x).shape

torch.Size([3, 4, 2, 2, 3])

In [166]:
p(x)[0].shape

torch.Size([4, 2, 2, 3])