In [1]:
import torch
from torch import nn

In [2]:
torch.randn(3, 2,3,4).flatten(2).shape

torch.Size([3, 2, 12])

In [3]:


class PatchEmbedding(nn.Module):
    def __init__(self, img_size=256, patch_size=16, num_hiddens=512):
        super().__init__()

        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) ** (img_size[1] // patch_size[1])
        self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, X):
        return self.conv(X).flatten(2).transpose(1, 2)


In [4]:
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = torch.zeros(batch_size, 3, img_size, img_size)
assert img_size//patch_size, patch_emb(X).shape == (batch_size, (img_size//patch_size)**2, num_hiddens)



In [5]:
class ViTMLP(nn.Module):
    def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
        super().__init__()
        self.dense1 = nn.LazyLinear(mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.dense2 = nn.LazyLinear(mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, X):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(self.dense1(X)))))

In [6]:
from d2l_common import MultiHeadAttention
class VitBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens, num_heads, dropout, use_bias=False) -> None:
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = MultiHeadAttention(num_hiddens, num_heads, dropout, use_bias)
        self.ln2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, X, valid_lens=None):
        X = X + self.attention(*([self.ln1(X)]*3), valid_lens)
        return X+self.mlp(self.ln2(X))


In [8]:
X = torch.ones(2, 100, 24)
X.shape
encoder_blk = VitBlock(24, 24, 48, 8, 0.5)
encoder_blk.eval()
assert encoder_blk(X).shape == (2,100,24)



In [17]:
from d2l_common import Classifier


class ViT(Classifier):
    """Vision Transformer."""

    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1, use_bias=False, num_classes=10):
        super().__init__()
        self.img_size = img_size
        self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches+1
        # position embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f'{i}', VitBlock(num_hiddens, num_hiddens, mlp_num_hiddens, num_heads, blk_dropout, use_bias))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, num_classes))

    def forward(self, X):
        X = self.patch_embedding(X)
        X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        print(X.shape, self.pos_embedding.shape)
        X = self.dropout(X + self.pos_embedding)
        X = self.blks(X)
        return self.head(X[:, 0])

In [18]:
from d2l_common import Trainer,FasionMNIST

img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model =ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout, lr)

trainer = Trainer(max_epochs=10)
data = FasionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)



torch.Size([128, 37, 512]) torch.Size([128, 46657, 512])


RuntimeError: The size of tensor a (37) must match the size of tensor b (46657) at non-singleton dimension 1

In [30]:
# torch.Size([128, 37, 512]) torch.Size([128, 46657, 512])
(torch.randn(128, 1, 512) + torch.randn(64, 37, 512)).shape

torch.Size([128, 37, 512])