In [3]:
import torch
import torch.nn as nn
from sklearn.feature_extraction import image
from sklearn.datasets import load_sample_image
import numpy as np

### 0. Get patch from image

In [4]:
img = load_sample_image('china.jpg')

In [5]:
img.shape

(427, 640, 3)

In [6]:
patches = image.extract_patches_2d(img, (16, 16))

In [7]:
patches.shape

(257500, 16, 16, 3)

In [8]:
patches[0].shape

(16, 16, 3)

In [9]:
patch = patches[0]
x = torch.tensor(patch)
x = x.type(torch.float32)
x = x.permute(2, 0, 1)
x.shape

torch.Size([3, 16, 16])

### 1. Define modules classes

In [10]:
class PatchTokenization(nn.Module):
    def __init__(self, patch_size=16, chanels=3, embed_dim=768):  # embed_dim = 16x16x3
        super().__init__()
        self.proj = nn.Conv2d(chanels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        return x

In [11]:
patching = PatchTokenization()

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8, proj_drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        # self.head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3)  # (B, N, C) -> (B, N, C * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        print(f'qkv size: {qkv.size()}')
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        print(f'qkv reshaped size: {qkv.size()}')
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, head, C/head, W, H)
        q, k, v = qkv[0], qkv[1], qkv[2]

        print(f'q size: {q.size()}')
        print(f'k size: {k.size()}')
        print(f'v size: {v.size()}')
        attn = q @ k.transpose(-2, -1)
        # attn = attn * (self.head_dim ** -0.5)
        print(f'attn size: {attn.size()}')
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # x = torch.mul(attn, v)
        x = (attn @ v).transpose(1, 2)
        print(f'attn * v size: {x.size()}')
        x = x.reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

In [13]:
attention = MultiHeadAttention(dim=768)

In [22]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        # self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        # x = self.fc2(x)
        # x = self.drop(x)

        return x 

In [45]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MultiHeadAttention(dim, num_heads, proj_drop, attn_drop)
        self.temporal_fc = nn.Linear(dim, dim)
        self.norm2 = norm_layer(dim)
        # mlp_hidden_dim = int(dim * mlp_ratio)
        mlp_hidden_dim = int(dim * 1.)
        self.mlp = MLP(dim, mlp_hidden_dim, act_layer, drop=proj_drop)

    def forward(self, x, B, T, W):
        num_spatial_tokens = (x.size(1) - 1) // T
        H = num_spatial_tokens // W

        x = x + self.attn(self.norm1(x))
        print(f'x + attn: {x.shape}')
        x = x + self.mlp(self.norm2(x))
        return x

In [46]:
block = Block(dim=768, num_heads=8)

### 2. Patch trough architecture

In [25]:
x = patching.forward(x) 

In [26]:
x.size()

torch.Size([768, 1, 1])

In [27]:
z = x.swapaxes(0, 2)

In [28]:
z.size()

torch.Size([1, 1, 768])

In [29]:
A, B, C = z.shape
print(A, B, C)

1 1 768


In [30]:
y = attention.forward(z)

qkv size: torch.Size([1, 1, 2304])
qkv reshaped size: torch.Size([1, 1, 3, 8, 96])
q size: torch.Size([1, 8, 1, 96])
k size: torch.Size([1, 8, 1, 96])
v size: torch.Size([1, 8, 1, 96])
attn size: torch.Size([1, 8, 1, 1])
attn * v size: torch.Size([1, 1, 8, 96])


In [31]:
print(y.size())

torch.Size([1, 1, 768])


In [33]:
B, T, W = y.shape
print(B, T, W)

1 1 768


In [47]:
w = block.forward(y, B, T, W)

qkv size: torch.Size([1, 1, 2304])
qkv reshaped size: torch.Size([1, 1, 3, 8, 96])
q size: torch.Size([1, 8, 1, 96])
k size: torch.Size([1, 8, 1, 96])
v size: torch.Size([1, 8, 1, 96])
attn size: torch.Size([1, 8, 1, 1])
attn * v size: torch.Size([1, 1, 8, 96])
x + attn: torch.Size([1, 1, 768])


In [49]:
print(w.size())

torch.Size([1, 1, 768])
