In [1]:
#Imports
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose, Normalize, Resize
from tqdm import tqdm
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR  # or ReduceLROnPlateau

<torch._C.Generator at 0x7acc2193cdf0>

In [None]:
#Set Random Seed

np.random.seed(0)
torch.manual_seed(0)

Define ViT Architecture

In [5]:
def patchify(images, n_patches):
    n, c, h, w = images.shape
    assert h == w, "Patchify method is implemented for square images only"
    patch_size = h // n_patches
    patches = torch.zeros(n, n_patches**2, c * patch_size * patch_size)

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[
                    :,
                    i * patch_size : (i + 1) * patch_size,
                    j * patch_size : (j + 1) * patch_size,
                ]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

def get_positional_embeddings(sequence_length, d):
    result = torch.zeros(sequence_length, d)
    for i in range(sequence_length):
        for j in range(0, d, 2):
            result[i][j] = np.sin(i / (10000 ** (j / d)))
            if j + 1 < d:
                result[i][j + 1] = np.cos(i / (10000 ** (j / d)))
    return result

class MyMSA(nn.Module):
    def __init__(self, d, n_heads=1):
        super(MyMSA, self).__init__()
        assert d % n_heads == 0
        self.d_head = d // n_heads
        self.n_heads = n_heads

        self.q_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)])
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        output = []
        for sequence in sequences:
            heads_out = []
            for h in range(self.n_heads):
                start = h * self.d_head
                end = (h + 1) * self.d_head
                q = self.q_mappings[h](sequence[:, start:end])
                k = self.k_mappings[h](sequence[:, start:end])
                v = self.v_mappings[h](sequence[:, start:end])
                attn = self.softmax(q @ k.T / (self.d_head ** 0.5))
                heads_out.append(attn @ v)
            output.append(torch.cat(heads_out, dim=1))
        return torch.stack(output)

class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=2):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d),
        )

    def forward(self, x):
        x = x + self.mhsa(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MyViT(nn.Module):
    def __init__(self, chw=(3, 32, 32), n_patches=4, n_blocks=1, hidden_d=64, n_heads=1, out_d=10):
        super().__init__()
        c, h, w = chw
        assert h % n_patches == 0 and w % n_patches == 0
        self.patch_size = (h // n_patches, w // n_patches)
        self.input_d = c * self.patch_size[0] * self.patch_size[1]

        self.linear_mapper = nn.Linear(self.input_d, hidden_d)
        self.class_token = nn.Parameter(torch.rand(1, hidden_d))
        self.register_buffer(
            "positional_embeddings",
            get_positional_embeddings(n_patches**2 + 1, hidden_d),
            persistent=False
        )

        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        self.mlp_head = nn.Sequential(nn.Linear(hidden_d, out_d), nn.Softmax(dim=-1))

    def forward(self, images):
        n = images.shape[0]
        patches = patchify(images, n_patches=4).to(self.positional_embeddings.device)
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        tokens += self.positional_embeddings.repeat(n, 1, 1)
        for block in self.blocks:
            tokens = block(tokens)
        return self.mlp_head(tokens[:, 0])

Define Data Transforms

In [3]:
# Define CIFAR-10 dataset transformations

transform = Compose([
    Resize((32, 32)),  # Ensure uniform size
    RandomRotation(degrees=15),  # Apply random rotation between -15 to +15 degrees
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load CIFAR-10 train and test sets
train_set = CIFAR10(
    root="./../datasets", train=True, download=True, transform=transform
)
test_set = CIFAR10(
    root="./../datasets", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

100%|██████████| 170M/170M [00:03<00:00, 43.1MB/s]


Define Training Loop

In [None]:
train_losses = []
val_losses = []

# Define the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyViT((3, 32, 32), n_patches=4, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)

# Hyperparameters
N_EPOCHS = 10
LR = 0.005
PATIENCE = 5  # Early stopping patience
MIN_DELTA = 0.001  # Minimum improvement to reset early stopping
STEP_SIZE = 10  # StepLR scheduler step size
GAMMA = 0.5     # LR decay factor

optimizer = Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

# Alternatively:
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

best_val_loss = float("inf")
patience_counter = 0

for epoch in range(N_EPOCHS):
    # ---- Training ----
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} Training", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)

    # ---- Validation ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            val_loss += loss.item()
    val_loss /= len(test_loader)
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f"Epoch {epoch+1}/{N_EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # ---- Scheduler Step ----
    scheduler.step()  # For StepLR
    # scheduler.step(val_loss)  # If using ReduceLROnPlateau

    # ---- Early Stopping Check ----
    if val_loss + MIN_DELTA < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "vit_best.pth")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break



Epoch 1/50 | Train Loss: 2.1875 | Val Loss: 2.1602




Epoch 2/50 | Train Loss: 2.1390 | Val Loss: 2.1280




Epoch 3/50 | Train Loss: 2.1204 | Val Loss: 2.1362




Epoch 4/50 | Train Loss: 2.1098 | Val Loss: 2.1051




Epoch 5/50 | Train Loss: 2.1089 | Val Loss: 2.1088




Epoch 6/50 | Train Loss: 2.1025 | Val Loss: 2.0973




Epoch 7/50 | Train Loss: 2.0977 | Val Loss: 2.0998




Epoch 8/50 | Train Loss: 2.0948 | Val Loss: 2.0942




Epoch 9/50 | Train Loss: 2.0923 | Val Loss: 2.0949




Epoch 10/50 | Train Loss: 2.0914 | Val Loss: 2.0939




Epoch 11/50 | Train Loss: 2.0802 | Val Loss: 2.0805




Epoch 12/50 | Train Loss: 2.0768 | Val Loss: 2.0840




Epoch 13/50 | Train Loss: 2.0770 | Val Loss: 2.0810




Epoch 14/50 | Train Loss: 2.0743 | Val Loss: 2.0834




Epoch 15/50 | Train Loss: 2.0729 | Val Loss: 2.0857




Epoch 16/50 | Train Loss: 2.0722 | Val Loss: 2.0785




Epoch 17/50 | Train Loss: 2.0725 | Val Loss: 2.0762




Epoch 18/50 | Train Loss: 2.0714 | Val Loss: 2.0837




Epoch 19/50 | Train Loss: 2.0705 | Val Loss: 2.0805


Epoch 20 Training:  33%|███▎      | 128/391 [01:20<02:57,  1.48it/s]

Plot losses

In [None]:
import matplotlib.pyplot as plt
import pickle

# Load the saved losses
with open("loss_history.pkl", "rb") as f:
    train_losses, val_losses = pickle.load(f)

# Plot the losses
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label="Training Loss", marker='o')
plt.plot(val_losses, label="Validation Loss", marker='s')
plt.title("Training vs Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# Save losses to file after training
import pickle
with open("loss_history.pkl", "wb") as f:
    pickle.dump((train_losses, val_losses), f)

Run Evaluation

In [None]:
# Set model to evaluation mode
model.eval()

correct, total = 0, 0
test_loss = 0.0

with torch.no_grad():  # No gradient calculation for evaluation
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)

        # Get predictions and compute loss
        y_hat = model(x)
        loss = criterion(y_hat, y)

        # Accumulate test loss
        test_loss += loss.item()  # Use .item() to get the scalar value

        # Compute accuracy
        _, predicted = torch.max(y_hat, dim=1)  # Get the class predictions
        correct += (predicted == y).sum().item()  # Sum the number of correct predictions
        total += y.size(0)  # Total number of samples in this batch

# Calculate final average test loss and accuracy
avg_test_loss = test_loss / len(test_loader)
test_accuracy = (correct / total) * 100

# Print results
print(f"Test Loss: {avg_test_loss:.2f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")

Finetuning/Transfer Learning Loop

In [None]:
# Load the pre-trained model
model.load_state_dict(torch.load("vit_cifar10.pth"))
model.to(device)

# This is adapted for transfer learning. Unfreeze all layers if you want full fine-tuning
for param in model.parameters():
    param.requires_grad = True

# Fine-tuning setup
N_EPOCHS_FT = 5
LR_FT = 0.001
PATIENCE = 3
MIN_DELTA = 0.001

optimizer = Adam(model.parameters(), lr=LR_FT)
scheduler = StepLR(optimizer, step_size=2, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# Early stopping setup
best_val_loss = float("inf")
patience_counter = 0

for epoch in range(N_EPOCHS_FT):
    model.train()
    train_loss = 0.0
    correct_train, total_train = 0, 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} fine-tuning", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(y_hat, dim=1)
        correct_train += (predicted == y).sum().item()
        total_train += y.size(0)

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = (correct_train / total_train) * 100

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch + 1}/{N_EPOCHS_FT} | Train Loss: {avg_train_loss:.4f} | "
          f"Train Acc: {train_accuracy:.2f}% | Val Loss: {avg_val_loss:.4f}")

    scheduler.step()

    # Early stopping logic
    if avg_val_loss + MIN_DELTA < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "vit_finetuned_best.pth")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

In [None]:
# Load best model after fine-tuning
model.load_state_dict(torch.load("vit_finetuned_best.pth"))