# Chest X-Ray Disease Classification (Multi-Label, 7 Diseases)

Train an EfficientNet-B3 model for multi-label classification of 7 chest disease categories from the NIH Chest X-Ray dataset.

## Dataset
- ~28,000 chest X-ray images (patient-aware splits)
- Split: 70% train / 15% validation / 15% test
- 7 disease classes (multi-label: each image can have multiple diseases)
- **Focus on quality**: Excluded diseases with insufficient training data

## Model
- Architecture: EfficientNet-B3 pretrained on ImageNet
- Classification Type: Multi-Label (BCEWithLogitsLoss with class weights)
- Image Size: 300×300
- Batch Size: 32
- Optimizer: AdamW (lr=0.001, weight_decay=0.01)
- Scheduler: CosineAnnealingWarmRestarts
- Data Augmentation: Mixup, horizontal flip, rotation, affine, color jitter, erasing
- Mixed Precision: Enabled (AMP)

## Results
- Exact Match Accuracy: 47.80%
- Hamming Accuracy: 87.35%
- Average F1-Score: 0.347
- AUC-ROC: 0.7847

In [None]:
# ------------------------------------------------------------------------------
# IMPORTS
# ------------------------------------------------------------------------------

import os 
import json
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np
import timm  # For EfficientNet models
from torch.cuda.amp import autocast, GradScaler  # Mixed precision training
import random  # For Mixup
from tqdm.auto import tqdm  # Progress bar

In [None]:
# ------------------------------------------------------------------------------
# DATASET CLASS - MULTI-LABEL
# ------------------------------------------------------------------------------

import pandas as pd

class ChestXrayDatasetMultiLabel(Dataset):
    """Dataset for NIH Chest X-ray with multi-label classification"""
    
    def __init__(self, metadata_csv, image_dir, disease_classes, transform=None):
        self.metadata = pd.read_csv(metadata_csv)
        self.image_dir = image_dir
        self.disease_classes = disease_classes
        self.transform = transform
        self.num_classes = len(disease_classes)
        
        print(f"Loaded {len(self.metadata)} samples from {metadata_csv}")
        
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        
        # Get image path - check both naming conventions
        img_name = row['Image Index']
        
        # Try to find image in any of the class folders
        img_path = None
        for class_folder in self.disease_classes:
            potential_path = os.path.join(self.image_dir, class_folder, img_name)
            if os.path.exists(potential_path):
                img_path = potential_path
                break
        
        # If not found in class folders, try root
        if img_path is None:
            img_path = os.path.join(self.image_dir, img_name)
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Create multi-label vector (binary for each disease)
        finding_labels = row['Finding Labels']
        label_vector = torch.zeros(self.num_classes, dtype=torch.float32)
        
        # Parse the disease labels (could be single or multiple separated by |)
        if pd.notna(finding_labels):
            diseases = [d.strip() for d in str(finding_labels).split('|')]
            for disease in diseases:
                if disease in self.disease_classes:
                    idx_disease = self.disease_classes.index(disease)
                    label_vector[idx_disease] = 1.0
        
        return image, label_vector

In [None]:
# ------------------------------------------------------------------------------
# LOAD DATA & TRANSFORMS
# ------------------------------------------------------------------------------

# Define 7 trainable disease classes
# Excluded: Cardiomegaly, Consolidation, Pleural_Thickening (insufficient data)
DISEASE_CLASSES = [
    'Atelectasis', 'Effusion', 'Infiltration', 'Mass', 
    'No Finding', 'Nodule', 'Pneumothorax'
]

NUM_CLASSES = len(DISEASE_CLASSES)
print(f"Number of classes: {NUM_CLASSES}")
print(f"Classes: {DISEASE_CLASSES}")

# Data transforms with augmentation for training
train_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(p=0.5),  # Flip X-rays horizontally
    transforms.RandomRotation(degrees=10),    # Slight rotation
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random shifts
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Adjust brightness/contrast
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.2))  # Randomly erase patches
])

# Standard transforms for validation/test (no augmentation)
eval_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Data directory
DATA_DIR = 'C:/xray_data/data'

# Create multi-label datasets
train_dataset = ChestXrayDatasetMultiLabel(
    metadata_csv=os.path.join(DATA_DIR, 'train_metadata.csv'),
    image_dir=os.path.join(DATA_DIR, 'train'),
    disease_classes=DISEASE_CLASSES,
    transform=train_transform
)

val_dataset = ChestXrayDatasetMultiLabel(
    metadata_csv=os.path.join(DATA_DIR, 'val_metadata.csv'),
    image_dir=os.path.join(DATA_DIR, 'val'),
    disease_classes=DISEASE_CLASSES,
    transform=eval_transform
)

test_dataset = ChestXrayDatasetMultiLabel(
    metadata_csv=os.path.join(DATA_DIR, 'test_metadata.csv'),
    image_dir=os.path.join(DATA_DIR, 'test'),
    disease_classes=DISEASE_CLASSES,
    transform=eval_transform
)

print(f"\nDataset sizes:")
print(f"Train: {len(train_dataset):,} images")
print(f"Val:   {len(val_dataset):,} images")
print(f"Test:  {len(test_dataset):,} images")

In [None]:
# ------------------------------------------------------------------------------
# CLASS WEIGHTS
# ------------------------------------------------------------------------------

class_weights = torch.tensor([
    2.34,   # Atelectasis (2,871 samples)
    2.06,   # Effusion (3,266 samples)
    1.88,   # Infiltration (5,022 samples)
    3.74,   # Mass (1,419 samples)
    1.00,   # No Finding (13,365 samples - baseline)
    3.02,   # Nodule (1,564 samples)
    4.59,   # Pneumothorax (624 samples)
], dtype=torch.float32)

print(f"\n{'='*70}")
print("CLASS WEIGHTS (7 Disease Focus)")
print(f"{'='*70}")
print(f"{'Disease':<25} {'Weight':<10} {'Training Count':<15} {'Strategy'}")
print("-" * 75)

training_counts = [2871, 3266, 5022, 1419, 13365, 1564, 624]
strategies = [
    "Moderate boost",
    "Light boost", 
    "Light boost (common)",
    "Strong boost",
    "Baseline (no weight)",
    "Moderate boost",
    "Strong boost (rare)"
]

for i, disease in enumerate(DISEASE_CLASSES):
    weight = class_weights[i].item()
    count = training_counts[i]
    strategy = strategies[i]
    print(f"{disease:<25} {weight:<10.2f} {count:<15,} {strategy}")

print(f"\nStrategy:")
print(f"  • Focus on 7 trainable diseases with adequate data")
print(f"  • Moderate class weights (1.88-4.59) to handle imbalance")
print(f"  • No oversampling - preserves natural disease correlations")
print(f"{'='*70}")


In [None]:
# ------------------------------------------------------------------------------
# MIXUP AUGMENTATION
# ------------------------------------------------------------------------------

def mixup_data(x, y, alpha=0.2):
    """Apply Mixup augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Compute loss for mixed samples"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


# ------------------------------------------------------------------------------
# MODEL SETUP
# ------------------------------------------------------------------------------

BATCH_SIZE = 32
NUM_WORKERS = 0

# Data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=NUM_WORKERS
)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

# Model: EfficientNet-B3
model = timm.create_model('efficientnet_b3', pretrained=True, num_classes=NUM_CLASSES)
model = model.to(device)

# Loss function with class weights
pos_weights = class_weights.to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

# Mixed precision training
use_amp = device.type == 'cuda'
scaler = GradScaler() if use_amp else None

print(f"\nModel: EfficientNet-B3")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Mixed Precision: {'Enabled' if use_amp else 'Disabled'}")
print(f"Ready to train!")

In [None]:
# ------------------------------------------------------------------------------
# TRAINING LOOP
# ------------------------------------------------------------------------------

import time
from sklearn.metrics import roc_auc_score

num_epochs = 20
best_val_acc = 0.0
patience = 5
epochs_without_improvement = 0

print("\nStarting training...")

for epoch in range(num_epochs):
    epoch_start = time.time()
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    print(f"Epochs without improvement: {epochs_without_improvement}/{patience}")
    print(f"{'='*60}")
    
    # Track batch timing
    batch_times = []
    
    # Training
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        batch_start = time.time()
        
        images, labels = images.to(device), labels.to(device)
        
        # Apply Mixup augmentation
        images, labels_a, labels_b, lam = mixup_data(images, labels, alpha=0.2)
        
        optimizer.zero_grad()
        
        # Forward pass (with or without mixed precision)
        if use_amp:
            with autocast():
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        
        # Multi-label accuracy
        preds = (torch.sigmoid(outputs) > 0.5).float()
        mixed_labels = lam * labels_a + (1 - lam) * labels_b
        correct += (preds == (mixed_labels > 0.5).float()).sum().item()
        total += labels.numel()
        
        batch_time = time.time() - batch_start
        batch_times.append(batch_time)
        
        # Print progress every 20 batches
        if (batch_idx + 1) % 20 == 0:
            avg_time = np.mean(batch_times[-20:])
            print(f"  Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f} | {avg_time:.2f}s/batch")
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total
    avg_batch_time = np.mean(batch_times)
    
    print(f"\nTraining Results:")
    print(f"  Loss: {train_loss:.4f} | Accuracy: {train_acc:.2f}%")
    print(f"  Avg batch time: {avg_batch_time:.3f}s | Throughput: {BATCH_SIZE/avg_batch_time:.1f} img/s")
    
    # Validation
    print(f"\nRunning validation...")
    model.eval()
    val_correct = 0
    val_total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            preds = (torch.sigmoid(outputs) > 0.5).float()
            
            val_correct += (preds == labels).sum().item()
            val_total += labels.numel()
            
            all_preds.append(torch.sigmoid(outputs).cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    val_acc = 100 * val_correct / val_total
    
    # Calculate mean AUC-ROC
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    val_auc = roc_auc_score(all_labels, all_preds, average='macro')
    
    epoch_time = time.time() - epoch_start
    
    print(f"\nValidation Results:")
    print(f"  Accuracy: {val_acc:.2f}% | AUC: {val_auc:.4f}")
    print(f"\nEpoch completed in {epoch_time/60:.1f} minutes")
    
    scheduler.step()
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        epochs_without_improvement = 0
        torch.save(model.state_dict(), 'best_efficientnet_b3.pth')
        print(f"✓ New best model saved! (val_acc: {val_acc:.2f}%)")
    else:
        epochs_without_improvement += 1
        print(f"No improvement for {epochs_without_improvement} epoch(s)")
    
    print(f"{'='*60}")
    
    # Early stopping check
    if epochs_without_improvement >= patience:
        print(f"\n{'='*60}")
        print(f"Early stopping")
        print(f"No improvement for {patience} consecutive epochs")
        print(f"Best validation accuracy: {best_val_acc:.2f}%")
        print(f"Training stopped at epoch {epoch+1}/{num_epochs}")
        print(f"{'='*60}")
        break

print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.2f}%")

In [None]:
# ------------------------------------------------------------------------------
# TEST EVALUATION
# ------------------------------------------------------------------------------

# Load best model
model.load_state_dict(torch.load('best_efficientnet_b3.pth'))
model.eval()

test_labels = []
test_preds = []

print("\nEvaluating on test set...")

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        
        preds = torch.sigmoid(outputs) > 0.5
        
        test_labels.append(labels.cpu().numpy())
        test_preds.append(preds.cpu().numpy())

test_labels = np.vstack(test_labels)
test_preds = np.vstack(test_preds)

# Get probabilities for AUC
test_probs = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        probs = torch.sigmoid(outputs)
        test_probs.append(probs.cpu().numpy())
test_probs = np.vstack(test_probs)

# Calculate metrics
exact_match_acc = (test_labels == test_preds).all(axis=1).mean() * 100
hamming_acc = (test_labels == test_preds).mean() * 100
test_auc = roc_auc_score(test_labels, test_probs, average='macro')

print(f"\n{'='*70}")
print("TEST SET RESULTS")
print(f"{'='*70}")
print(f"Overall Metrics:")
print(f"  • Exact Match Accuracy: {exact_match_acc:.2f}% (all labels correct)")
print(f"  • Hamming Accuracy: {hamming_acc:.2f}% (per-label accuracy)")
print(f"  • AUC-ROC: {test_auc:.4f}")
print(f"")

# Per-class metrics
print(f"Per-Class Metrics:")
print(f"{'Disease':<25} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Support':<10}")
print("=" * 75)

from sklearn.metrics import precision_recall_fscore_support

precisions, recalls, f1s, supports = precision_recall_fscore_support(
    test_labels, test_preds, average=None, zero_division=0
)

for i, disease in enumerate(DISEASE_CLASSES):
    print(f"{disease:<25} {precisions[i]:<10.3f} {recalls[i]:<10.3f} {f1s[i]:<10.3f} {int(supports[i]):<10}")

print(f"\n{'='*70}")
print("PERFORMANCE SUMMARY")
print(f"{'='*70}")
print(f"Results with 7-Class Model:")
print(f"  • Exact match: {exact_match_acc:.2f}%")
print(f"  • Hamming accuracy: {hamming_acc:.2f}%")
print(f"  • Average F1 (all classes): {f1s.mean():.3f}")
print(f"")
print(f"Classes with lowest F1 scores:")
lowest_f1_indices = np.argsort(f1s)[:3]
for idx in lowest_f1_indices:
    print(f"  • {DISEASE_CLASSES[idx]}: F1 = {f1s[idx]:.3f}")
print(f"{'='*70}")

In [None]:
# ------------------------------------------------------------------------------
# VISUALIZE RESULTS
# ------------------------------------------------------------------------------

import matplotlib.pyplot as plt
import seaborn as sns

# Create bar chart of per-class F1 scores
plt.figure(figsize=(12, 6))
plt.barh(DISEASE_CLASSES, f1s, color='steelblue')
plt.xlabel('F1-Score')
plt.title('Per-Class F1 Scores - EfficientNet-B3')
plt.xlim(0, 1)
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('per_class_f1_scores.png', dpi=300, bbox_inches='tight')
print("Visualization saved to 'per_class_f1_scores.png'")

In [None]:
# ------------------------------------------------------------------------------
# INFERENCE ON SINGLE IMAGE
# ------------------------------------------------------------------------------

def predict_multilabel_image(image_path, model, device, transform, disease_classes, threshold=0.5):
    """Predict multiple diseases for a chest X-ray image"""
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
    
    predictions = []
    for i, prob in enumerate(probabilities):
        if prob > threshold:
            predictions.append((disease_classes[i], prob))
    
    predictions.sort(key=lambda x: x[1], reverse=True)
    return image, probabilities, predictions


# Example: Test on random image
import random

idx = random.randint(0, len(test_dataset) - 1)
row = test_dataset.metadata.iloc[idx]
img_name = row['Image Index']

img_path = None
for class_folder in DISEASE_CLASSES:
    path = os.path.join(test_dataset.image_dir, class_folder, img_name)
    if os.path.exists(path):
        img_path = path
        break
if img_path is None:
    img_path = os.path.join(test_dataset.image_dir, img_name)

print(f"Testing: {img_name}")

# Parse ground truth labels
ground_truth_labels = []
finding_labels = row['Finding Labels']
if pd.notna(finding_labels):
    diseases = [d.strip() for d in str(finding_labels).split('|')]
    ground_truth_labels = [d for d in diseases if d in DISEASE_CLASSES]

print("Ground Truth:")
if ground_truth_labels:
    for disease in ground_truth_labels:
        print(f"  {disease}")
else:
    print("  No findings")
print()

original_image, all_probs, predicted_diseases = predict_multilabel_image(
    img_path, model, device, eval_transform, DISEASE_CLASSES, threshold=0.3
)

print("Predicted Diseases:")
if predicted_diseases:
    for disease, prob in predicted_diseases:
        print(f"  {disease}: {prob*100:.1f}%")
else:
    print("  None detected")

import matplotlib.pyplot as plt

if predicted_diseases:
    # Only plot diseases above threshold
    detected_diseases = [d for d, p in predicted_diseases]
    detected_probs = [p for d, p in predicted_diseases]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    ax1.imshow(original_image, cmap='gray')
    ax1.set_title('Chest X-Ray')
    ax1.axis('off')
    
    ax2.barh(detected_diseases, detected_probs, color='red')
    ax2.set_xlabel('Probability')
    ax2.set_title('Detected Diseases (Above 30% Threshold)')
    ax2.set_xlim(0, 1)
    ax2.axvline(x=0.3, color='darkred', linestyle='--', linewidth=2, label='Threshold (30%)')
    ax2.legend()
    ax2.grid(axis='x', alpha=0.3)
else:
    # No diseases detected - just show the image
    fig, ax1 = plt.subplots(1, 1, figsize=(7, 6))
    ax1.imshow(original_image, cmap='gray')
    ax1.set_title('Chest X-Ray\n(No Diseases Detected Above Threshold)')
    ax1.axis('off')

plt.tight_layout()
plt.show()