In [1]:
#https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c
#Code for finetuning is added

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose, Normalize, Resize
from tqdm import tqdm

In [2]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7d0929348d70>

Define VIT architecture

In [11]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches**2, h * w * c // n_patches**2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[
                    :,
                    i * patch_size : (i + 1) * patch_size,
                    j * patch_size : (j + 1) * patch_size,
                ]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches


class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        )
        self.k_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        )
        self.v_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        )
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head : (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head**0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d),
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out


class MyViT(nn.Module):
    def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=128, n_heads=8, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Attributes
        self.chw = chw  # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        # Input and patches sizes
        assert (
            chw[1] % n_patches == 0
        ), "Input shape not entirely divisible by number of patches"
        assert (
            chw[2] % n_patches == 0
        ), "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # 1) Linear mapper
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        self.register_buffer(
            "positional_embeddings",
            get_positional_embeddings(n_patches**2 + 1, hidden_d),
            persistent=False,
        )

        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList(
            [MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)]
        )

        # 5) Classification MLPk
        self.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1))

    def forward(self, images):
        # Dividing images into patches
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # Running linear layer tokenization
        # Map the vector corresponding to each patch to the hidden size dimension
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)

        # Adding positional embedding
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        # Transformer Blocks
        for block in self.blocks:
            out = block(out)

        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)  # Map to output dimension, output category distribution


def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = (
                np.sin(i / (10000 ** (j / d)))
                if j % 2 == 0
                else np.cos(i / (10000 ** ((j - 1) / d)))
            )
    return result

Preprocess Data

In [4]:
# Define CIFAR-10 dataset transformations
transform = Compose([
    Resize((32, 32)),  # Ensure the images are 32x32 (CIFAR-10 size)
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalizing CIFAR-10 images
])

# Load CIFAR-10 train and test sets
train_set = CIFAR10(
    root="./../datasets", train=True, download=True, transform=transform
)
test_set = CIFAR10(
    root="./../datasets", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./../datasets/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:01<00:00, 90.1MB/s]


Extracting ./../datasets/cifar-10-python.tar.gz to ./../datasets
Files already downloaded and verified


In [None]:
#For an arbitrary dataset stored under a directory "MyData"

#from torchvision import datasets, transforms
#from torch.utils.data import DataLoader
#from torchvision.transforms import Compose, ToTensor, Normalize, Resize
#from torchvision import datasets

# Define dataset transformations
#transform = Compose([
#    Resize((32, 32)),  # Resize images to 32x32 (or adjust this based on your dataset)
#    ToTensor(),
#    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalizing the images
#])

# Define paths to train and test folders (replace with your actual paths)
#train_dir = "MyData/train"
#test_dir = "MyData/test"

# Load the train and test datasets
#train_set = datasets.ImageFolder(root=train_dir, transform=transform)
#test_set = datasets.ImageFolder(root=test_dir, transform=transform)

# Create data loaders
#train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
#test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

# Check the number of classes and the class labels
#print(f"Number of classes: {len(train_set.classes)}")
#print(f"Class labels: {train_set.classes}")

Training Loop

In [12]:
# Define the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MyViT(
    (3, 32, 32), n_patches=4, n_blocks=2, hidden_d=8, n_heads=2, out_d=10
).to(device)

# Hyperparameters for training
N_EPOCHS = 5
LR = 0.005  # Initial learning rate for training

optimizer = Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(N_EPOCHS):
    model.train()  # Set model to training mode
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.detach().cpu().item() / len(train_loader)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

# Save the trained model after initial training
torch.save(model.state_dict(), "vit_cifar10.pth")



Epoch 1/5 loss: 2.18




Epoch 2/5 loss: 2.14




Epoch 3/5 loss: 2.13




Epoch 4/5 loss: 2.12


                                                                      

Epoch 5/5 loss: 2.11




Check model performance

In [19]:
# Set model to evaluation mode
model.eval()

correct, total = 0, 0
test_loss = 0.0

with torch.no_grad():  # No gradient calculation for evaluation
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)

        # Get predictions and compute loss
        y_hat = model(x)
        loss = criterion(y_hat, y)

        # Accumulate test loss
        test_loss += loss.item()  # Use .item() to get the scalar value

        # Compute accuracy
        _, predicted = torch.max(y_hat, dim=1)  # Get the class predictions
        correct += (predicted == y).sum().item()  # Sum the number of correct predictions
        total += y.size(0)  # Total number of samples in this batch

# Calculate final average test loss and accuracy
avg_test_loss = test_loss / len(test_loader)
test_accuracy = (correct / total) * 100

# Print results
print(f"Test Loss: {avg_test_loss:.2f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")

Testing: 100%|██████████| 79/79 [00:12<00:00,  6.13it/s]

Test Loss: 2.08
Test Accuracy: 37.03%





Finetuning

In [16]:
#Optionally added a learning rate scheduler
from torch.optim.lr_scheduler import StepLR

# Load the pre-trained model
model.load_state_dict(torch.load("vit_cifar10.pth"))
model.to(device)

# Unfreeze all layers if necessary (optional)
for param in model.parameters():
    param.requires_grad = True

model.train()  # Set model to training mode

# Fine-tuning setup
N_EPOCHS_FT = 5  # Fine-tuning for 5 additional epochs
LR_FT = 0.001  # Initial learning rate for fine-tuning
optimizer = Adam(model.parameters(), lr=LR_FT)

# Set up a learning rate scheduler
scheduler = StepLR(optimizer, step_size=2, gamma=0.1)  # Decay learning rate every 2 epochs by 10%

# Fine-tuning loop
for epoch in range(N_EPOCHS_FT):
    model.train()
    train_loss = 0.0
    correct_train, total_train = 0, 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in fine-tuning", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.item()  # Accumulate loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accuracy calculation
        _, predicted = torch.max(y_hat, dim=1)
        correct_train += (predicted == y).sum().item()
        total_train += y.size(0)

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = (correct_train / total_train) * 100

    print(f"Epoch {epoch + 1}/{N_EPOCHS_FT} fine-tuning loss: {avg_train_loss:.2f}, "
          f"Train Accuracy: {train_accuracy:.2f}%")

    # Step the learning rate scheduler
    scheduler.step()

  model.load_state_dict(torch.load("vit_cifar10.pth"))


Epoch 1/5 fine-tuning loss: 2.09, Train Accuracy: 36.04%




Epoch 2/5 fine-tuning loss: 2.09, Train Accuracy: 36.44%




Epoch 3/5 fine-tuning loss: 2.08, Train Accuracy: 37.01%




Epoch 4/5 fine-tuning loss: 2.08, Train Accuracy: 37.12%


                                                                         

Epoch 5/5 fine-tuning loss: 2.08, Train Accuracy: 37.17%




In [17]:
# Testing the fine-tuned model
model.eval()  # Set model to evaluation mode

correct, total = 0, 0
test_loss = 0.0
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)

print(f"Fine-Tuned Test loss: {test_loss:.2f}")
print(f"Fine-Tuned Test accuracy: {correct / total * 100:.2f}%")

Testing: 100%|██████████| 79/79 [00:13<00:00,  5.80it/s]

Fine-Tuned Test loss: 2.08
Fine-Tuned Test accuracy: 37.03%





In [18]:
# Save the fine-tuned model
torch.save(model.state_dict(), "vit_cifar10_finetuned.pth")