In [1]:
!pip install -q einops

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from einops import rearrange
from torch.utils.data import DataLoader

In [3]:
BATCH_SIZE = 128
EPOCHS = 25
LR = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 32
PATCH_SIZE = 4
NUM_CLASSES = 10
DIM = 256
DEPTH = 6
HEADS = 8
MLP_DIM = 512
DROPOUT = 0.1

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

100%|██████████| 170M/170M [00:08<00:00, 20.5MB/s]


In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, emb_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, emb_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    def forward(self, x):
        x = self.proj(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            IMAGE_SIZE, PATCH_SIZE, 3, DIM
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, DIM))
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.patch_embed.n_patches + 1, DIM)
        )
        self.dropout = nn.Dropout(DROPOUT)
        self.transformer = nn.Sequential(
            *[TransformerBlock(DIM, HEADS, MLP_DIM, DROPOUT) for _ in range(DEPTH)]
        )
        self.norm = nn.LayerNorm(DIM)
        self.head = nn.Linear(DIM, NUM_CLASSES)
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        cls_output = x[:, 0]
        return self.head(cls_output)

In [8]:
model = VisionTransformer().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [9]:
def train():
    model.train()
    correct = total = 0
    running_loss = 0
    for images, labels in trainloader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    acc = 100 * correct / total
    return running_loss / len(trainloader), acc

In [10]:
def test():
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return 100 * correct / total

In [11]:
for epoch in range(EPOCHS):
    train_loss, train_acc = train()
    test_acc = test()
    scheduler.step()
    print(f"Epoch [{epoch+1}/{EPOCHS}] "
          f"Loss: {train_loss:.3f} "
          f"Train Acc: {train_acc:.2f}% "
          f"Test Acc: {test_acc:.2f}%")

Epoch [1/25] Loss: 1.886 Train Acc: 30.42% Test Acc: 42.85%
Epoch [2/25] Loss: 1.593 Train Acc: 42.13% Test Acc: 47.11%
Epoch [3/25] Loss: 1.465 Train Acc: 47.21% Test Acc: 52.49%
Epoch [4/25] Loss: 1.397 Train Acc: 49.52% Test Acc: 54.15%
Epoch [5/25] Loss: 1.318 Train Acc: 52.28% Test Acc: 56.76%
Epoch [6/25] Loss: 1.264 Train Acc: 54.19% Test Acc: 58.65%
Epoch [7/25] Loss: 1.220 Train Acc: 56.00% Test Acc: 60.17%
Epoch [8/25] Loss: 1.171 Train Acc: 58.03% Test Acc: 60.77%
Epoch [9/25] Loss: 1.139 Train Acc: 59.10% Test Acc: 61.37%
Epoch [10/25] Loss: 1.103 Train Acc: 60.23% Test Acc: 62.67%
Epoch [11/25] Loss: 1.065 Train Acc: 61.97% Test Acc: 63.35%
Epoch [12/25] Loss: 1.038 Train Acc: 62.87% Test Acc: 65.39%
Epoch [13/25] Loss: 1.007 Train Acc: 64.01% Test Acc: 65.19%
Epoch [14/25] Loss: 0.982 Train Acc: 65.14% Test Acc: 66.06%
Epoch [15/25] Loss: 0.958 Train Acc: 65.73% Test Acc: 67.03%
Epoch [16/25] Loss: 0.927 Train Acc: 66.84% Test Acc: 68.07%
Epoch [17/25] Loss: 0.910 Train A