In [29]:
!pip install einops




In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from einops import rearrange
import numpy as np


In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [32]:
transform_train = transforms.Compose([
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

transform_test = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])


In [33]:
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


In [34]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size

        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.projection(x)          # (B, E, H/P, W/P)
        x = x.flatten(2)                # (B, E, N)
        x = x.transpose(1, 2)           # (B, N, E)
        return x


In [35]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, E = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        attention = (q @ k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention = attention.softmax(dim=-1)

        out = (attention @ v)
        out = out.transpose(1, 2).reshape(B, N, E)
        return self.fc(out)


In [36]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * mlp_ratio, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [37]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=64,
        patch_size=32,
        in_channels=3,
        num_classes=10,
        embed_dim=256,
        depth=4,
        num_heads=4,
        dropout=0.1
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout=dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        return self.head(x[:, 0])


In [38]:
model = VisionTransformer().to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, "Million parameters")


3.950346 Million parameters


In [39]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)


In [40]:
def train(model, loader):
    model.train()
    total_loss = 0
    correct = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    acc = correct / len(loader.dataset)
    return total_loss / len(loader), acc


In [41]:
def evaluate(model, loader):
    model.eval()
    correct = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            correct += (outputs.argmax(1) == labels).sum().item()

    return correct / len(loader.dataset)


In [42]:
epochs = 10

for epoch in range(epochs):
    loss, train_acc = train(model, train_loader)
    test_acc = evaluate(model, test_loader)

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Loss: {loss:.4f} "
          f"Train Acc: {train_acc:.4f} "
          f"Test Acc: {test_acc:.4f}")


Epoch [1/10] Loss: 1.7792 Train Acc: 0.3547 Test Acc: 0.4048
Epoch [2/10] Loss: 1.5823 Train Acc: 0.4291 Test Acc: 0.4315
Epoch [3/10] Loss: 1.4849 Train Acc: 0.4675 Test Acc: 0.4748
Epoch [4/10] Loss: 1.4078 Train Acc: 0.4948 Test Acc: 0.4997
Epoch [5/10] Loss: 1.3422 Train Acc: 0.5202 Test Acc: 0.5148
Epoch [6/10] Loss: 1.2884 Train Acc: 0.5386 Test Acc: 0.5204
Epoch [7/10] Loss: 1.2350 Train Acc: 0.5597 Test Acc: 0.5319
Epoch [8/10] Loss: 1.1812 Train Acc: 0.5776 Test Acc: 0.5361
Epoch [9/10] Loss: 1.1375 Train Acc: 0.5945 Test Acc: 0.5555
Epoch [10/10] Loss: 1.0856 Train Acc: 0.6130 Test Acc: 0.5578
