Pneumonia is one of the leading respiratory illnesses worldwide, and its timely and accurate diagnosis is essential for effective treatment. Manually reviewing chest X-rays is a critical step in this process, and AI can provide valuable support by helping to expedite the assessment. Deep learning models can distinguish pneumonia cases from normal images of lungs in chest X-rays.

By fine-tuning a pre-trained convolutional neural network, specifically the ResNet-50 model, you can classify X-ray images into two categories: normal lungs and those affected by pneumonia. You can leverage its already trained weights and get an accurate classifier trained faster and with fewer resources.

## The Data

<img src="x-rays_sample.png" align="center"/>
&nbsp

The dataset of chest X-rays have been preprocessed for use with a ResNet-50 model by calling `transforms.Resize(224)` and `transforms.CenterCrop(224)`. You can see a sample of 5 images from each category above. The dataset inside the `data/chestxrays` folder is divided into `test` and `train` folders.

There are 150 training images and 50 testing images for each category, NORMAL and PNEUMONIA (300 and 100 in total). This data has been loaded into a `train_loader` and a `test_loader` using the `DataLoader` class from the PyTorch library.

In [None]:
# -------
# Install
# -------
# !pip install torch torchvision torchmetrics

In [1]:
# -------------------------
# Import required libraries
# -------------------------

# Data loading
import random
import os
import shutil
import numpy as np
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Train model
import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import torch.optim as optim

# Evaluate model
from torchmetrics import Accuracy, F1Score

# Check for GPU availability
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print("Using device:", device)

# Check for MPS availability
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Using device:", device)

Using device: mps


In [2]:
#---------------------------------------------------------------------------------
# Move 50 random images per class from the training set to create a validation set
#---------------------------------------------------------------------------------
def move_files(src_class_dir, dest_class_dir, n=50):
    if not os.path.exists(dest_class_dir):
        os.makedirs(dest_class_dir)
    files = os.listdir(src_class_dir)
    random_files = random.sample(files, n)
    for f in random_files:
        shutil.move(os.path.join(src_class_dir, f), os.path.join(dest_class_dir, f))

if not os.path.exists('data/chestxrays/val'):
    move_files('data/chestxrays/train/NORMAL', 'data/chestxrays/val/NORMAL')
    move_files('data/chestxrays/train/PNEUMONIA', 'data/chestxrays/val/PNEUMONIA')


In [3]:
#------------------------------------
# Transformations and create datasets
#------------------------------------

# Define the transformations to apply to the images for use with ResNet-50.
# The images need to be normalized to the same domain as the original training data of ResNet-50 network.
# Normalize the X-rays using transforms. Normalize function that takes as input the means and
# standard deviations of the three color channels, (R,G,B), from the original ResNet-50 training dataset.
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]

# Training transforms: Add horizontal flip for augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std)
])

# Validation and test transforms: no augmentation, just normalization
val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std)
])

# Create datasets
train_dataset = ImageFolder('data/chestxrays/train', transform=train_transform)
val_dataset = ImageFolder('data/chestxrays/val', transform=val_test_transform)
test_dataset = ImageFolder('data/chestxrays/test', transform=val_test_transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

print("Training set size:", len(train_dataset))
print("Validation set size:", len(val_dataset))
print("Test set size:", len(test_dataset))

Training set size: 200
Validation set size: 100
Test set size: 100


In [4]:
#----------------------
# Instantiate the model
#----------------------

# Load the pre-trained ResNet-50 model with ew weights with accuracy 80.858%
resnet50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

#-----------------
# Modify the model
#-----------------

# Freeze the parameters of the model
for param in resnet50.parameters():
    param.requires_grad = False

# Modify the final layer for binary classification
resnet50.fc = nn.Linear(resnet50.fc.in_features, 1)

# Set the model to ResNet-50
model = resnet50

# Move the model to the selected device (GPU, MPS, or CPU)
model.to(device)

#-------------------------
# Define the training loop
#-------------------------

# Training function with validation and early stopping
def train_with_validation(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=5):
    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        running_loss = 0.0
        running_accuracy = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            labels = labels.float().unsqueeze(1)
            
            # Use mixed precision training for forward pass and loss computation
            with torch.amp.autocast(device_type=device.type):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            preds = torch.sigmoid(outputs) > 0.5
            running_loss += loss.item() * inputs.size(0)
            running_accuracy += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        # train_acc = running_accuracy.double() / len(train_loader.dataset) # For GPU/CPU
        train_acc = running_accuracy.float() / len(train_loader.dataset) # For MPS

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_accuracy = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                labels = labels.float().unsqueeze(1)
                
                # Use mixed precision training for forward pass and loss computation
                with torch.amp.autocast(device_type=device.type):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                
                preds = torch.sigmoid(outputs) > 0.5
                val_loss += loss.item() * inputs.size(0)
                val_accuracy += torch.sum(preds == labels.data)

        val_loss = val_loss / len(val_loader.dataset)
        # val_acc = val_accuracy.double() / len(val_loader.dataset)  # For GPU/CPU
        val_acc = val_accuracy.float() / len(val_loader.dataset) # For MPS

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Step the scheduler with the validation loss
        scheduler.step(val_loss)

        # Early Stopping Check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            # Save model
            torch.save(model.state_dict(), 'model.pth')
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping triggered")
            break

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))

#--------------------
# Fine-tune the model
#--------------------

# Set up loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)

# Use ReduceLROnPlateau scheduler to reduce LR if validation loss stagnates
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# Decay lr by 10% every epoch (alternative scheduler)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

# Train the model with early stopping and validation
train_with_validation(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=5)



Epoch [1/50], Train Loss: 0.6884, Train Acc: 0.5250, Val Loss: 0.6568, Val Acc: 0.5300
Epoch [2/50], Train Loss: 0.5919, Train Acc: 0.8050, Val Loss: 0.6311, Val Acc: 0.8500
Epoch [3/50], Train Loss: 0.5410, Train Acc: 0.8650, Val Loss: 0.6122, Val Acc: 0.8300
Epoch [4/50], Train Loss: 0.4818, Train Acc: 0.9050, Val Loss: 0.5752, Val Acc: 0.7600
Epoch [5/50], Train Loss: 0.4578, Train Acc: 0.8750, Val Loss: 0.5368, Val Acc: 0.8200
Epoch [6/50], Train Loss: 0.3998, Train Acc: 0.9100, Val Loss: 0.5058, Val Acc: 0.8700
Epoch [7/50], Train Loss: 0.3749, Train Acc: 0.9050, Val Loss: 0.4880, Val Acc: 0.9100
Epoch [8/50], Train Loss: 0.3500, Train Acc: 0.9450, Val Loss: 0.4591, Val Acc: 0.8900
Epoch [9/50], Train Loss: 0.3341, Train Acc: 0.9200, Val Loss: 0.4286, Val Acc: 0.9100
Epoch [10/50], Train Loss: 0.3261, Train Acc: 0.9100, Val Loss: 0.4025, Val Acc: 0.9100
Epoch [11/50], Train Loss: 0.3291, Train Acc: 0.9050, Val Loss: 0.3904, Val Acc: 0.9000
Epoch [12/50], Train Loss: 0.3167, Train 

### Below is the model evaluation code which evaluates the accuracy and F1-score of the fine-tuned model.

In [5]:
#-------------------
# Evaluate the model
#-------------------

# Set model to evaluation mode
model.eval()

# Load the best model weights
model.load_state_dict(torch.load('model.pth'))

# Initialize metrics for accuracy and F1 score
accuracy_metric = Accuracy(task="binary")
f1_metric = F1Score(task="binary")

# Create lists store all predictions and labels
all_preds = []
all_labels = []

# Disable gradient calculation for evaluation
with torch.no_grad():
  for inputs, labels in test_loader:
    # Move inputs and labels to the device
    inputs, labels = inputs.to(device), labels.to(device)
    
    # Forward pass
    outputs = model(inputs)
    preds = torch.sigmoid(outputs).round() # Round to 0 or 1
    
    # Extend the lists with predictions and labels
    all_preds.extend(preds.cpu().tolist())
    all_labels.extend(labels.unsqueeze(1).cpu().tolist())

# Convert lists to tensors
all_preds = torch.tensor(all_preds)
all_labels = torch.tensor(all_labels)

# Compute metrics for the entire test set
test_acc = accuracy_metric(all_preds, all_labels).item()
test_f1 = f1_metric(all_preds, all_labels).item()

print(f"Test accuracy: {test_acc:.3f}")
print(f"Test F1-score: {test_f1:.3f}")

  model.load_state_dict(torch.load('model.pth'))


Test accuracy: 0.810
Test F1-score: 0.838
