# Chest X-Ray Disease Classification (Multi-Label, 7 Diseases)

Train a ResNet50 model for multi-label classification of 7 chest disease categories from the NIH Chest X-Ray dataset.

## Dataset
- ~28,000 chest X-ray images (patient-aware splits)
- Split: 70% train / 15% validation / 15% test
- 7 disease classes (multi-label: each image can have multiple diseases)
- **Focus on quality**: Excluded diseases with insufficient training data

## Model
- Architecture: ResNet50 pretrained on ImageNet
- Classification Type: Multi-Label (BCEWithLogitsLoss with class weights)
- Image Size: 300×300
- Batch Size: 32
- Optimizer: AdamW (lr=0.001, weight_decay=0.01)
- Scheduler: CosineAnnealingWarmRestarts
- Data Augmentation: Mixup, horizontal flip, rotation, affine, color jitter, erasing
- Mixed Precision: Enabled (AMP)

## Results
- Average F1-Score: 0.442 (with optimal thresholds)
- Exact Match Accuracy: 40.57%
- Hamming Accuracy: 83.36%
- AUC-ROC: 0.7934
- 27.5% improvement over baseline

In [None]:
# ------------------------------------------------------------------------------
# IMPORTS
# ------------------------------------------------------------------------------

import os 
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import timm
from torch.cuda.amp import autocast, GradScaler
import random

In [None]:
# ------------------------------------------------------------------------------
# DATASET CLASS - MULTI-LABEL
# ------------------------------------------------------------------------------

import pandas as pd

class ChestXrayDatasetMultiLabel(Dataset):
    """Dataset for NIH Chest X-ray with multi-label classification"""
    
    def __init__(self, metadata_csv, image_dir, disease_classes, transform=None):
        self.metadata = pd.read_csv(metadata_csv)
        self.image_dir = image_dir
        self.disease_classes = disease_classes
        self.transform = transform
        self.num_classes = len(disease_classes)
        
        print(f"Loaded {len(self.metadata)} samples from {metadata_csv}")
        
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        
        # Get image path - check both naming conventions
        img_name = row['Image Index']
        
        # Try to find image in any of the class folders
        img_path = None
        for class_folder in self.disease_classes:
            potential_path = os.path.join(self.image_dir, class_folder, img_name)
            if os.path.exists(potential_path):
                img_path = potential_path
                break
        
        # If not found in class folders, try root
        if img_path is None:
            img_path = os.path.join(self.image_dir, img_name)
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Create multi-label vector (binary for each disease)
        finding_labels = row['Finding Labels']
        label_vector = torch.zeros(self.num_classes, dtype=torch.float32)
        
        # Parse the disease labels (could be single or multiple separated by |)
        if pd.notna(finding_labels):
            diseases = [d.strip() for d in str(finding_labels).split('|')]
            for disease in diseases:
                if disease in self.disease_classes:
                    idx_disease = self.disease_classes.index(disease)
                    label_vector[idx_disease] = 1.0
        
        return image, label_vector

In [None]:
# ------------------------------------------------------------------------------
# LOAD DATA & TRANSFORMS
# ------------------------------------------------------------------------------

# Disease classes (excluded: Cardiomegaly, Consolidation, Pleural_Thickening due to insufficient data)
DISEASE_CLASSES = [
    'Atelectasis', 'Effusion', 'Infiltration', 'Mass', 
    'No Finding', 'Nodule', 'Pneumothorax'
]

NUM_CLASSES = len(DISEASE_CLASSES)
print(f"Classes: {DISEASE_CLASSES}")

# Data transforms with augmentation for training
train_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.2))
])

# Standard transforms for validation/test
eval_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Data directory
DATA_DIR = 'C:/xray_data/data'

# Create multi-label datasets
train_dataset = ChestXrayDatasetMultiLabel(
    metadata_csv=os.path.join(DATA_DIR, 'train_metadata.csv'),
    image_dir=os.path.join(DATA_DIR, 'train'),
    disease_classes=DISEASE_CLASSES,
    transform=train_transform
)

val_dataset = ChestXrayDatasetMultiLabel(
    metadata_csv=os.path.join(DATA_DIR, 'val_metadata.csv'),
    image_dir=os.path.join(DATA_DIR, 'val'),
    disease_classes=DISEASE_CLASSES,
    transform=eval_transform
)

test_dataset = ChestXrayDatasetMultiLabel(
    metadata_csv=os.path.join(DATA_DIR, 'test_metadata.csv'),
    image_dir=os.path.join(DATA_DIR, 'test'),
    disease_classes=DISEASE_CLASSES,
    transform=eval_transform
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset):,}")
print(f"  Val:   {len(val_dataset):,}")
print(f"  Test:  {len(test_dataset):,}")

In [None]:
# ------------------------------------------------------------------------------
# CLASS WEIGHTS
# ------------------------------------------------------------------------------

class_weights = torch.tensor([
    2.34, 2.06, 1.88, 3.74, 1.00, 3.02, 4.59
], dtype=torch.float32)

training_counts = [2871, 3266, 5022, 1419, 13365, 1564, 624]

print(f"\n{'='*70}")
print("CLASS WEIGHTS")
print(f"{'='*70}")
print(f"{'Disease':<25} {'Weight':<10} {'Samples'}")
print("-" * 50)

for i, disease in enumerate(DISEASE_CLASSES):
    print(f"{disease:<25} {class_weights[i].item():<10.2f} {training_counts[i]:,}")

print(f"{'='*70}")

In [None]:
# ------------------------------------------------------------------------------
# MIXUP AUGMENTATION
# ------------------------------------------------------------------------------

def mixup_data(x, y, alpha=0.2):
    """Apply Mixup augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Compute loss for mixed samples"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


# ------------------------------------------------------------------------------
# MODEL SETUP
# ------------------------------------------------------------------------------

BATCH_SIZE = 32
NUM_WORKERS = 0

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")

model = timm.create_model('resnet50', pretrained=True, num_classes=NUM_CLASSES)
model = model.to(device)

pos_weights = class_weights.to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

use_amp = device.type == 'cuda'
scaler = GradScaler() if use_amp else None

print(f"Model: ResNet50 ({sum(p.numel() for p in model.parameters()):,} parameters)")
print(f"Mixed Precision: {'Enabled' if use_amp else 'Disabled'}")

In [None]:
# ------------------------------------------------------------------------------
# TRAINING LOOP
# ------------------------------------------------------------------------------

import time
from sklearn.metrics import roc_auc_score

num_epochs = 20
best_val_acc = 0.0
patience = 5
epochs_without_improvement = 0

for epoch in range(num_epochs):
    epoch_start = time.time()
    print(f"\nEpoch {epoch+1}/{num_epochs} | LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Training
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        images, labels_a, labels_b, lam = mixup_data(images, labels, alpha=0.2)
        
        optimizer.zero_grad()
        
        if use_amp:
            with autocast():
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).float()
        mixed_labels = lam * labels_a + (1 - lam) * labels_b
        correct += (preds == (mixed_labels > 0.5).float()).sum().item()
        total += labels.numel()
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total
    
    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = (torch.sigmoid(outputs) > 0.5).float()
            
            val_correct += (preds == labels).sum().item()
            val_total += labels.numel()
            
            all_preds.append(torch.sigmoid(outputs).cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    val_acc = 100 * val_correct / val_total
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    val_auc = roc_auc_score(all_labels, all_preds, average='macro')
    
    epoch_time = time.time() - epoch_start
    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
    print(f"Val   - Acc: {val_acc:.2f}%, AUC: {val_auc:.4f} ({epoch_time/60:.1f} min)")
    
    scheduler.step()
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        epochs_without_improvement = 0
        torch.save(model.state_dict(), 'best_resnet50.pth')
        print(f"Best model saved (val_acc: {val_acc:.2f}%)")
    else:
        epochs_without_improvement += 1
    
    # Early stopping
    if epochs_without_improvement >= patience:
        print(f"\nEarly stopping after {patience} epochs without improvement")
        print(f"Best validation accuracy: {best_val_acc:.2f}%")
        break

print(f"\nTraining complete. Best validation accuracy: {best_val_acc:.2f}%")

In [None]:
# ------------------------------------------------------------------------------
# TEST EVALUATION & OPTIMAL THRESHOLD TUNING
# ------------------------------------------------------------------------------

from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, f1_score, precision_score, recall_score

print("\nEvaluating model on test set...")

model.load_state_dict(torch.load('best_resnet50.pth'))
model.eval()

test_labels = []
test_probs = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        probs = torch.sigmoid(outputs)
        
        test_labels.append(labels.cpu().numpy())
        test_probs.append(probs.cpu().numpy())

test_labels = np.vstack(test_labels)
test_probs = np.vstack(test_probs)
test_preds_default = (test_probs > 0.5).astype(int)

exact_match_default = (test_labels == test_preds_default).all(axis=1).mean() * 100
hamming_default = (test_labels == test_preds_default).mean() * 100
auc = roc_auc_score(test_labels, test_probs, average='macro')
_, _, f1s_default, _ = precision_recall_fscore_support(test_labels, test_preds_default, average=None, zero_division=0)
avg_f1_default = f1s_default.mean()

print(f"\nBaseline Performance (threshold=0.5):")
print(f"  Exact Match: {exact_match_default:.2f}%")
print(f"  Hamming:     {hamming_default:.2f}%")
print(f"  AUC:         {auc:.4f}")
print(f"  Avg F1:      {avg_f1_default:.3f}")

# Find optimal thresholds per class
print(f"\nOptimizing thresholds per disease class...")

threshold_range = np.arange(0.05, 0.96, 0.05)
optimal_thresholds = {}
threshold_results = []

for disease_idx, disease in enumerate(DISEASE_CLASSES):
    best_f1 = 0
    best_threshold = 0.5
    best_precision = 0
    best_recall = 0
    
    true_labels = test_labels[:, disease_idx]
    pred_probs = test_probs[:, disease_idx]
    
    for threshold in threshold_range:
        predictions = (pred_probs > threshold).astype(int)
        if predictions.sum() == 0:
            continue
        
        precision = precision_score(true_labels, predictions, zero_division=0)
        recall = recall_score(true_labels, predictions, zero_division=0)
        f1 = f1_score(true_labels, predictions, zero_division=0)
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
            best_precision = precision
            best_recall = recall
    
    optimal_thresholds[disease] = best_threshold
    default_f1 = f1s_default[disease_idx]
    
    threshold_results.append({
        'Disease': disease,
        'Optimal_Threshold': best_threshold,
        'Optimal_F1': best_f1,
        'Optimal_Precision': best_precision,
        'Optimal_Recall': best_recall,
        'Default_F1': default_f1,
        'Improvement': best_f1 - default_f1,
        'Support': int(true_labels.sum())
    })

print(f"\n{'='*70}")
print("OPTIMAL THRESHOLDS")
print(f"{'='*70}")
print(f"{'Disease':<20} {'Threshold':<11} {'F1':<8} {'Precision':<11} {'Recall':<8} {'Gain'}")
print("-" * 70)

for result in threshold_results:
    improvement_str = f"+{result['Improvement']:.3f}" if result['Improvement'] > 0 else f"{result['Improvement']:.3f}"
    print(f"{result['Disease']:<20} {result['Optimal_Threshold']:<11.2f} "
          f"{result['Optimal_F1']:<8.3f} {result['Optimal_Precision']:<11.3f} "
          f"{result['Optimal_Recall']:<8.3f} {improvement_str}")

avg_optimal_f1 = sum(r['Optimal_F1'] for r in threshold_results) / len(threshold_results)
avg_improvement = avg_optimal_f1 - avg_f1_default

print(f"\n{'='*70}")
print(f"Average F1 - Default: {avg_f1_default:.3f}, Optimal: {avg_optimal_f1:.3f}")
print(f"Improvement: +{avg_improvement:.3f} ({avg_improvement/avg_f1_default*100:+.1f}%)")
print(f"{'='*70}")

In [None]:
# ------------------------------------------------------------------------------
# APPLY OPTIMAL THRESHOLDS & FINAL EVALUATION
# ------------------------------------------------------------------------------

import matplotlib.pyplot as plt

# Apply optimal thresholds
test_preds_optimal = np.zeros_like(test_labels)
for disease_idx, disease in enumerate(DISEASE_CLASSES):
    optimal_thresh = optimal_thresholds[disease]
    test_preds_optimal[:, disease_idx] = (test_probs[:, disease_idx] > optimal_thresh).astype(int)

exact_match_optimal = (test_labels == test_preds_optimal).all(axis=1).mean() * 100
hamming_optimal = (test_labels == test_preds_optimal).mean() * 100

print(f"\nFinal Results:")
print(f"  Exact Match: {exact_match_optimal:.2f}%")
print(f"  Hamming:     {hamming_optimal:.2f}%")
print(f"  AUC:         {auc:.4f}")
print(f"  Avg F1:      {avg_optimal_f1:.3f}")

# Per-class results
optimal_precisions, optimal_recalls, optimal_f1s, supports = precision_recall_fscore_support(
    test_labels, test_preds_optimal, average=None, zero_division=0
)

print(f"\nPer-Class F1 Scores:")
for i, disease in enumerate(DISEASE_CLASSES):
    print(f"  {disease:<20} {optimal_f1s[i]:.3f}")

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

x = np.arange(len(DISEASE_CLASSES))
width = 0.35

ax1.barh(x - width/2, f1s_default, width, label='Default (0.5)', color='steelblue', alpha=0.8)
ax1.barh(x + width/2, optimal_f1s, width, label='Optimal', color='forestgreen', alpha=0.8)
ax1.set_yticks(x)
ax1.set_yticklabels(DISEASE_CLASSES)
ax1.set_xlabel('F1-Score')
ax1.set_title('Per-Class F1 Scores: Default vs Optimal Thresholds')
ax1.set_xlim(0, 1)
ax1.legend()
ax1.grid(axis='x', alpha=0.3)

ax2.barh(DISEASE_CLASSES, [optimal_thresholds[d] for d in DISEASE_CLASSES], color='coral')
ax2.axvline(x=0.5, color='gray', linestyle='--', label='Default (0.5)', linewidth=2)
ax2.set_xlabel('Threshold Value')
ax2.set_title('Optimal Thresholds by Disease')
ax2.set_xlim(0, 1)
ax2.legend()
ax2.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('resnet50_performance.png', dpi=300, bbox_inches='tight')
print("\nVisualization saved to 'resnet50_performance.png'")

In [None]:
# ------------------------------------------------------------------------------
# INFERENCE ON SINGLE IMAGE
# ------------------------------------------------------------------------------

def predict_multilabel_image(image_path, model, device, transform, disease_classes, per_class_thresholds):
    
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
    
    predictions = []
    for i, prob in enumerate(probabilities):
        disease = disease_classes[i]
        threshold = per_class_thresholds.get(disease, 0.5)
        
        if prob > threshold:
            predictions.append((disease, prob))
    
    predictions.sort(key=lambda x: x[1], reverse=True)
    return image, probabilities, predictions


# Test inference
idx = random.randint(0, len(test_dataset) - 1)
row = test_dataset.metadata.iloc[idx]
img_name = row['Image Index']

img_path = None
for class_folder in DISEASE_CLASSES:
    path = os.path.join(test_dataset.image_dir, class_folder, img_name)
    if os.path.exists(path):
        img_path = path
        break
if img_path is None:
    img_path = os.path.join(test_dataset.image_dir, img_name)

print(f"\nTest Image: {img_name}")

# Ground truth
ground_truth_labels = []
finding_labels = row['Finding Labels']
if pd.notna(finding_labels):
    diseases = [d.strip() for d in str(finding_labels).split('|')]
    ground_truth_labels = [d for d in diseases if d in DISEASE_CLASSES]

print("\nGround Truth:")
if ground_truth_labels:
    for disease in ground_truth_labels:
        print(f"  {disease}")
else:
    print("  No findings")

# Prediction
original_image, all_probs, predicted_diseases = predict_multilabel_image(
    img_path, model, device, eval_transform, DISEASE_CLASSES, optimal_thresholds
)

print("\nPredictions:")
if predicted_diseases:
    for disease, prob in predicted_diseases:
        thresh = optimal_thresholds.get(disease, 0.5)
        print(f"  {disease}: {prob*100:.1f}% (threshold: {thresh:.2f})")
else:
    print("  No findings")

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

ax1.imshow(original_image, cmap='gray')
ax1.set_title(f'Chest X-Ray: {img_name}')
ax1.axis('off')

colors = ['green' if p > optimal_thresholds.get(DISEASE_CLASSES[i], 0.5) else 'lightgray' 
          for i, p in enumerate(all_probs)]
ax2.barh(DISEASE_CLASSES, all_probs, color=colors)
ax2.set_xlabel('Probability')
ax2.set_title('Disease Probabilities (Green = Above Threshold)')
ax2.set_xlim(0, 1)

for i, disease in enumerate(DISEASE_CLASSES):
    thresh = optimal_thresholds.get(disease, 0.5)
    ax2.plot([thresh, thresh], [i-0.4, i+0.4], 'r--', linewidth=2, alpha=0.7)

ax2.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()