In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from einops import rearrange, repeat
from pandas import Series
from tqdm import tqdm
import torch.nn.functional as F

import os

In [2]:
ctx = Series()
ctx.device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop=0.2):
        super(MultiHeadAttention, self).__init__()
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.nhead = nhead
        self.attend = nn.Softmax(dim=-1)
        self.drop = nn.Dropout(pdrop)
        self.c_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attn_mask=None):
        b, t, c = x.shape

        qkv = self.qkv(x).chunk(3, dim=-1)

        q, k, v = map(lambda tsn: rearrange(tsn, 'b t (nh hd) -> b nh t hd', nh=self.nhead), qkv)

        wei = q @ k.transpose(-2, -1) * k.size(dim=-1) ** -0.5
        if attn_mask is not None:
            wei = wei.masked_fill(attn_mask[:t, :t] == 0, float('-inf'))
        attn = self.attend(wei)

        y = attn @ v
        y = rearrange(y, 'b nh t hd -> b t (nh hd)', nh=self.nhead)
        return self.drop(self.c_proj(y))

class FFN(nn.Module):
    def __init__(self, embed_dim, pdrop=0.2):
        super(FFN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(pdrop),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop=0.2):
        super(Block, self).__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, nhead, pdrop)
        self.ffn = FFN(embed_dim, pdrop)

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.ffn(self.ln_2(x))
        return x

class StackBlock(nn.Module):
    def __init__(self, embed_dim, nhead, n_layer, pdrop=0.2):
        super(StackBlock, self).__init__()
        self.blocks = nn.Sequential(
            *[Block(embed_dim, nhead, pdrop) for _ in range(n_layer)]
        )

    def forward(self, x):
        return self.blocks(x)

In [4]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, embed_dim, n_layer, nhead, pdrop, output_dim):
        super(VisionTransformer, self).__init__()
        self.patch_num = img_size // patch_size
        scale = embed_dim ** -0.5

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_embed = nn.Parameter(torch.randn(embed_dim) * scale)
        self.pos_embed = nn.Parameter(scale * torch.randn(self.patch_num ** 2 + 1, embed_dim))
        self.ln_pre = nn.LayerNorm(embed_dim)
        self.transformer = StackBlock(
            embed_dim, nhead, n_layer, pdrop
        )
        self.ln_post = nn.LayerNorm(embed_dim)
        self.proj = nn.Parameter(torch.randn(embed_dim, output_dim) * scale)

    def forward(self, x):
        x = self.conv1(x)
        x = rearrange(x, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=self.patch_num, p2=self.patch_num)
        x = torch.cat(
            (repeat(self.cls_embed, 'c -> b 1 c', b=x.shape[0]), x), dim=1
        )
        x += self.pos_embed
        x = self.ln_pre(x)
        x = self.transformer(x)
        x = self.ln_post(x[:, 0, :])
        return x @ self.proj

In [5]:
model = VisionTransformer(
    img_size=32, patch_size=4, embed_dim=192, n_layer=6, nhead=3, pdrop=0.1, output_dim=10
).to(ctx.device)

In [6]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomErasing(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

train_data = datasets.CIFAR10(root='~/data', train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root='~/data', train=False, download=True, transform=test_transform)

classes = train_data.classes

ctx.batch_size = 64

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=ctx.batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=ctx.batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
ctx.epochs = 1024
ctx.lr = 1e-6
ctx.eval_interval = 1
ctx.save_weight = 'vit.pt'
ctx.weight_decay = 0.1

if os.path.exists(ctx.save_weight):
    model.load_state_dict(torch.load(ctx.save_weight, weights_only=True))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=ctx.lr, weight_decay=ctx.weight_decay)

pbar = tqdm(range(1, 1 + ctx.epochs))
len_train_dl = len(train_dataloader)
len_test_dl = len(test_dataloader)
for epoch in pbar:
    train_avg_loss = 0
    model.train()
    for idx, (imgs, labels) in enumerate(train_dataloader):
        loss = criterion(model(imgs.to(ctx.device)), labels.to(ctx.device))
        optimizer.zero_grad()
        loss.backward()
        train_avg_loss += loss.item() / len_train_dl
        optimizer.step()
        pbar.set_postfix()
        pbar.set_postfix_str(f'{idx + 1}/{len_train_dl}')
    torch.save(
        model.state_dict(),
        ctx.save_weight
    )
    with torch.no_grad():
        if epoch % ctx.eval_interval == 0:
            model.eval()
            val_avg_loss = 0
            total = 0
            correct = 0
            for imgs, labels in test_dataloader:
                logits = model(imgs.to(ctx.device))
                val_avg_loss += criterion(logits, labels.to(ctx.device)).item() / len_test_dl
                total += labels.size(0)
                pred_labels = logits.argmax(dim=-1)
                correct += torch.sum(
                    pred_labels == labels.to(ctx.device)
                )
            pbar.set_description_str(f'train_loss: {train_avg_loss:.4f}, val loss: {val_avg_loss:.4f}, val acc: {correct / total:.4f}')

train_loss: 0.0281, val loss: 0.6514, val acc: 0.8574:   9%|▉         | 95/1024 [1:02:58<12:38:26, 48.98s/it]         