In [2]:
import torch
import torch.nn as nn
from sklearn.feature_extraction import image
from sklearn.datasets import load_sample_image
import numpy as np

### 0. Get patch from image

In [10]:
img = load_sample_image('china.jpg')

In [12]:
img.shape

(427, 640, 3)

In [13]:
patches = image.extract_patches_2d(img, (16, 16))

In [15]:
patches.shape

(257500, 16, 16, 3)

In [17]:
patches[0].shape

(16, 16, 3)

In [65]:
patch = patches[0]
x = torch.tensor(patch)
x = x.type(torch.float32)
x = x.permute(2, 0, 1)
x.shape

torch.Size([3, 16, 16])

### 1. Define modules classes

In [61]:
class PatchTokenization(nn.Module):
    def __init__(self, patch_size=16, chanels=3, embed_dim=768):  # embed_dim = 16x16x3
        super().__init__()
        self.proj = nn.Conv2d(chanels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        return x

In [62]:
patching = PatchTokenization()

In [76]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8, proj_drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        # head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3)  # (W, H, C) -> (W, H, C * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self, x):
        W, H, C = x.shape
        qkv = self.qkv(x).reshape(W, H, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (3, head, C/head, W, H)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = torch.mul(q, k.transpose(-2, -1))  # sometimes they use @ operator? 
        # attn = attn * (head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = torch.mul(attn, v).reshape(W, H, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

In [77]:
attention = MultiHeadAttention(dim=768)

In [23]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x 

### 2. Patch trough architecture

In [66]:
x = patching.forward(x) 


torch.Size([768, 1, 1])

In [78]:
y = attention.forward(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (768x1 and 768x2304)