# Multi-Objective Hyperparameter Optimization for Breast Cancer Classification

## Tutorial: Step-by-Step Guide

This notebook demonstrates how to use the implementation for multi-objective hyperparameter optimization of CNNs for breast cancer classification under dataset shift.

**Datasets:**
- **VinDr-Mammo** (Source): Training and validation
- **INbreast** (Target): Zero-shot evaluation only

**Objectives:**
1. Maximize PR-AUC
2. Maximize AUROC
3. Minimize Brier score
4. Minimize robustness degradation

---

## Step 1: Setup and Imports

First, ensure all dependencies are installed and import required modules.

In [None]:
# Install dependencies (if needed)
# !pip install -r requirements.txt

import os
import sys
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from pathlib import Path

# Check CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

---

## Step 2: Configuration

Configure dataset paths and settings. **Adjust these to match your data!**

In [None]:
import config

# Set your dataset paths
config.VINDR_MAMMO_PATH = "/path/to/vindr_mammo"  # TODO: Update this!
config.INBREAST_PATH = "/path/to/inbreast"        # TODO: Update this!

# Verify paths exist
print(f"VinDr-Mammo path: {config.VINDR_MAMMO_PATH}")
print(f"  Exists: {os.path.exists(config.VINDR_MAMMO_PATH)}")
print(f"\nINbreast path: {config.INBREAST_PATH}")
print(f"  Exists: {os.path.exists(config.INBREAST_PATH)}")

# Display current configuration
print("\n=== VinDr-Mammo Configuration ===")
for key, value in config.VINDR_CONFIG.items():
    print(f"  {key}: {value}")

print("\n=== INbreast Configuration ===")
for key, value in config.INBREAST_CONFIG.items():
    print(f"  {key}: {value}")

---

## Step 3: Load and Parse VinDr-Mammo Dataset

Use the dataset-specific parser to load VinDr-Mammo metadata.

In [None]:
from data.parsers import parse_dataset
from optimization.nsga3_runner import load_metadata

# Load VinDr-Mammo metadata
print("Loading VinDr-Mammo dataset...")
vindr_metadata = load_metadata(
    dataset_name="vindr",
    dataset_path=config.VINDR_MAMMO_PATH,
    dataset_config=config.VINDR_CONFIG
)

print(f"\nLoaded {len(vindr_metadata)} images")
print(f"Unique patients: {vindr_metadata['patient_id'].nunique()}")
print(f"Unique breasts: {vindr_metadata['breast_id'].nunique()}")

# Display sample
print("\nSample metadata:")
vindr_metadata.head()

In [None]:
# Check label distribution
label_counts = vindr_metadata['label'].value_counts()
print("\n=== Label Distribution ===")
print(f"Benign (0): {label_counts.get(0, 0)} ({label_counts.get(0, 0)/len(vindr_metadata)*100:.1f}%)")
print(f"Malignant (1): {label_counts.get(1, 0)} ({label_counts.get(1, 0)/len(vindr_metadata)*100:.1f}%)")

# Visualize distribution
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
label_counts.plot(kind='bar')
plt.title('Label Distribution')
plt.xlabel('Label')
plt.ylabel('Count')
plt.xticks([0, 1], ['Benign (0)', 'Malignant (1)'], rotation=0)

plt.subplot(1, 2, 2)
view_counts = vindr_metadata['view'].value_counts()
view_counts.plot(kind='bar')
plt.title('View Distribution')
plt.xlabel('View')
plt.ylabel('Count')
plt.xticks(rotation=0)

plt.tight_layout()
plt.show()

---

## Step 4: Create Train/Validation Split

Patient-wise split (80/20) to prevent data leakage.

In [None]:
from data.dataset import create_train_val_split
from utils.seed import set_all_seeds

# Set random seed for reproducibility
set_all_seeds(config.RANDOM_SEED)

# Create patient-wise split
train_metadata, val_metadata = create_train_val_split(
    vindr_metadata,
    train_ratio=config.TRAIN_VAL_SPLIT,
    random_seed=config.RANDOM_SEED
)

print("=== Train/Validation Split ===")
print(f"\nTrain set:")
print(f"  Images: {len(train_metadata)}")
print(f"  Patients: {train_metadata['patient_id'].nunique()}")
print(f"  Breasts: {train_metadata['breast_id'].nunique()}")
print(f"  Label 0: {(train_metadata['label'] == 0).sum()}")
print(f"  Label 1: {(train_metadata['label'] == 1).sum()}")

print(f"\nValidation set:")
print(f"  Images: {len(val_metadata)}")
print(f"  Patients: {val_metadata['patient_id'].nunique()}")
print(f"  Breasts: {val_metadata['breast_id'].nunique()}")
print(f"  Label 0: {(val_metadata['label'] == 0).sum()}")
print(f"  Label 1: {(val_metadata['label'] == 1).sum()}")

# Verify no patient overlap
train_patients = set(train_metadata['patient_id'])
val_patients = set(val_metadata['patient_id'])
overlap = train_patients & val_patients
print(f"\nPatient overlap: {len(overlap)} (should be 0)")
assert len(overlap) == 0, "ERROR: Patient overlap detected!"

---

## Step 5: Load and Inspect Sample Images

Verify that images can be loaded correctly.

In [None]:
from PIL import Image
from pathlib import Path

# Load a few sample images
image_dir = Path(config.VINDR_MAMMO_PATH) / config.VINDR_CONFIG["image_dir"]

sample_images = train_metadata.sample(min(4, len(train_metadata)))

fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()

for idx, (_, row) in enumerate(sample_images.iterrows()):
    if idx >= 4:
        break
    
    img_path = image_dir / row['image_path']
    
    if img_path.exists():
        img = Image.open(img_path)
        axes[idx].imshow(img, cmap='gray')
        axes[idx].set_title(
            f"Patient: {row['patient_id']}\n"
            f"View: {row['view']}, Label: {row['label']} "
            f"({'Benign' if row['label'] == 0 else 'Malignant'})"
        )
        axes[idx].axis('off')
    else:
        axes[idx].text(0.5, 0.5, f"Image not found:\n{img_path}", 
                      ha='center', va='center')
        axes[idx].axis('off')

plt.tight_layout()
plt.show()

---

## Step 6: Create Model and DataLoaders

Set up ResNet-50 model with partial fine-tuning and data loaders.

In [None]:
from models import ResNet50WithPartialFineTuning
from data.dataset import create_dataloaders

# Example hyperparameters
hparams = {
    "learning_rate": 0.001,
    "weight_decay": 0.0001,
    "dropout_rate": 0.2,
    "augmentation_strength": 0.5,
    "unfreeze_fraction": 0.3,
}

# Create model
model = ResNet50WithPartialFineTuning(
    unfreeze_fraction=hparams["unfreeze_fraction"],
    dropout_rate=hparams["dropout_rate"],
    pretrained=True,
)

trainable_params = model.get_trainable_params()
frozen_params = model.get_frozen_params()
total_params = trainable_params + frozen_params

print("=== Model Configuration ===")
print(f"Architecture: ResNet-50")
print(f"Pretrained: ImageNet")
print(f"Unfreeze fraction: {hparams['unfreeze_fraction']:.2f}")
print(f"Dropout rate: {hparams['dropout_rate']:.2f}")
print(f"\nParameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
print(f"  Frozen: {frozen_params:,} ({frozen_params/total_params*100:.1f}%)")

In [None]:
# Create dataloaders
train_loader, val_loader = create_dataloaders(
    train_metadata=train_metadata,
    val_metadata=val_metadata,
    image_dir=str(image_dir),
    batch_size=config.BATCH_SIZE,
    augmentation_strength=hparams["augmentation_strength"],
    num_workers=2,
)

print(f"\n=== DataLoaders ===")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Batch size: {config.BATCH_SIZE}")
print(f"Augmentation strength: {hparams['augmentation_strength']:.2f}")

---

## Step 7: Train a Single Model (Demo)

Train one model to demonstrate the pipeline. For full optimization, use `nsga3_runner.py`.

In [None]:
from training.trainer import Trainer

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    val_metadata=val_metadata,
    learning_rate=hparams["learning_rate"],
    weight_decay=hparams["weight_decay"],
    max_epochs=5,  # Use 5 epochs for demo (change to 100 for real training)
    device=device,
    checkpoint_dir="./demo_checkpoints",
)

print("=== Training Configuration ===")
print(f"Max epochs: 5 (demo)")
print(f"Learning rate: {hparams['learning_rate']:.6f}")
print(f"Weight decay: {hparams['weight_decay']:.6f}")
print(f"Device: {device}")
print(f"\nStarting training...\n")

In [None]:
# Train the model
best_metrics = trainer.train()

print("\n=== Training Complete ===")
print(f"Best validation metrics:")
print(f"  PR-AUC: {best_metrics['pr_auc']:.4f}")
print(f"  AUROC: {best_metrics['auroc']:.4f}")
print(f"  Brier: {best_metrics['brier']:.4f}")

In [None]:
# Plot training history
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plt.plot(trainer.history['train_loss'])
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(trainer.history['val_pr_auc'], label='PR-AUC')
plt.plot(trainer.history['val_auroc'], label='AUROC')
plt.title('Validation Metrics')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(trainer.history['val_brier'])
plt.title('Validation Brier Score')
plt.xlabel('Epoch')
plt.ylabel('Brier Score')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Step 8: Evaluate Robustness

Measure performance degradation under intensity perturbations.

In [None]:
from training.robustness import RobustnessEvaluator

# Evaluate robustness
robustness_eval = RobustnessEvaluator(
    model=model,
    val_loader=val_loader,
    val_metadata=val_metadata,
    device=device,
)

robustness_degradation = robustness_eval.evaluate()

print("=== Robustness Evaluation ===")
print(f"Robustness degradation: {robustness_degradation:.4f}")
print(f"\nInterpretation:")
print(f"  - Lower is better (more robust)")
print(f"  - Degradation = PR-AUC_standard - PR-AUC_perturbed")
print(f"  - Negative value means model improved under perturbation")

---

## Step 9: Compute All Objectives

Calculate the 4 optimization objectives.

In [None]:
# Summary of all objectives
objectives = {
    "PR-AUC": best_metrics['pr_auc'],
    "AUROC": best_metrics['auroc'],
    "Brier Score": best_metrics['brier'],
    "Robustness Degradation": robustness_degradation,
}

print("=== All Objectives ===")
for name, value in objectives.items():
    direction = "↑ (maximize)" if "AUC" in name else "↓ (minimize)"
    print(f"{name:25s}: {value:7.4f} {direction}")

# For NSGA-III (all minimization)
nsga3_objectives = [
    -objectives["PR-AUC"],
    -objectives["AUROC"],
    objectives["Brier Score"],
    objectives["Robustness Degradation"],
]

print("\n=== Converted for NSGA-III (all minimization) ===")
print(f"Objective vector: {nsga3_objectives}")

---

## Step 10: Test Noisy OR Aggregation

Verify breast-level aggregation from image-level predictions.

In [None]:
from utils.noisy_or import noisy_or_aggregation, aggregate_to_breast_level

# Example: Manual Noisy OR
p_cc = 0.3
p_mlo = 0.4
p_breast = noisy_or_aggregation(p_cc, p_mlo)

print("=== Noisy OR Example ===")
print(f"CC view probability: {p_cc:.3f}")
print(f"MLO view probability: {p_mlo:.3f}")
print(f"Breast-level probability: {p_breast:.3f}")
print(f"\nFormula: 1 - (1 - {p_cc}) * (1 - {p_mlo}) = {p_breast:.3f}")

# Test with validation set
model.eval()
image_predictions = {}

with torch.no_grad():
    for images, labels, image_ids in val_loader:
        images = images.to(device)
        preds = model(images).cpu().numpy()
        for img_id, pred in zip(image_ids, preds):
            image_predictions[img_id] = float(pred)

breast_preds, breast_labels = aggregate_to_breast_level(
    image_predictions, val_metadata
)

print(f"\n=== Aggregation Results ===")
print(f"Image-level predictions: {len(image_predictions)}")
print(f"Breast-level predictions: {len(breast_preds)}")
print(f"Reduction factor: {len(image_predictions) / len(breast_preds):.1f}x")

# Show sample breast with both views
sample_breast = val_metadata.groupby('breast_id').filter(lambda x: len(x) == 2).groupby('breast_id').first()
if len(sample_breast) > 0:
    breast_id = sample_breast.index[0]
    breast_images = val_metadata[val_metadata['breast_id'] == breast_id]
    
    print(f"\nSample breast: {breast_id}")
    for _, row in breast_images.iterrows():
        pred = image_predictions.get(row['image_id'], 0.0)
        print(f"  {row['view']:3s} view: p = {pred:.3f}")
    
    # Find aggregated prediction
    breast_idx = val_metadata[val_metadata['breast_id'] == breast_id].index[0]
    breast_dict = dict(zip(val_metadata['breast_id'], range(len(val_metadata['breast_id'].unique()))))
    print(f"  Breast-level (Noisy OR): p = {breast_preds[breast_dict.get(breast_id, 0)]:.3f}")

---

## Step 11: Run NSGA-III Optimization (Full Pipeline)

For full multi-objective optimization, run the NSGA-III script from command line.

**Note:** This is computationally expensive! Each evaluation trains a full CNN.

```bash
# From terminal:
python optimization/nsga3_runner.py
```

Configuration in `config.py`:
```python
NSGA3_CONFIG = {
    "pop_size": 24,        # Population size
    "n_generations": 50,   # Number of generations
    "n_objectives": 4,     # 4 objectives
}
```

**Total evaluations:** 24 × 50 = 1,200 (adjust for your compute budget)

In [None]:
# Display optimization configuration
print("=== NSGA-III Configuration ===")
print(f"Population size: {config.NSGA3_CONFIG['pop_size']}")
print(f"Generations: {config.NSGA3_CONFIG['n_generations']}")
print(f"Objectives: {config.NSGA3_CONFIG['n_objectives']}")
print(f"\nTotal evaluations: {config.NSGA3_CONFIG['pop_size'] * config.NSGA3_CONFIG['n_generations']}")
print(f"\nEstimated time (assuming 1 hour per evaluation):")
print(f"  {config.NSGA3_CONFIG['pop_size'] * config.NSGA3_CONFIG['n_generations']} hours")
print(f"  = {config.NSGA3_CONFIG['pop_size'] * config.NSGA3_CONFIG['n_generations'] / 24:.1f} days")
print(f"\n⚠️  Consider reducing pop_size and n_generations for testing!")

---

## Step 12: Analyze Pareto Front (After Optimization)

After NSGA-III completes, analyze the Pareto front.

```bash
# From terminal:
python optimization/analyze_pareto.py --results_dir ./optimization_results
```

In [None]:
# Example: Load and analyze Pareto results (if available)
import glob

results_dir = "./optimization_results"

# Find most recent results
csv_files = glob.glob(f"{results_dir}/pareto_solutions_*.csv")

if csv_files:
    csv_files.sort(reverse=True)
    latest_results = csv_files[0]
    
    print(f"Loading Pareto solutions from: {latest_results}")
    pareto_df = pd.read_csv(latest_results)
    
    print(f"\nNumber of Pareto solutions: {len(pareto_df)}")
    
    # Summary statistics
    print("\n=== Pareto Front Summary ===")
    print(pareto_df[['pr_auc', 'auroc', 'brier', 'robustness_degradation']].describe())
    
    # Find extreme solutions
    print("\n=== Extreme Solutions ===")
    print(f"\nBest PR-AUC: {pareto_df.loc[pareto_df['pr_auc'].idxmax()]['pr_auc']:.4f}")
    print(f"Best AUROC: {pareto_df.loc[pareto_df['auroc'].idxmax()]['auroc']:.4f}")
    print(f"Best Brier: {pareto_df.loc[pareto_df['brier'].idxmin()]['brier']:.4f}")
    print(f"Best Robustness: {pareto_df.loc[pareto_df['robustness_degradation'].idxmin()]['robustness_degradation']:.4f}")
    
    # Plot Pareto front (2D projections)
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    pairs = [
        ('pr_auc', 'auroc'),
        ('pr_auc', 'brier'),
        ('pr_auc', 'robustness_degradation'),
        ('auroc', 'brier'),
        ('auroc', 'robustness_degradation'),
        ('brier', 'robustness_degradation'),
    ]
    
    for idx, (obj1, obj2) in enumerate(pairs):
        axes[idx].scatter(pareto_df[obj1], pareto_df[obj2], alpha=0.6)
        axes[idx].set_xlabel(obj1.replace('_', ' ').title())
        axes[idx].set_ylabel(obj2.replace('_', ' ').title())
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
else:
    print("No Pareto results found.")
    print("Run 'python optimization/nsga3_runner.py' first.")

---

## Step 13: Evaluate on INbreast (Zero-Shot Transfer)

Load INbreast and evaluate trained model without any fine-tuning.

In [None]:
# Load INbreast dataset
print("Loading INbreast dataset...")
inbreast_metadata = load_metadata(
    dataset_name="inbreast",
    dataset_path=config.INBREAST_PATH,
    dataset_config=config.INBREAST_CONFIG
)

print(f"\nLoaded {len(inbreast_metadata)} images")
print(f"Unique patients: {inbreast_metadata['patient_id'].nunique()}")
print(f"Unique breasts: {inbreast_metadata['breast_id'].nunique()}")

# Display sample
print("\nSample metadata:")
inbreast_metadata.head()

In [None]:
# Check label distribution
label_counts = inbreast_metadata['label'].value_counts()
print("\n=== INbreast Label Distribution ===")
print(f"Benign (0): {label_counts.get(0, 0)} ({label_counts.get(0, 0)/len(inbreast_metadata)*100:.1f}%)")
print(f"Malignant (1): {label_counts.get(1, 0)} ({label_counts.get(1, 0)/len(inbreast_metadata)*100:.1f}%)")

# Check BI-RADS distribution (including subcategories)
print("\n=== BI-RADS Distribution (with subcategories) ===")
birads_counts = inbreast_metadata['birads_original'].value_counts()
for birads, count in birads_counts.items():
    label = inbreast_metadata[inbreast_metadata['birads_original'] == birads]['label'].iloc[0]
    print(f"  {birads}: {count} images → Label {label}")

In [None]:
from evaluation.evaluate_target import evaluate_target_zero_shot
from training.metrics import find_optimal_threshold

# Find optimal threshold on source validation set
threshold = find_optimal_threshold(breast_preds, breast_labels)
print(f"Optimal threshold (from source): {threshold:.4f}")

# Get INbreast image directory
inbreast_image_dir = str(Path(config.INBREAST_PATH) / config.INBREAST_CONFIG["image_dir"])

# Zero-shot evaluation
print("\n=== Zero-Shot Evaluation on INbreast ===")
print("NOTE: No fine-tuning, no threshold tuning - pure transfer!\n")

target_metrics = evaluate_target_zero_shot(
    model=model,
    target_metadata=inbreast_metadata,
    image_dir=inbreast_image_dir,
    threshold=threshold,
    device=device,
)

print("\n=== INbreast Results ===")
print(f"PR-AUC: {target_metrics['pr_auc']:.4f}")
print(f"AUROC: {target_metrics['auroc']:.4f}")
print(f"Brier Score: {target_metrics['brier']:.4f}")
print(f"Threshold (transferred): {threshold:.4f}")
print(f"Sensitivity: {target_metrics['sensitivity']:.4f}")
print(f"Specificity: {target_metrics['specificity']:.4f}")

In [None]:
# Compare source and target performance
comparison = pd.DataFrame({
    'Metric': ['PR-AUC', 'AUROC', 'Brier Score'],
    'VinDr-Mammo (Source)': [best_metrics['pr_auc'], best_metrics['auroc'], best_metrics['brier']],
    'INbreast (Target)': [target_metrics['pr_auc'], target_metrics['auroc'], target_metrics['brier']],
})

comparison['Difference'] = comparison['INbreast (Target)'] - comparison['VinDr-Mammo (Source)']

print("\n=== Source vs Target Comparison ===")
print(comparison.to_string(index=False))

# Visualize comparison
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(comparison))
width = 0.35

ax.bar(x - width/2, comparison['VinDr-Mammo (Source)'], width, label='VinDr-Mammo (Source)', alpha=0.8)
ax.bar(x + width/2, comparison['INbreast (Target)'], width, label='INbreast (Target)', alpha=0.8)

ax.set_xlabel('Metric')
ax.set_ylabel('Score')
ax.set_title('Source vs Target Performance (Zero-Shot Transfer)')
ax.set_xticks(x)
ax.set_xticklabels(comparison['Metric'])
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## Summary

This notebook demonstrated:

1. ✅ **Dataset loading** with dataset-specific parsers (VinDr-Mammo, INbreast)
2. ✅ **BI-RADS mapping** including subcategories (4A, 4B, 4C)
3. ✅ **Patient-wise splitting** (no data leakage)
4. ✅ **Model creation** (ResNet-50 with partial fine-tuning)
5. ✅ **Training** with early stopping
6. ✅ **Robustness evaluation** under perturbations
7. ✅ **Noisy OR aggregation** for breast-level predictions
8. ✅ **4 objectives** (PR-AUC, AUROC, Brier, Robustness)
9. ✅ **Zero-shot transfer** to INbreast

### Next Steps

For full multi-objective optimization:

```bash
# Run NSGA-III
python optimization/nsga3_runner.py

# Analyze Pareto front
python optimization/analyze_pareto.py

# Evaluate selected solutions
python evaluation/evaluate_source.py --checkpoint path/to/checkpoint.pt --hyperparameters config.json
python evaluation/evaluate_target.py --checkpoint path/to/checkpoint.pt --threshold 0.45 --hyperparameters config.json
```

### Documentation

- **[README.md](README.md)** - Main documentation
- **[DATASET_SETUP_GUIDE.md](DATASET_SETUP_GUIDE.md)** - Dataset preparation
- **[IMPLEMENTATION_NOTES.md](IMPLEMENTATION_NOTES.md)** - Technical details
- **[TEST_RESULTS.md](TEST_RESULTS.md)** - Test documentation

### Testing

```bash
python test_correctness.py  # 79 tests
python test_parsers.py      # 34 tests
python test_integration.py  # 1 test
```