Pneumonia is one of the leading respiratory illnesses worldwide, and its timely and accurate diagnosis is essential for effective treatment. Manually reviewing chest X-rays is a critical step in this process, and AI can provide valuable support by helping to expedite the assessment. In your role as a consultant data scientist, you will test the ability of a deep learning model to distinguish pneumonia cases from normal images of lungs in chest X-rays.

By fine-tuning a pre-trained convolutional neural network, specifically the ResNet-18 model, your task is to classify X-ray images into two categories: normal lungs and those affected by pneumonia. You can leverage its already trained weights and get an accurate classifier trained faster and with fewer resources.

## The Data

<img src="x-rays_sample.png" align="center"/>
&nbsp

You have a dataset of chest X-rays that have been preprocessed for use with a ResNet-18 model. You can see a sample of 5 images from each category above. Upon unzipping the `chestxrays.zip` file (code provided below), you will find your dataset inside the `data/chestxrays` folder divided into `test` and `train` folders. 

There are 150 training images and 50 testing images for each category, NORMAL and PNEUMONIA (300 and 100 in total). For your convenience, this data has already been loaded into a `train_loader` and a `test_loader` using the `DataLoader` class from the PyTorch library. 

In [62]:
# # Make sure to run this cell to use torchmetrics.
!pip install torch torchvision torchmetrics

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [63]:
# Import required libraries
# -------------------------
# Data loading
import random
import numpy as np
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Train model
import torch
from torchvision import models
import torch.nn as nn
import torch.optim as optim

# Evaluate model
from torchmetrics import Accuracy, F1Score

# Set random seeds for reproducibility
torch.manual_seed(101010)
np.random.seed(101010)
random.seed(101010)

In [64]:
import os
import zipfile

# Unzip the data folder
if not os.path.exists('data/chestxrays'):
    with zipfile.ZipFile('data/chestxrays.zip', 'r') as zip_ref:
        zip_ref.extractall('data')

In [65]:
# Define the transformations to apply to the images for use with ResNet-18
transform_mean = [0.485, 0.456, 0.406]
transform_std =[0.229, 0.224, 0.225]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(mean=transform_mean, std=transform_std)])

# Apply the image transforms
train_dataset = ImageFolder('data/chestxrays/train', transform=transform)
test_dataset = ImageFolder('data/chestxrays/test', transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset) // 2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

In [66]:
# Start coding here
# Use as many cells as you need
weights = models.ResNet18_Weights.DEFAULT
resnet18 = models.resnet18(weights=weights)

In [67]:
for param in resnet18.parameters():
    param.requires_grad = False

Adjusting final layer

In [68]:
num_features = resnet18.fc.in_features
print(num_features)

resnet18.fc = nn.Linear(num_features, 1)

512


Model Training

In [69]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

resnet18 = resnet18.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(resnet18.fc.parameters(), lr=0.01)

num_epochs = 3 #you can increase the number of epochs to improve accuracy

for epoch in range(num_epochs):
    resnet18.train()
    running_loss = 0.0
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.float().view(-1,1)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = resnet18(inputs)
        
        # Compare between outputs and labels
        loss = criterion(outputs, labels)
        
        # Backward pass, calculates gradients
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        # Accumulate running loss
        running_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {loss.item():.4f}") 
    
    # Average loss for the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {running_loss / len(train_loader):.4f}")

Epoch [1/3], Batch [1], Loss: 0.7188
Epoch [1/3], Average Loss: 1.3915
Epoch [2/3], Batch [1], Loss: 1.0120
Epoch [2/3], Average Loss: 0.8973
Epoch [3/3], Batch [1], Loss: 1.1386
Epoch [3/3], Average Loss: 0.9199


### Below is the provided model evaluation code. Run the below cell to help you evaluate the accuracy and F1-score of your fine-tuned model.

In [70]:
#-------------------
# Evaluate the model
#-------------------

# Set model to evaluation mode
model = resnet18
model.eval()

# Initialize metrics for accuracy and F1 score
accuracy_metric = Accuracy(task="binary")
f1_metric = F1Score(task="binary")

# Create lists store all predictions and labels
all_preds = []
all_labels = []

# Disable gradient calculation for evaluation
with torch.no_grad():
  for inputs, labels in test_loader:
    # Forward pass
    outputs = model(inputs)
    preds = torch.sigmoid(outputs).round()  # Round to 0 or 1

    # Extend the lists with predictions and labels
    all_preds.extend(preds.tolist())
    all_labels.extend(labels.unsqueeze(1).tolist())

    # Convert lists back to tensors
    all_preds = torch.tensor(all_preds)
    all_labels = torch.tensor(all_labels)

    # Calculate accuracy and F1 score
    test_accuracy = accuracy_metric(all_preds, all_labels).item()
    test_f1_score = f1_metric(all_preds, all_labels).item()
    
    print(f"Test Accuracy: {test_acc:.3f}")
    print(f"Test F1 Score: {test_f1:.3f}")

Test Accuracy: 0.580
Test F1 Score: 0.704
