In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import math


In [None]:
# Use GPU if available (Colab usually has one)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    transform=train_transform,
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    transform=test_transform,
    download=True
)

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



100%|██████████| 170M/170M [00:02<00:00, 83.4MB/s]


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64):
        super().__init__()

        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # This layer converts a flattened patch into an embedding
        self.projection = nn.Linear(patch_size * patch_size * in_channels,
                                    embed_dim)

    def forward(self, x):
        """
        x shape: (batch_size, 3, 32, 32)
        """

        B, C, H, W = x.shape

        # Step 1: split image into patches
        x = x.unfold(2, self.patch_size, self.patch_size)
        x = x.unfold(3, self.patch_size, self.patch_size)

        # Now shape: (B, C, num_patches_h, num_patches_w, patch_size, patch_size)

        x = x.contiguous().view(B, C, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 1, 3, 4)

        # Shape: (B, num_patches, C, patch_size, patch_size)

        x = x.flatten(2)

        # Shape: (B, num_patches, patch_dim)

        # Step 2: project patches to embedding dimension
        x = self.projection(x)

        # Final shape: (B, num_patches, embed_dim)
        return x


In [None]:
class Embeddings(nn.Module):
    def __init__(self, num_patches, embed_dim):
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.position_embeddings = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )

    def forward(self, x):
        B, N, D = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)

        # Add CLS token at beginning
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embeddings
        x = x + self.position_embeddings

        return x


In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V):
        d_k = Q.size(-1)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        attention = F.softmax(scores, dim=-1)
        output = torch.matmul(attention, V)

        return output


In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=64, num_heads=4):
        super().__init__()

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attention = ScaledDotProductAttention()
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, D = x.shape

        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        Q, K, V = qkv[0], qkv[1], qkv[2]

        out = self.attention(Q, K, V)

        out = out.transpose(1, 2).contiguous()
        out = out.view(B, N, D)

        out = self.fc_out(out)
        return out


In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=64, num_heads=4, mlp_dim=128):
        super().__init__()

        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)

        self.ffn = nn.Sequential(
    nn.Linear(embed_dim, mlp_dim),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(mlp_dim, embed_dim)
)


        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.attention(x)
        x = self.norm1(x)

        x = x + self.ffn(x)
        x = self.norm2(x)

        return x


In [None]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()

        self.patch_embed = PatchEmbedding()
        self.embed = Embeddings(num_patches=64, embed_dim=64)

        self.encoder = nn.Sequential(
    TransformerEncoderBlock(),
    TransformerEncoderBlock(),
    TransformerEncoderBlock(),
    TransformerEncoderBlock()
)


        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.embed(x)
        x = self.encoder(x)

        cls_output = x[:, 0]
        out = self.classifier(cls_output)

        return out


In [None]:
model = VisionTransformer().to(device)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
epochs = 20

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")


Epoch [1/20], Loss: 1.7135
Epoch [2/20], Loss: 1.3352
Epoch [3/20], Loss: 1.2071
Epoch [4/20], Loss: 1.1329
Epoch [5/20], Loss: 1.0724
Epoch [6/20], Loss: 1.0238
Epoch [7/20], Loss: 0.9825
Epoch [8/20], Loss: 0.9482
Epoch [9/20], Loss: 0.9152
Epoch [10/20], Loss: 0.8897
Epoch [11/20], Loss: 0.8624
Epoch [12/20], Loss: 0.8377
Epoch [13/20], Loss: 0.8163
Epoch [14/20], Loss: 0.8003
Epoch [15/20], Loss: 0.7787
Epoch [16/20], Loss: 0.7604
Epoch [17/20], Loss: 0.7407
Epoch [18/20], Loss: 0.7300
Epoch [19/20], Loss: 0.7122
Epoch [20/20], Loss: 0.6963


In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 70.27%
