<a href="https://colab.research.google.com/github/ikhlas15/ATHENS-AI-Medical-Imaging/blob/main/H12_transfer_learning_medical_imaging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 12: Transfer Learning for Medical Image Classification**

### **Course**: Artificial Intelligence in Medical Imaging: From Fundamentals to Applications

***

## **1. Introduction**

Welcome to Notebook 12! So far, we have trained our own CNNs from scratch. While this is a great way to learn, it can be inefficient, especially when working with small medical datasets. Today, we will explore one of the most powerful and widely used techniques in deep learning: **Transfer Learning**.

Transfer learning is the process of taking a model that has been pre-trained on a very large dataset (like ImageNet, which contains millions of everyday images) and adapting it for a new, specific task (like detecting pneumonia in chest X-rays). The core idea is that the features learned on the large dataset—such as edges, textures, shapes, and object parts—are often general enough to be useful for our new task.

#### **What you will learn today:**
*   The motivation and benefits of using transfer learning in medical imaging AI.
*   How to load a state-of-the-art model (ResNet18) pre-trained on ImageNet.
*   How to adapt the pre-trained model for our specific medical imaging task.
*   The two-phase fine-tuning strategy:
    1.  **Feature Extraction:** Freezing the pre-trained layers and training only the new classification head.
    2.  **Fine-Tuning:** Unfreezing the entire network and continuing to train with a small learning rate.
*   How to handle common challenges, such as mismatched input channels (grayscale vs. RGB) and image sizes.

***

## **2. Setup: Installing and Importing Libraries**

Let's begin by preparing our environment.

In [None]:
# Install required packages
!pip install torch torchvision medmnist

import torch
import torch.nn as nn
import torch.optim as optim
# Import Dataloader from Pytorch

import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
#TODO import PneumoniaMNIST dataset from medmnist


# Set our standard random seed and device
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {...}")

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2
Using device: cuda


***

## **3. Preparing Data for Transfer Learning**

This is a critical step. Pre-trained models have specific expectations for their input data because they were trained in a particular way. To use a model pre-trained on ImageNet, we must preprocess our medical images to match the ImageNet format as closely as possible.

This involves three key transformations:
1.  **Resize Images:** ImageNet models are typically trained on 224x224 images. We must resize our smaller 28x28 PneumoniaMNIST images to this size.
2.  **Handle Input Channels:** Our X-ray images are grayscale (1 channel), but ImageNet models expect 3-channel RGB images. The simplest solution is to duplicate the single grayscale channel three times.
3.  **Normalize with ImageNet Statistics:** We must normalize our images using the exact `mean` and `standard deviation` that were used to train the original model. For ImageNet, these are standard, well-known values.


In [None]:
# Define the special transformations for transfer learning
# Hint: Use transforms.Compose([...]) to stack multiple operations
transfer_learning_transforms = ...([
    transforms.Resize((224, 224)),

    # Hint: Convert 1-channel images to 3 channels for ResNet
    transforms.Grayscale(num_output_channels=____),

    transforms.ToTensor(),

    # Hint: Normalize using ImageNet mean and std lists
    transforms. ... (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Load the datasets with these new transforms
# Hint: Fill "train" or "val" for the dataset split
train_dataset = PneumoniaMNIST(split='______',
                               transform=transfer_learning_transforms,
                               download=True)

val_dataset = PneumoniaMNIST(split='_____',
                             transform=...,
                             download=True)


# Create DataLoaders
# Hint: Common batch size is 32, shuffle only the training set
train_loader = DataLoader(dataset=train_dataset, batch_size=____, shuffle=______)
val_loader   = DataLoader(dataset=val_dataset, batch_size=____, shuffle=______)


print(f"Training on {len(train_dataset)} images.")
print(f"Validating on {len(val_dataset)} images.")


100%|██████████| 4.17M/4.17M [00:01<00:00, 3.47MB/s]


Training on 4708 images.
Validating on 524 images.


***

## **4. Loading and Adapting the Pre-trained Model**

We will use **ResNet18**, a popular and efficient architecture, pre-trained on ImageNet.

### **4.1. Loading the Model**
We load the model using `torchvision.models`, setting `weights=ResNet18_Weights.DEFAULT` to download the learned weights.

In [None]:
# Load the pretrained ResNet18 model
model = models....(weights="ResNet18_Weights.DEFAULT")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 169MB/s]


### **4.2. Adapting the Model**
The loaded model is designed for ImageNet. We need to make two changes:
1.  (Optional but good practice) We already handled the input channel mismatch in our data transforms, so we don't need to change the first convolutional layer.
2.  **Replace the Classifier Head:** The final layer of ResNet18, `model.fc`, is a `Linear` layer that outputs 1000 values (for the 1000 ImageNet classes). We must replace this with a new `Linear` layer that outputs 2 values for our binary classification task (Normal vs. Pneumonia).


In [None]:
# Get the number of input features for the classifier
num_ftrs = model.fc. ...

# Replace the final fully connected layer with a new one for our task
model.fc = nn.Linear(..., 2)

# Move the model to the GPU
model = model.to(device)

print("--- Model architecture adapted for our task ---")
# print(model) # Uncomment to see the full architecture

--- Model architecture adapted for our task ---


***

## **5. Fine-Tuning Strategy: Phase 1 (Train the Head Only)**

Our model now consists of a frozen, pre-trained "body" (the convolutional layers) and a new, randomly initialized "head" (our `Linear` layer). If we start training the whole network immediately, the large, random gradients from the untrained head could corrupt the finely-tuned weights of the body.

The best practice is to first **freeze the body and train only the head**.

### **5.1. Freezing the Pre-trained Layers**
We loop through all the model's parameters and set their `requires_grad` attribute to `False`.

In [None]:
# Freeze all the parameters in the model
for param in model.parameters():
    param.requires_grad = ...

# Unfreeze ONLY the parameters of the new final layer
for param in model.fc.parameters():
    param.requires_grad = ...

### **5.2. Training the Head**
Now, we create an optimizer that will *only* update the parameters where `requires_grad` is `True`.

In [None]:
# Create an optimizer that only updates the parameters of the new classifier
# Hint: Use optim.Adam and filter only parameters with requires_grad=True
optimizer = optim. ...(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
# Hint: Use the standard classification loss for multi-class problems
criterion = nn. ...

# Re-usable training function from previous notebooks
def train_one_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    total_loss, total_correct = 0.0, 0
    for images, labels in data_loader:
         #Hint: Move both images and labels to the correct device
        images, labels = images. ... , labels.... .squeeze().long()
         # Hint: reset gradients

        outputs = model(images)
        loss = criterion(outputs, labels)
        # Hint: compute gradients
        loss.
        # Hint: update parameters
        optimizer.
        total_loss += loss.item() * images.size(0)
        total_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
    return total_loss / len(data_loader.dataset), total_correct / len(data_loader.dataset)

# Train for a few epochs
print("--- Phase 1: Training the classifier head ---")
num_epochs_phase1 =  # Set a number (e.g 3)
for epoch in range(num_epochs_phase1):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Epoch [{epoch+1}/{num_epochs_phase1}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

--- Phase 1: Training the classifier head ---
Epoch [1/3] - Train Loss: 0.3450, Train Acc: 0.8549
Epoch [2/3] - Train Loss: 0.2334, Train Acc: 0.9034
Epoch [3/3] - Train Loss: 0.2080, Train Acc: 0.9153


***

## **6. Fine-Tuning Strategy: Phase 2 (Unfreeze and Train All Layers)**

Now that our new head is trained and stable, we can **unfreeze the entire network** and continue training. This will allow the pre-trained feature extractor to slightly adjust its weights to better suit our specific medical dataset.

It is crucial to use a **very small learning rate** during this phase to avoid making drastic changes that would destroy the valuable pre-trained features.


In [None]:
# Unfreeze all layers
for param in model.parameters():
    param.requires_grad =

# Create a new optimizer for the whole model with a very low learning rate
optimizer = optim.Adam(...., lr=1e-5) # A much smaller learning rate

# Continue training for a few more epochs
print("\n--- Phase 2: Fine-tuning the entire model ---")
num_epochs_phase2 = 5
for epoch in range(...):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Epoch [{epoch+1}/{num_epochs_phase2}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")



--- Phase 2: Fine-tuning the entire model ---
Epoch [1/5] - Train Loss: 0.1539, Train Acc: 0.9382
Epoch [2/5] - Train Loss: 0.0742, Train Acc: 0.9794
Epoch [3/5] - Train Loss: 0.0433, Train Acc: 0.9892
Epoch [4/5] - Train Loss: 0.0283, Train Acc: 0.9938
Epoch [5/5] - Train Loss: 0.0165, Train Acc: 0.9975


***

## **7. Final Evaluation**

Finally, let's evaluate our fully fine-tuned model on the validation set.

In [None]:
# Re-usable evaluate function
def evaluate(model, data_loader, device):
  # Hint: Put the model in evaluation mode
    model.
    total_correct = 0
    # Hint: Disable gradient computation during evaluation
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device).squeeze().long()
            # Hint: do feed fo
            outputs = model(images)
            # Hint: Use torch.argmax(...) to get predicted class indices
            total_correct += (torch. ...(outputs, dim=1) == labels).sum().item()
    # Hint: accuracy = correct_predictions / total_samples
    return total_correct / len(data_loader. ...)

# Run validation
val_accuracy = ...(model, val_loader, device)
print(f"\nFinal Validation Accuracy: {val_accuracy:.4f}")


Final Validation Accuracy: 0.9637


You should see a very high accuracy, likely much higher than what our simple baseline CNN from scratch could achieve, and with far less training time!

***

## **8. Summary and Next Steps**


In the final notebook of this course, **`13_explainability_interpretability.ipynb`**, we will explore techniques to "look inside the black box" of our trained models to understand *why* they are making certain predictions, a critical step for building trust and deploying AI in clinical settings.
