<a href="https://colab.research.google.com/github/fmaliks25/github-101/blob/main/ViT_paper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Setup and Data Preparation

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [4]:
train_dataset = datasets.CIFAR10(root='data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='data', train=False, transform=transform, download=True)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


# 2. Define the Vision Transformer (ViT) Model
Define the necessary components for the ViT model, including the patch embedding, transformer encoder, and the final classification head.

In [6]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size, img_size):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
        self.img_size = img_size

    def forward(self, x):
        x = self.proj(x)  # (batch_size, emb_size, num_patches, num_patches)
        x = x.flatten(2)  # (batch_size, emb_size, num_patches*num_patches)
        x = x.transpose(1, 2)  # (batch_size, num_patches*num_patches, emb_size)
        return x


In [7]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, emb_size, num_classes, num_layers, num_heads, mlp_dim):
        super(ViT, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.position_embedding = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads, dim_feedforward=mlp_dim, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.patch_embedding(x)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embedding
        x = self.transformer_encoder(x)
        cls_token_final = x[:, 0]
        x = self.mlp_head(cls_token_final)
        return x


# 3. Initialize and Train the Model
Initialize the ViT model, define the loss function and optimizer, and set up the training loop.



In [8]:
# Model parameters
img_size = 224
patch_size = 16
in_channels = 3
emb_size = 768
num_classes = 10
num_layers = 6
num_heads = 8
mlp_dim = 2048

# Initialize the ViT model
model = ViT(img_size, patch_size, in_channels, emb_size, num_classes, num_layers, num_heads, mlp_dim)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# 4. Evaluate the Model
After training, evaluate the model's performance on the test set.

In [None]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

print("Training complete!")
