In [6]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# ==========================
# DEVICE
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ==========================
# HYPERPARAMETERS
# ==========================
patch_size = 4
embed_dim = 128
num_heads = 4
num_layers = 4
num_classes = 10
batch_size = 128
lr = 3e-4
epochs = 20

# ==========================
# PREPROCESSING
# ==========================
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

# ==========================
# PATCH EMBEDDING
# ==========================
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)               # (B, D, H', W')
        x = x.flatten(2)               # (B, D, N)
        x = x.transpose(1,2)           # (B, N, D)
        return x

# ==========================
# VISION TRANSFORMER
# ==========================
class ViT(nn.Module):
    def __init__(self):
        super().__init__()

        self.patch_embed = PatchEmbedding()

        # 64 patches + 1 CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 65, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=512,
            dropout=0.1,
            batch_first=True
        )

        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.shape[0]

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = self.transformer(x)
        cls_output = self.norm(x[:, 0])
        return self.head(cls_output)

# ==========================
# MODEL SETUP
# ==========================
model = ViT().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

Using device: cuda


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Calculate training accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(trainloader)
    acc = 100 * correct / total

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Train Acc: {acc:.2f}%")


Epoch 1, Loss: 1.7992, Train Acc: 32.44%
Epoch 2, Loss: 1.4487, Train Acc: 46.89%
Epoch 3, Loss: 1.2950, Train Acc: 52.99%
Epoch 4, Loss: 1.1938, Train Acc: 56.65%
Epoch 5, Loss: 1.1297, Train Acc: 59.40%
Epoch 6, Loss: 1.0743, Train Acc: 61.28%
Epoch 7, Loss: 1.0250, Train Acc: 63.13%
Epoch 8, Loss: 0.9874, Train Acc: 64.56%
Epoch 9, Loss: 0.9483, Train Acc: 66.25%
Epoch 10, Loss: 0.9181, Train Acc: 67.15%
Epoch 11, Loss: 0.8885, Train Acc: 68.19%
Epoch 12, Loss: 0.8576, Train Acc: 69.49%
Epoch 13, Loss: 0.8392, Train Acc: 70.15%
Epoch 14, Loss: 0.8204, Train Acc: 70.70%
Epoch 15, Loss: 0.7952, Train Acc: 71.54%
Epoch 16, Loss: 0.7846, Train Acc: 72.19%
Epoch 17, Loss: 0.7609, Train Acc: 73.08%
Epoch 18, Loss: 0.7470, Train Acc: 73.28%
Epoch 19, Loss: 0.7400, Train Acc: 73.89%
Epoch 20, Loss: 0.7223, Train Acc: 74.67%


In [8]:
correct = 0
total = 0

model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print("Accuracy:", 100 * correct / total)


Accuracy: 73.92
