In [None]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

Transfer Learning for Computer Vision Tutorial
==============================================

**Author**: [Sasank Chilamkurthy](https://chsasank.github.io)

In this tutorial, you will learn how to train a convolutional neural
network for image classification using transfer learning. You can read
more about the transfer learning at [cs231n
notes](https://cs231n.github.io/transfer-learning/)

Quoting these notes,

> In practice, very few people train an entire Convolutional Network
> from scratch (with random initialization), because it is relatively
> rare to have a dataset of sufficient size. Instead, it is common to
> pretrain a ConvNet on a very large dataset (e.g. ImageNet, which
> contains 1.2 million images with 1000 categories), and then use the
> ConvNet either as an initialization or a fixed feature extractor for
> the task of interest.

These two major transfer learning scenarios look as follows:

-   **Finetuning the ConvNet**: Instead of random initialization, we
    initialize the network with a pretrained network, like the one that
    is trained on imagenet 1000 dataset. Rest of the training looks as
    usual.
-   **ConvNet as fixed feature extractor**: Here, we will freeze the
    weights for all of the network except that of the final fully
    connected layer. This last fully connected layer is replaced with a
    new one with random weights and only this layer is trained.


In [None]:
# License: BSD
# Author: Sasank Chilamkurthy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from collections import Counter
from tempfile import TemporaryDirectory
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, OneCycleLR
import random
from sklearn.metrics import classification_report, accuracy_score, roc_auc_score, confusion_matrix


cudnn.benchmark = True
plt.ion()   # interactive mode

Load Data
=========

We will use torchvision and torch.utils.data packages for loading the
data.

The problem we\'re going to solve today is to train a model to classify
**ants** and **bees**. We have about 120 training images each for ants
and bees. There are 75 validation images for each class. Usually, this
is a very small dataset to generalize upon, if trained from scratch.
Since we are using transfer learning, we should be able to generalize
reasonably well.

This dataset is a very small subset of imagenet.


In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Hyperparameters
num_epochs = 200
lr= 1e-3  # higher lr for classifier
momentum = 0.9
step_size = 7 
gamma = 0.1
weight_decay = 1e-4
batch_size = 64
patience = 20
label_smoothing = 0.05
warmup_epochs = 5

In [None]:

# # Load the .npy dataset
# path_to_fer = 'fer13_v2.npy'
# m = np.load(path_to_fer, allow_pickle=True).item()
# x_train, y_train = m['train']
# x_val, y_val = m['val']
# x_test, y_test = m['test']

# # Define custom dataset class
# class FERDataset(Dataset):
#     def __init__(self, images, labels, transform=None):
#         self.images = images
#         self.labels = labels
#         self.transform = transform
    
#     def __len__(self):
#         return len(self.images)
    
#     def __getitem__(self, idx):
#         image = self.images[idx].astype(np.uint8)
#         label = int(self.labels[idx])  # Ensure label is an integer
#         return self.transform(image), label

# # Define transformations
# data_transforms = {
#     'train': transforms.Compose([
#         transforms.ToPILImage(),
#         transforms.Resize((224, 224)),  # Resize from 48x48 to 224x224
#         transforms.Grayscale(num_output_channels=3),  # Convert 1-channel to 3-channel
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize([0.5], [0.5])  # Normalize for grayscale
#     ]),
#     'val': transforms.Compose([
#         transforms.ToPILImage(),
#         transforms.Resize((224, 224)),  # Resize from 48x48 to 224x224
#         transforms.Grayscale(num_output_channels=3),  # Convert 1-channel to 3-channel
#         transforms.ToTensor(),
#         transforms.Normalize([0.5], [0.5])  # Normalize for grayscale
#     ]),
# }

# # Create datasets
# datasets = {
#     'train': FERDataset(x_train, y_train, transform=data_transforms['train']),
#     'val': FERDataset(x_val, y_val, transform=data_transforms['val']),
#     'test': FERDataset(x_test, y_test, transform=data_transforms['val'])
# }

# # Create dataloaders
# dataloaders = {
#     x: DataLoader(datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
#     for x in ['train', 'val', 'test']
# }

# dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val', 'test']}

# # Get the number of unique classes in the dataset
# num_classes = len(np.unique(np.concatenate([y_train, y_val, y_test])))

# print(f"Number of classes: {num_classes}")


# # Select device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using {device} device")

In [None]:
# Load the .npy dataset
path_to_fer = 'fer13_v2.npy'
m = np.load(path_to_fer, allow_pickle=True).item()
x_train, y_train = m['train']
x_val, y_val = m['val']
x_test, y_test = m['test']

In [None]:

# Count the number of samples per class for each dataset
train_counts = Counter(y_train)
val_counts = Counter(y_val)
test_counts = Counter(y_test)

# Get sorted class labels
class_labels = sorted(set(y_train) | set(y_val) | set(y_test))

# Create a bar plot
fig, ax = plt.subplots(figsize=(10, 5))
bar_width = 0.25
index = np.arange(len(class_labels))

ax.bar(index, [train_counts[c] for c in class_labels], bar_width, label='Train')
ax.bar(index + bar_width, [val_counts[c] for c in class_labels], bar_width, label='Validation')
ax.bar(index + 2 * bar_width, [test_counts[c] for c in class_labels], bar_width, label='Test')

ax.set_xlabel('Class Label')
ax.set_ylabel('Number of Images')
ax.set_title('Number of Images per Class')
ax.set_xticks(index + bar_width)
ax.set_xticklabels(class_labels)
ax.legend()

plt.show()

In [None]:
TARGET_COUNT = 4000

# Count the number of samples per class in training set
train_counts = Counter(y_train)

# Data augmentation for training set
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((384, 384), interpolation=InterpolationMode.BILINEAR),
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Match EfficientNet normalization
])

# **No augmentation for validation and test sets**
val_test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((384, 384), interpolation=InterpolationMode.BILINEAR),
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Balance dataset by removing excess images and augmenting underrepresented classes
balanced_x_train = []
balanced_y_train = []

for label in train_counts:
    indices = [i for i, y in enumerate(y_train) if y == label]
    if train_counts[label] > TARGET_COUNT:
        # Randomly select TARGET_COUNT images
        selected_indices = random.sample(indices, TARGET_COUNT)
    else:
        # Keep all images and perform augmentation to reach TARGET_COUNT
        selected_indices = indices.copy()
        while len(selected_indices) < TARGET_COUNT:
            selected_indices.append(random.choice(indices))
    
    for idx in selected_indices:
        balanced_x_train.append(x_train[idx])
        balanced_y_train.append(y_train[idx])

# Convert balanced dataset to numpy arrays
balanced_x_train = np.array(balanced_x_train)
balanced_y_train = np.array(balanced_y_train)

# Define dataset class
class FERDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx].astype(np.uint8)
        label = int(self.labels[idx])
        return self.transform(image), label

# Create datasets with appropriate transforms
datasets = {
    'train': FERDataset(balanced_x_train, balanced_y_train, transform=train_transforms),
    'val': FERDataset(x_val, y_val, transform=val_test_transforms),  # No augmentation
    'test': FERDataset(x_test, y_test, transform=val_test_transforms)  # No augmentation
}

# Get the number of unique classes in the dataset
num_classes = len(np.unique(np.concatenate([y_train, y_val, y_test])))

print(f"Number of classes: {num_classes}")

# Create dataloaders
dataloaders = {
    x: DataLoader(datasets[x], batch_size=batch_size, shuffle=(x == 'train'), num_workers=4)
    for x in ['train', 'val', 'test']
}

# Count final dataset sizes
dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val', 'test']}

print(f"Final dataset sizes: {dataset_sizes}")

In [None]:
dataset_sizes

Visualize a few images
======================

Let\'s visualize a few training images so as to understand the data
augmentations.


In [None]:
def show_sample_images(dataloader, num_images=6):
    images_shown = 0
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for images, labels in dataloader:
        for i in range(num_images):
            image = images[i].permute(1, 2, 0).numpy() * 0.5 + 0.5
            axes[i].imshow(image, cmap='gray')
            axes[i].set_title(f"Label: {labels[i]}")
            axes[i].axis('off')
            images_shown += 1
            if images_shown >= num_images:
                plt.show()
                return
            
# Show sample images
show_sample_images(dataloaders['train'])

Training the model
==================

Now, let\'s write a general function to train a model. Here, we will
illustrate:

-   Scheduling the learning rate
-   Saving the best model

In the following, parameter `scheduler` is an LR scheduler object from
`torch.optim.lr_scheduler`.


In [None]:




# Define training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, patience=5):
    since = time.time()
    
    # Create models directory
    model_dir = "models"
    os.makedirs(model_dir, exist_ok=True)
    best_model_path = os.path.join(model_dir, 'best_model.pth')
    
    best_acc = 0.0
    patience_counter = 0
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            if phase == 'train':
                scheduler.step()
                train_losses.append(epoch_loss)
                train_accuracies.append(epoch_acc.item())
            else:
                val_losses.append(epoch_loss)
                val_accuracies.append(epoch_acc.item())
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), best_model_path)
                patience_counter = 0  # Reset early stopping counter
            elif phase == 'val':
                patience_counter += 1
        
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')
    
    # Load best model
    model.load_state_dict(torch.load(best_model_path))
    
    # Plot training results
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training & Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Acc')
    plt.plot(val_accuracies, label='Val Acc')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Training & Validation Accuracy')
    
    plt.show()
    
    return model

Visualizing the model predictions
=================================

Generic function to display predictions for a few images


In [None]:
# Function to unnormalize and show an image
def imshow(img):
    """Display a tensor image (Grayscale or RGB)"""
    img = img.numpy().transpose((1, 2, 0))  # Convert from Tensor format (C, H, W) to (H, W, C)
    img = img * 0.5 + 0.5  # Unnormalize (assuming Normalize([0.5], [0.5]))
    
    if img.shape[-1] == 1:  # If the image is grayscale, remove the last dimension
        img = img.squeeze(-1)
    
    plt.imshow(img, cmap="gray" if len(img.shape) == 2 else None)  # Use 'gray' colormap if grayscale
    plt.axis("off")

# Function to visualize model predictions
def visualize_model(model, num_images=6):
    """Visualizes model predictions on the validation set"""
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure(figsize=(10, num_images * 2))

    with torch.no_grad():
        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size(0)):
                images_so_far += 1
                ax = plt.subplot((num_images + 1) // 2, 2, images_so_far)
                ax.axis("off")
                ax.set_title(f'Predicted: {preds[j].item()} | True: {labels[j].item()}')
                
                imshow(inputs.cpu().data[j])  # Show image

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
    model.train(mode=was_training)

Finetuning the ConvNet
======================

Load a pretrained model and reset final fully connected layer.


In [None]:

# Step 1: Load and modify the model
model_ft = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)

# Freeze entire model first
for param in model_ft.parameters():
    param.requires_grad = False

# Unfreeze classifier parameters
for param in model_ft.classifier[1].parameters():
    param.requires_grad = True

model_ft = model_ft.to(device)

# Define loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

# Step 4: Set up optimizer (AdamW is better for weight decay handling)
optimizer_ft = optim.AdamW(model_ft.classifier.parameters(), lr=lr, weight_decay=weight_decay)

# Step 5: One-cycle learning rate scheduler (more stable than cosine)
scheduler = OneCycleLR(optimizer_ft, max_lr=1e-3, 
                        steps_per_epoch=len(dataloaders['train']), 
                        epochs=num_epochs, 
                        pct_start=0.1)  # Warm-up for 10% of training

Train and evaluate
==================

It should take around 15-25 min on CPU. On GPU though, it takes less
than a minute.


In [None]:

# Run training

model_ft = train_model(model_ft, criterion, optimizer_ft, scheduler, num_epochs=num_epochs, patience=patience)

In [None]:
def evaluate_model(model, dataloader, num_classes, device):
    model.eval()  # Set model to evaluation mode
    y_true = []
    y_pred = []
    y_prob = []  # Store probability scores for AUC calculation

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)  # Forward pass

            probs = torch.softmax(outputs, dim=1)  # Convert logits to probabilities
            _, preds = torch.max(outputs, 1)  # Get predicted class

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())  # Store probabilities

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    # Compute classification report (Precision, Recall, F1-score)
    class_report = classification_report(y_true, y_pred, digits=4)

    # Compute Accuracy
    accuracy = accuracy_score(y_true, y_pred)

    # Compute AUC (for multi-class classification using One-vs-Rest)
    y_true_bin = np.eye(num_classes)[y_true]  # Convert to one-hot encoding
    auc_score = roc_auc_score(y_true_bin, np.array(y_prob), multi_class="ovr")

    return cm, accuracy, auc_score, class_report


# Function to plot confusion matrix
def plot_confusion_matrix(cm, class_labels, normalize=False):
    plt.figure(figsize=(9, 8))
    
    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]  # Normalize by row
    
    plt.imshow(cm, interpolation="nearest", cmap="Blues")
    plt.title("Confusion Matrix", fontsize=20)
    plt.colorbar()

    # Add text annotations
    fmt = ".2f" if normalize else "d"
    for i in range(len(cm)):
        for j in range(len(cm[i])):
            plt.text(j, i, format(cm[i, j], fmt), ha="center", va="center", 
                     color="white" if cm[i, j] > cm.max() / 2 else "black")

    plt.xticks(np.arange(len(class_labels)), class_labels, rotation=45)
    plt.yticks(np.arange(len(class_labels)), class_labels)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("Actual Label", fontsize=14)
    plt.show()

In [None]:
# Example usage:
cm, accuracy, auc_score, class_report = evaluate_model(model_ft, dataloaders['test'], num_classes=num_classes, device="cuda")

print(f"Accuracy: {accuracy:.4f}")
print(f"AUC Score: {auc_score:.4f}")
print("\nClassification Report:\n", class_report)

# Plot confusion matrix
plot_confusion_matrix(cm, class_labels=[f"Class {i}" for i in range(num_classes)], normalize=False)

In [None]:
visualize_model(model_ft)

Inference on test dataset
==========================

Use the trained model to make predictions on test dataset and visualize
the predicted class labels along with the images.


In [None]:
import torch
import torchvision.models as models
import torch.nn as nn

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the number of classes (7 in your case)
num_classes = 7

# Initialize the EfficientNet-B0 model
model_ft = models.efficientnet_b0(weights=None)  # No pre-trained weights
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)

# Move model to device
model_ft = model_ft.to(device)

# Load the saved weights
model_path = "models/best_model_0.0001.pth"  # Update this path if your model is saved elsewhere
model_ft.load_state_dict(torch.load(model_path, map_location=device))

# Set model to evaluation mode
model_ft.eval()

print("Model loaded successfully and ready for inference!")

In [None]:
# Example usage:
cm, accuracy, auc_score, class_report = evaluate_model(model_ft, dataloaders['test'], num_classes=num_classes, device="cuda")

print(f"Accuracy: {accuracy:.4f}")
print(f"AUC Score: {auc_score:.4f}")
print("\nClassification Report:\n", class_report)

# Plot confusion matrix
plot_confusion_matrix(cm, class_labels=[f"Class {i}" for i in range(num_classes)], normalize=False)

In [None]:

cm = np.array([
    [468,  29, 121,  84, 154,  36, 102],
    [  8,  33,   5,   2,   6,   2,   0],
    [ 72,  12, 328,  59, 134,  74,  81],
    [ 99,  10,  89, 1413, 125,  65, 154],
    [123,  14, 181,  69, 508,  17, 132],
    [ 36,   4, 130,  37,  25, 579,  40],
    [163,   1, 163, 147, 284,  54, 706]
])


num_classes = cm.shape[0]
class_labels = [f"Class {i}" for i in range(num_classes)]  # Create class names

# Compute y_true and y_pred from confusion matrix
y_true = []
y_pred = []
for actual in range(num_classes):
    for predicted in range(num_classes):
        y_true.extend([actual] * cm[actual, predicted])  # Repeat actual class label
        y_pred.extend([predicted] * cm[actual, predicted])  # Repeat predicted class label

# Compute classification metrics
accuracy = accuracy_score(y_true, y_pred)
class_report = classification_report(y_true, y_pred, target_names=class_labels, digits=4)

# Compute AUC score (convert y_true to one-hot encoding)
y_true_bin = np.eye(num_classes)[y_true]  # One-hot encoding
y_pred_bin = np.eye(num_classes)[y_pred]  # One-hot predictions
auc_score = roc_auc_score(y_true_bin, y_pred_bin, multi_class="ovr")

# Function to plot confusion matrix
def plot_confusion_matrix(cm, class_labels, normalize=False):
    plt.figure(figsize=(9, 8))

    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]  # Normalize by row
    
    plt.imshow(cm, interpolation="nearest", cmap="Blues")
    plt.title("Confusion Matrix", fontsize=20)
    plt.colorbar()

    fmt = ".2f" if normalize else "d"
    for i in range(len(cm)):
        for j in range(len(cm[i])):
            plt.text(j, i, format(cm[i, j], fmt), ha="center", va="center", 
                     color="white" if cm[i, j] > cm.max() / 2 else "black")

    plt.xticks(np.arange(len(class_labels)), class_labels, rotation=45)
    plt.yticks(np.arange(len(class_labels)), class_labels)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("Actual Label", fontsize=14)
    plt.show()

# Print results
print(f"Accuracy: {accuracy:.4f}")
print(f"AUC Score: {auc_score:.4f}")
print("\nClassification Report:\n", class_report)

# Plot confusion matrix (normalized)
plot_confusion_matrix(cm, class_labels, normalize=False)