In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 32
PATCH_SIZE = 4
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2
EMBED_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 4
MLP_DIM = 256
DROPOUT = 0.1
BATCH_SIZE = 64
LR = 5e-4
EPOCHS = 20

In [3]:
#Adding basic augmentation to prevent overfitting
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

100%|██████████| 170M/170M [00:03<00:00, 46.5MB/s]


In [4]:
class ViTClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        #Patch Embedding:(Convolutional approach)
        self.patch_embed = nn.Conv2d(3, EMBED_DIM, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)

        #Learnable tokens
        self.cls_token = nn.Parameter(torch.randn(1, 1, EMBED_DIM))
        self.pos_embedding = nn.Parameter(torch.randn(1, NUM_PATCHES + 1, EMBED_DIM))
        self.dropout = nn.Dropout(DROPOUT)

        #Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=EMBED_DIM,
            nhead=NUM_HEADS,
            dim_feedforward=MLP_DIM,
            dropout=DROPOUT,
            activation='gelu',
            batch_first=True,
            norm_first=True #Critical for stability
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)

        #Classification Head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(EMBED_DIM),
            nn.Linear(EMBED_DIM, 10)
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding
        x = self.dropout(x)

        x = self.transformer(x)

        return self.mlp_head(x[:, 0])

In [5]:
model = ViTClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)

print(f"Training on {device}...")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Epoch [{epoch+1}/{EPOCHS}] - Loss: {running_loss/len(trainloader):.4f} - Acc: {100 * correct / total:.2f}%")



Training on cuda...
Epoch [1/20] - Loss: 1.7796 - Acc: 45.99%
Epoch [2/20] - Loss: 1.5024 - Acc: 51.05%
Epoch [3/20] - Loss: 1.4018 - Acc: 52.63%
Epoch [4/20] - Loss: 1.3371 - Acc: 54.91%
Epoch [5/20] - Loss: 1.2823 - Acc: 57.10%
Epoch [6/20] - Loss: 1.2343 - Acc: 59.25%
Epoch [7/20] - Loss: 1.1934 - Acc: 61.41%
Epoch [8/20] - Loss: 1.1557 - Acc: 61.13%
Epoch [9/20] - Loss: 1.1218 - Acc: 62.43%
Epoch [10/20] - Loss: 1.0884 - Acc: 63.79%
Epoch [11/20] - Loss: 1.0567 - Acc: 65.35%
Epoch [12/20] - Loss: 1.0215 - Acc: 64.92%
Epoch [13/20] - Loss: 1.0016 - Acc: 66.41%
Epoch [14/20] - Loss: 0.9782 - Acc: 66.92%
Epoch [15/20] - Loss: 0.9461 - Acc: 67.86%
Epoch [16/20] - Loss: 0.9349 - Acc: 68.53%
Epoch [17/20] - Loss: 0.9090 - Acc: 68.82%
Epoch [18/20] - Loss: 0.8864 - Acc: 69.04%
Epoch [19/20] - Loss: 0.8664 - Acc: 71.02%
Epoch [20/20] - Loss: 0.8502 - Acc: 70.64%
