# NIH Chest X-ray Classification using Hierarchical Vision Transformers
## Matthew Ohanian

This notebook implements a hierarchical approach to chest X-ray classification using two Vision Transformer (ViT) models:
1. A binary classifier that determines if an X-ray contains any pathological finding
2. A multi-label classifier that identifies specific pathological conditions only when the binary classifier detects a finding

This approach mimics radiologist workflow (first detecting abnormality, then characterizing it) and might improve performance by specializing each model for its specific task.

In [5]:
# Import necessary libraries
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score, accuracy_score

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Import custom modules
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import sys
sys.path.append('/content/drive/MyDrive/deep_learning_proj/')

# Import NIH dataset module
import NIH_ChestXRay_Dataset_Module as nih

# Import custom modules for hierarchical ViT models
import nih_hierarchical_vit as nih_hvit

Using device: cpu
Mounted at /content/drive


# 1. Data Preprocessing (5%)

## 1.1 Dataset Overview

The NIH Chest X-ray dataset is a large public dataset of chest X-rays containing 14 common thoracic pathologies. Each image can have multiple labels, making this a multi-label classification problem. The dataset also includes a "No Finding" label indicating the absence of pathological conditions.

## 1.2 Data Loading and Exploration

In [None]:
# Data Loading
data_dir = "/content/drive/MyDrive/deep_learning_proj/data/nih_data"
train_loader, val_loader, test_loader, class_weights = nih.get_nih_data_loaders(
    data_dir=data_dir,
    batch_size=128,
    sample_size=500,  # Adjust sample size as needed
    test_size=100,
    balance=True,
    verbose=True
)

# Get disease labels (excluding "No Finding")
disease_labels = [label for label in nih.NIHChestXRay.LABELS if label != "No Finding"]
print(f"Disease labels: {disease_labels}")

# Visualize class distribution
def count_labels(loader):
    counts = {label: 0 for label in nih.NIHChestXRay.LABELS}
    for _, labels_batch in loader:
        for i, label in enumerate(nih.NIHChestXRay.LABELS):
            counts[label] += labels_batch[:, i].sum().item()
    return counts

label_counts = count_labels(train_loader)
plt.figure(figsize=(12, 6))
plt.bar(label_counts.keys(), label_counts.values())
plt.xticks(rotation=90)
plt.title("Label Distribution in Training Set")
plt.ylabel("Count")
plt.tight_layout()
plt.show()

Loaders: train=400, val=50, test=50
Disease labels: ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']


## 1.3 Data Preprocessing Steps

The NIH chest X-ray images undergo several preprocessing steps before being fed into our models:

1. **Resizing**: All images are resized to 224×224 pixels to match the input size requirement of the Vision Transformer models.

2. **Normalization**: Images are normalized using ImageNet mean and standard deviation values (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) to ensure consistent input scaling.

3. **Data Augmentation**: For training data, we apply the following augmentations to increase model robustness:
   - Random horizontal flips
   - Random rotations (±10 degrees)
   - Random brightness and contrast adjustments

4. **Class Balancing**: Due to the imbalanced nature of the dataset (some conditions are much rarer than others), we employ class weighting during training to prevent the model from ignoring minority classes.

5. **Binary Label Creation**: For the binary classifier, we create a new target where any image with at least one finding (not labeled as "No Finding") is considered positive.

**Note**: The actual preprocessing implementation details can be found in the `NIH_ChestXRay_Dataset_Module.py` file.

# 2. Model Implementation (10%)

## 2.1 Hierarchical Model Architecture

Our approach uses a hierarchical architecture with two specialized Vision Transformer (ViT) models:

1. **Binary Classifier**: Determines if an X-ray contains any pathological finding (abnormal vs. normal)
2. **Disease Classifier**: Identifies specific diseases only when the binary classifier detects an abnormality

This architecture mimics radiologist workflow, where they first detect the presence of any abnormality before characterizing specific conditions. This approach may lead to better performance and efficiency by allowing each model to specialize in its specific task.

## 2.2 Model Definitions

In [None]:
# Create binary classifier model
binary_model = nih_hvit.ViTBinaryClassifier()
binary_model.to(device)

# Create multi-label disease classifier model
disease_model = nih_hvit.ViTDiseaseClassifier(num_labels=len(disease_labels), labels=disease_labels)
disease_model.to(device)

# Binary model parameters
binary_params = sum(p.numel() for p in binary_model.parameters())
print(f"Binary model parameters: {binary_params:,}")

# Disease model parameters
disease_params = sum(p.numel() for p in disease_model.parameters())
print(f"Disease model parameters: {disease_params:,}")

## 2.3 Model Architecture Details

### 2.3.1 Binary Classifier

The binary classifier model is based on a pre-trained Vision Transformer (ViT) model. It consists of:

- **Base Model**: Pre-trained ViT-B/16 model from the `timm` library, which uses 16×16 patches and has been pre-trained on ImageNet
- **Adaptation Layer**: A custom head replaces the original classification head to adapt the model for binary classification
- **Output**: A single sigmoid output representing the probability of any finding being present

### 2.3.2 Disease Classifier

The disease classifier model is also based on a pre-trained Vision Transformer, but adapted for multi-label classification:

- **Base Model**: Pre-trained ViT-B/16 model from the `timm` library
- **Adaptation Layer**: A custom multi-label classification head with 14 outputs (one for each disease)
- **Output**: Multiple sigmoid outputs representing the probability of each specific disease

### 2.3.3 Inference Flow

During inference, the process follows these steps:
1. The input X-ray image is first passed through the binary classifier
2. If the binary classifier predicts a finding (score above threshold), the image is passed to the disease classifier
3. The disease classifier then predicts the probabilities of specific diseases
4. If the binary classifier predicts no finding, we skip the disease classifier and return all zeros for disease probabilities

This hierarchical approach potentially improves efficiency and accuracy by specializing each model for its specific task.

# 3. Methods (5%)

## 3.1 Training Strategy

We train the binary and disease classifiers separately to allow each model to specialize in its specific task:

1. **Binary Classifier Training**:
   - Trained to differentiate between normal X-rays ("No Finding") and abnormal ones (any disease present)
   - Uses binary cross-entropy loss with class weighting to handle the imbalance
   - We employ the AdamW optimizer with a learning rate of 1e-4 and weight decay of 1e-5
   - Learning rate is scheduled using cosine annealing

2. **Disease Classifier Training**:
   - Trained to identify specific diseases in abnormal X-rays
   - Uses multi-label binary cross-entropy loss (independent binary classifier for each disease)
   - Also employs AdamW optimizer with similar hyperparameters
   - Uses the same learning rate scheduler

## 3.2 Evaluation Metrics

We use the following metrics to evaluate our models:

1. **Area Under the ROC Curve (AUC)**: Primary metric for both binary and multi-label classification
2. **Accuracy**: For the binary classifier
3. **Per-class AUC**: To evaluate performance on each specific disease
4. **Mean AUC**: Average AUC across all disease classes

## 3.3 Training Configuration

In [None]:
# Loss functions
binary_criterion = torch.nn.BCEWithLogitsLoss()
disease_criterion = torch.nn.BCEWithLogitsLoss()

# Optimizers
binary_optimizer = AdamW(binary_model.parameters(), lr=1e-4, weight_decay=1e-5)
disease_optimizer = AdamW(disease_model.parameters(), lr=1e-4, weight_decay=1e-5)

# Learning rate schedulers
binary_scheduler = CosineAnnealingLR(binary_optimizer, T_max=5)
disease_scheduler = CosineAnnealingLR(disease_optimizer, T_max=5)

# 4. Experiments and Results (10%)

## 4.1 Binary Classifier Training

In [None]:
# Training Binary Classifier
num_epochs = 5  # Adjust as needed
binary_train_losses = []
binary_val_losses = []
binary_val_aucs = []
binary_val_accs = []
best_binary_auc = 0.0

print(f"Training binary classifier for {num_epochs} epochs...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Train
    train_loss = nih_hvit.train_binary_model(binary_model, train_loader, binary_criterion, binary_optimizer, device, epoch)
    binary_train_losses.append(train_loss)

    # Validate
    val_loss, val_auc, val_acc = nih_hvit.validate_binary_model(binary_model, val_loader, binary_criterion, device)
    binary_val_losses.append(val_loss)
    binary_val_aucs.append(val_auc)
    binary_val_accs.append(val_acc)

    # Update learning rate
    binary_scheduler.step()

    # Save best model
    if val_auc > best_binary_auc:
        best_binary_auc = val_auc
        torch.save(binary_model.state_dict(), f'binary_model_epoch_{epoch+1}_auc_{val_auc:.3f}.pt')

    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val Acc: {val_acc:.4f}")
    print("-" * 50)

# Plot training history
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(binary_train_losses, label='Train Loss')
plt.plot(binary_val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Binary Classifier Training and Validation Loss')

plt.subplot(1, 3, 2)
plt.plot(binary_val_aucs, label='Validation AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()
plt.title('Binary Classifier Validation AUC')

plt.subplot(1, 3, 3)
plt.plot(binary_val_accs, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Binary Classifier Validation Accuracy')

plt.tight_layout()
plt.show()

### 4.1.1 Binary Classifier Results Summary

**[Placeholder - to be filled after training]**

Summary of the binary classifier performance including:
- Final validation AUC: [value]
- Final validation accuracy: [value]
- Training convergence observations
- Key insights from loss and metrics curves

## 4.2 Disease Classifier Training

In [None]:
# Training Disease Classifier
num_epochs = 5  # Adjust as needed
disease_train_losses = []
disease_val_losses = []
disease_val_aucs = []
best_disease_auc = 0.0

print(f"Training disease classifier for {num_epochs} epochs...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Train
    train_loss = nih_hvit.train_multilabel_model(disease_model, train_loader, disease_criterion, disease_optimizer, device, epoch)
    disease_train_losses.append(train_loss)

    # Validate
    val_loss, val_auc, class_aucs = nih_hvit.validate_multilabel_model(disease_model, val_loader, disease_criterion, device, disease_labels)
    disease_val_losses.append(val_loss)
    disease_val_aucs.append(val_auc)

    # Update learning rate
    disease_scheduler.step()

    # Save best model
    if val_auc > best_disease_auc:
        best_disease_auc = val_auc
        torch.save(disease_model.state_dict(), f'disease_model_epoch_{epoch+1}_auc_{val_auc:.3f}.pt')

    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")
    print("Class AUCs:")
    for label, auc in class_aucs.items():
        print(f"  {label}: {auc:.4f}")
    print("-" * 50)

# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(disease_train_losses, label='Train Loss')
plt.plot(disease_val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Disease Classifier Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(disease_val_aucs, label='Validation AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()
plt.title('Disease Classifier Validation AUC')

plt.tight_layout()
plt.show()

# Plot AUCs for last epoch
plt.figure(figsize=(12, 6))
plt.bar(class_aucs.keys(), class_aucs.values())
plt.xticks(rotation=90)
plt.ylabel('AUC')
plt.title('Disease Classifier Validation AUC by Class')
plt.tight_layout()
plt.show()

### 4.2.1 Disease Classifier Results Summary

**[Placeholder - to be filled after training]**

Summary of the disease classifier performance including:
- Final mean validation AUC: [value]
- Best and worst performing disease classes
- Training convergence observations
- Key insights from per-class performance

## 4.3 Hierarchical Model Evaluation

In [None]:
# Evaluate hierarchical model
print("Evaluating hierarchical model on test set...")
test_results = nih_hvit.test_hierarchical_model(binary_model, disease_model, test_loader, device, disease_labels)

print(f"Binary classifier test results:")
print(f"  AUC: {test_results['binary_auc']:.4f}")
print(f"  Accuracy: {test_results['binary_accuracy']:.4f}")

print(f"\nDisease classifier test results:")
print(f"  Mean AUC: {test_results['mean_disease_auc']:.4f}")
print("  AUC by class:")
for label, auc in test_results['disease_aucs'].items():
    print(f"    {label}: {auc:.4f}")

# Plot disease AUCs
plt.figure(figsize=(12, 6))
plt.bar(test_results['disease_aucs'].keys(), test_results['disease_aucs'].values())
plt.xticks(rotation=90)
plt.ylabel('AUC')
plt.title('Test AUC by Disease Class')
plt.tight_layout()
plt.show()

### 4.3.1 Hierarchical Model Test Results

**[Placeholder - to be filled after evaluation]**

Summary of the hierarchical model performance on the test set including:
- Binary classifier test AUC and accuracy
- Disease classifier mean test AUC
- Per-disease AUC analysis
- Key insights about the hierarchical approach effectiveness

## 4.4 Visualizing Predictions

In [None]:
# Visualize predictions
print("Visualizing example predictions from hierarchical model:")
nih_hvit.visualize_hierarchical_predictions(binary_model, disease_model, test_loader, device, disease_labels, num_examples=5)

### 4.4.1 Qualitative Analysis of Predictions

**[Placeholder - to be filled after running visualization]**

Analysis of the example predictions including:
- Observations about model confidence
- Cases where binary classifier and disease classifier align/disagree
- Potential failure modes identified

## 4.5 Comparison with Single Model Approach

In [None]:
# Load comparison results (replace with actual results after experiments)
comparison_data = {
    'Metric': ['Mean AUC', 'Pneumonia AUC', 'Effusion AUC', 'Cardiomegaly AUC', 'Atelectasis AUC', 'Training Time', 'Inference Time'],
    'Single Model': [0.0, 0.0, 0.0, 0.0, 0.0, '0 min', '0ms'],  # Placeholder values
    'Hierarchical Model': [0.0, 0.0, 0.0, 0.0, 0.0, '0 min', '0ms']  # Placeholder values
}
comparison_df = pd.DataFrame(comparison_data)
comparison_df.set_index('Metric', inplace=True)
comparison_df

### 4.5.1 Performance Comparison Analysis

**[Placeholder - to be filled after comparison]**

Analysis of the hierarchical model compared to the single model approach including:
- Overall performance differences
- Performance on specific diseases
- Computational efficiency comparison
- Trade-offs between approaches

# 5. Conclusion and Future Work

This notebook demonstrated a hierarchical approach to chest X-ray classification using two Vision Transformer models. The binary classifier first determines if there's any pathological finding, and then the multi-label classifier identifies specific conditions only when the binary classifier detects a finding.

## 5.1 Key Findings

**[Placeholder - to be completed after all experiments]**

1. **Hierarchical Performance**: 
   - How does the hierarchical approach compare to the single model?
   - What are the benefits and drawbacks?
   - Did the approach improve performance on any specific diseases?

2. **Clinical Relevance**: The hierarchical approach more closely mimics radiologist workflow, potentially making it more interpretable and clinically relevant.

3. **Efficiency Considerations**: While requiring two models, the hierarchical approach may be more computationally efficient at inference time since the more complex disease classifier only runs on images with detected findings.

## 5.2 Limitations

**[Placeholder - to be completed after experiments]**

- Limitations of the current approach
- Dataset limitations
- Challenges encountered

## 5.3 Future Work

1. **Joint Training**: Exploring end-to-end training approaches where both models are trained jointly with a shared loss function.

2. **Confidence Calibration**: Improving the calibration of confidence scores, especially for the binary classifier which acts as a gatekeeper.

3. **Attention Visualization**: Implementing visualization of attention maps for both models to improve interpretability.

4. **Clinical Validation**: Validating the hierarchical approach against radiologist performance and clinical outcomes.

5. **Model Optimization**: Exploring different architectures and hyperparameters to further improve performance.