In [1]:
#Importing th required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm #PyTorch Image Model library

The PyTorch Image Models (timm) library is a powerful, open-source collection of state-of-the-art computer vision models, layers, utilities, optimizers, and schedulers for use with PyTorch.

Configurations for ViT Transfer Learning

In [2]:
NUM_CLASSES = 100
MODEL_NAME = 'vit_tiny_patch16_224'
IMG_SIZE = 224
BATCH_SIZE = 128
LEARNING_RATE = 0.1
NUM_EPOCHS = 1
DEVICE = "cpu"

We choose a small, common pre-trained ViT model (Tiny patch 16, 224x224 input).
The input data will be resized to 224x224 to match the pre-trained weights' expectations.

Data Preprocessing

In [3]:
#resizing & center crop to the size which the model expects (224 X 224)
#using the same standard deviation & mean as the original paper
transform_train = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomCrop(IMG_SIZE, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

In [4]:
#Loading the CIFAR-100 datasets
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

100%|██████████| 169M/169M [00:04<00:00, 38.9MB/s]


In [5]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)



Model initialization & Transfer Learning setup

In [6]:
model = timm.create_model(
    MODEL_NAME,
    pretrained=True, #using IMAGENET pretrained weights
    num_classes=NUM_CLASSES #replacing the classification head
).to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

In [7]:
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")

Model parameters: 5.54M


In [8]:
#freezing all layers except the newly added replaced head
for name, param in model.named_parameters():
  if 'head' not in name:
    param.requires_grad = False
  else:
    param.requires_grad = True #ensuring that the new head is trainable

In [9]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [10]:
#training loop
for epoch in range(NUM_EPOCHS):
  model.train()
  running_loss = 0.0

  for i, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

    optimizer.zero_grad()
    outputs = model(inputs)

    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  avg_train_loss = running_loss/len(train_loader)

  #evaluation
  model.eval()
  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

      outputs = model(inputs)

      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = correct/total
  print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f} | Test Accuracy: {accuracy:.2f}%")

Epoch 1/1 | Train Loss: 20.3714 | Test Accuracy: 0.43%


Saving the model parameters

In [11]:
SAVE_PATH = 'vit_cifar100_finetuned.pth'

In [12]:
torch.save(model.state_dict(), SAVE_PATH)