# Hierarchical 3D Medical Image Classification

This notebook trains a hierarchical classification model on **all 5 merged 3D MedMNIST datasets**:
- **OrganMNIST3D** (11 classes, multi-region)
- **NoduleMNIST3D** (2 classes, chest)
- **AdrenalMNIST3D** (2 classes, abdomen)
- **FractureMNIST3D** (3 classes, chest)
- **VesselMNIST3D** (2 classes, brain)

## Training Pipeline
1. **Stage 1 (Coarse)**: Classify anatomical region (abdomen, chest, brain)
2. **Stage 2 (Fine)**: Classify pathology within each region

In [1]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from config import (
    DEVICE, DATA_CONFIG, MODEL_CONFIG, TRAINING_CONFIG,
    PATHS, set_seed, DEFAULT_MERGED_DATASETS
)
from utils.data_loader import create_hierarchical_dataset
from utils.hierarchical_model import HierarchicalClassificationModel
from utils.trainer import HierarchicalTrainer
from utils.visualization import plot_training_history

set_seed(42)
print(f"Device: {DEVICE}")

GPU detected: NVIDIA GeForce RTX 3070
GPU memory: 8.21 GB
Platform: NVIDIA CUDA
Device: cuda


## 1. Load Merged Dataset

In [2]:
# Load all 5 merged datasets
print(f"Loading datasets: {DEFAULT_MERGED_DATASETS}")

train_loader, val_loader, test_loader, dataset_info = create_hierarchical_dataset(
    datasets_to_include=DEFAULT_MERGED_DATASETS,
    batch_size=DATA_CONFIG['batch_size'],
    num_workers=DATA_CONFIG['num_workers']
)

print("\n" + "="*60)
print("MERGED DATASET INFO")
print("="*60)
print(f"Datasets: {dataset_info['datasets_included']}")
print(f"Train samples: {dataset_info['train_samples']:,}")
print(f"Val samples: {dataset_info['val_samples']:,}")
print(f"Test samples: {dataset_info['test_samples']:,}")
print(f"\nCoarse classes (regions): {dataset_info['num_coarse_classes']}")
print(f"Region mapping: {dataset_info['idx_to_region']}")
print(f"Fine classes: {dataset_info['num_fine_classes']}")

Loading datasets: ['organ', 'nodule', 'adrenal', 'fracture', 'vessel']
Using downloaded and verified file: /home/luca/.medmnist/organmnist3d.npz
Using downloaded and verified file: /home/luca/.medmnist/nodulemnist3d.npz
Using downloaded and verified file: /home/luca/.medmnist/adrenalmnist3d.npz


KeyboardInterrupt: 

In [None]:
# Verify data format
print("\nVerifying data format...")
for imgs, coarse_labels, fine_labels in train_loader:
    print(f"  Image shape: {imgs.shape}")
    print(f"  Coarse labels: {coarse_labels.shape} - unique: {coarse_labels.unique().tolist()}")
    print(f"  Fine labels: {fine_labels.shape} - unique: {fine_labels.unique().tolist()[:10]}...")
    break
print(" Data format verified!")

## 2. Build Hierarchical Model

The model consists of:
- **Coarse Classifier**: Shared 3D CNN backbone + region classification head
- **Fine Classifiers**: Separate classifier for each region

In [None]:
# Configure region-specific classes
# Each region has a different number of fine-grained classes
region_configs = {
    'abdomen': 2,   # AdrenalMNIST3D classes
    'brain': 2,     # VesselMNIST3D classes  
    'chest': 5,     # NoduleMNIST3D (2) + FractureMNIST3D (3)
    'multi': 11,    # OrganMNIST3D classes
}

# Get region index mapping from dataset
region_idx_to_name = dataset_info['idx_to_region']

print("Region configurations:")
for region, num_classes in region_configs.items():
    print(f"  {region}: {num_classes} classes")

In [None]:
# Create hierarchical model
model = HierarchicalClassificationModel(
    region_configs=region_configs,
    architecture=MODEL_CONFIG['architecture'],
    coarse_model_type=MODEL_CONFIG['coarse_architecture'],
    fine_model_type=MODEL_CONFIG['fine_architecture'],
    dropout_rate=MODEL_CONFIG['dropout_rate'],
    region_idx_to_name=region_idx_to_name,
    num_total_organs=dataset_info['num_fine_classes'],
    use_subtypes=False
).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Coarse architecture: {MODEL_CONFIG['coarse_architecture']}")
print(f"  Fine architecture: {MODEL_CONFIG['fine_architecture']}")

## 3. Train Hierarchical Model

Training proceeds in two stages:
1. **Stage 1**: Train coarse classifier to predict anatomical regions
2. **Stage 2**: Freeze coarse classifier, train region-specific fine classifiers

In [None]:
# Create trainer
trainer = HierarchicalTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=DEVICE,
    coarse_weight=0.3,
    fine_weight=0.7
)

print(f"Training configuration:")
print(f"  Coarse epochs: {TRAINING_CONFIG['coarse_epochs']}")
print(f"  Fine epochs: {TRAINING_CONFIG['fine_epochs']}")
print(f"  Learning rate: {TRAINING_CONFIG['learning_rate']}")

In [None]:
# Train the model
history = trainer.train(
    coarse_epochs=TRAINING_CONFIG['coarse_epochs'],
    fine_epochs=TRAINING_CONFIG['fine_epochs']
)

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)

## 4. Training Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Coarse training
if history['coarse_train_loss']:
    axes[0, 0].plot(history['coarse_train_loss'], label='Coarse Train Loss', color='blue')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Stage 1: Coarse Classifier Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

if history['coarse_train_acc']:
    axes[0, 1].plot(history['coarse_train_acc'], label='Coarse Train Acc', color='blue')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Stage 1: Coarse Classifier Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

# Fine training
if history['fine_train_loss']:
    axes[1, 0].plot(history['fine_train_loss'], label='Fine Train Loss', color='green')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].set_title('Stage 2: Fine Classifier Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

if history['fine_train_acc']:
    axes[1, 1].plot(history['fine_train_acc'], label='Fine Train Acc', color='green')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].set_title('Stage 2: Fine Classifier Accuracy')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{PATHS['figures']}/hierarchical_training_{MODEL_CONFIG['architecture']}.png", dpi=150)
plt.show()

## 5. Evaluation on Test Set

In [None]:
def evaluate_hierarchical_model(model, test_loader, device, region_idx_to_name):
    """Evaluate hierarchical model on test set."""
    model.eval()
    
    coarse_correct = 0
    coarse_total = 0
    
    fine_correct = {name: 0 for name in region_idx_to_name.values()}
    fine_total = {name: 0 for name in region_idx_to_name.values()}
    
    all_coarse_preds = []
    all_coarse_labels = []
    all_fine_preds = []
    all_fine_labels = []
    
    with torch.no_grad():
        for imgs, coarse_labels, fine_labels in tqdm(test_loader, desc="Evaluating"):
            imgs = imgs.to(device, dtype=torch.float32)
            if imgs.max() > 1:
                imgs = imgs / 255.0
            
            coarse_labels = coarse_labels.long().to(device)
            fine_labels = fine_labels.long().to(device)
            
            # Stage 1: Coarse prediction
            coarse_logits = model.forward_coarse(imgs)
            coarse_preds = coarse_logits.argmax(1)
            
            coarse_correct += (coarse_preds == coarse_labels).sum().item()
            coarse_total += imgs.size(0)
            
            all_coarse_preds.extend(coarse_preds.cpu().numpy())
            all_coarse_labels.extend(coarse_labels.cpu().numpy())
            
            # Stage 2: Fine prediction per region
            for region_idx, region_name in region_idx_to_name.items():
                mask = (coarse_labels == region_idx)
                if not mask.any():
                    continue
                
                region_imgs = imgs[mask]
                region_fine_labels = fine_labels[mask]
                
                fine_logits = model.forward_fine(region_imgs, region_name)
                fine_preds = fine_logits.argmax(1)
                
                fine_correct[region_name] += (fine_preds == region_fine_labels).sum().item()
                fine_total[region_name] += region_imgs.size(0)
                
                all_fine_preds.extend(fine_preds.cpu().numpy())
                all_fine_labels.extend(region_fine_labels.cpu().numpy())
    
    # Compute metrics
    coarse_acc = coarse_correct / coarse_total if coarse_total > 0 else 0
    
    fine_acc_per_region = {}
    for region_name in region_idx_to_name.values():
        if fine_total[region_name] > 0:
            fine_acc_per_region[region_name] = fine_correct[region_name] / fine_total[region_name]
        else:
            fine_acc_per_region[region_name] = 0.0
    
    overall_fine_acc = sum(fine_correct.values()) / sum(fine_total.values()) if sum(fine_total.values()) > 0 else 0
    
    return {
        'coarse_accuracy': coarse_acc,
        'fine_accuracy_per_region': fine_acc_per_region,
        'overall_fine_accuracy': overall_fine_acc,
        'coarse_predictions': np.array(all_coarse_preds),
        'coarse_labels': np.array(all_coarse_labels),
        'fine_predictions': np.array(all_fine_preds),
        'fine_labels': np.array(all_fine_labels),
    }

In [None]:
# Evaluate on test set
results = evaluate_hierarchical_model(model, test_loader, DEVICE, region_idx_to_name)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"\nStage 1 (Coarse) Accuracy: {results['coarse_accuracy']:.4f}")
print(f"\nStage 2 (Fine) Accuracy per Region:")
for region, acc in results['fine_accuracy_per_region'].items():
    print(f"  {region}: {acc:.4f}")
print(f"\nOverall Fine Accuracy: {results['overall_fine_accuracy']:.4f}")

## 6. Save Model

In [None]:
import os
os.makedirs(PATHS['models'], exist_ok=True)

# Save model
model_path = f"{PATHS['models']}/hierarchical_{MODEL_CONFIG['architecture']}.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'region_configs': region_configs,
    'dataset_info': dataset_info,
    'history': history,
    'test_results': results,
    'config': {
        'coarse_architecture': MODEL_CONFIG['coarse_architecture'],
        'fine_architecture': MODEL_CONFIG['fine_architecture'],
        'dropout_rate': MODEL_CONFIG['dropout_rate'],
    }
}, model_path)

print(f"\n Model saved to: {model_path}")

## 7. Summary

In [None]:
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nDatasets used: {', '.join(dataset_info['datasets_included'])}")
print(f"Total training samples: {dataset_info['train_samples']:,}")
print(f"\nModel architecture:")
print(f"  Coarse: {MODEL_CONFIG['coarse_architecture']}")
print(f"  Fine: {MODEL_CONFIG['fine_architecture']}")
print(f"  Total parameters: {total_params:,}")
print(f"\nFinal Results:")
print(f"  Coarse accuracy: {results['coarse_accuracy']:.4f}")
print(f"  Fine accuracy: {results['overall_fine_accuracy']:.4f}")
print(f"\nModel saved to: {model_path}")