# Results and Analysis

## Hierarchical Classification: Architecture Comparison

### Evaluation metrics: **Accuracy, Precision, Recall, F1-Score, AUC, ROC Curves**

In [None]:
import sys
sys.path.insert(0, '..')

import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, classification_report, accuracy_score,
    precision_score, recall_score, f1_score, roc_auc_score, roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm import tqdm

from config import (
    DEVICE, DATA_CONFIG, MODEL_CONFIG, PATHS, set_seed, DEFAULT_MERGED_DATASETS
)
from utils.data_loader import create_hierarchical_dataset, REGION_FINE_CLASS_COUNTS
from utils.hierarchical_model import HierarchicalClassificationModel

# Set matplotlib style for PGF export
matplotlib.rcParams.update({
    "axes.titlesize": 14,
    "axes.labelsize": 12,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "font.size": 10,
})

set_seed(42)
print(f"Device: {DEVICE}")

## 1. Load Test Data

In [None]:
# Load test data
train_loader, val_loader, test_loader, dataset_info = create_hierarchical_dataset(
    datasets_to_include=DEFAULT_MERGED_DATASETS,
    batch_size=DATA_CONFIG['batch_size'],
    num_workers=DATA_CONFIG['num_workers']
)

print(f"Test samples: {dataset_info['test_samples']}")
print(f"Regions: {dataset_info['idx_to_region']}")

# Extract key info
region_configs = dataset_info['region_num_classes']
region_idx_to_name = dataset_info['idx_to_region']
num_coarse_classes = len(region_idx_to_name)

## 2. Evaluate Stage 1 (Coarse) Models

Load models from `stage1_coarse_*.pth` files.

In [None]:
ARCHITECTURES_STAGE1 = ['efficientnet3d_b0', 'resnet18_3d', 'resnet34_3d', 'densenet121_3d', 'enhanced', 'base']

stage1_results = {}

print("=" * 60)
print("SCANNING FOR STAGE 1 (COARSE) MODELS")
print("=" * 60)

for arch in ARCHITECTURES_STAGE1:
    model_path = f"{PATHS['models']}/stage1_coarse_{arch}.pth"
    
    if os.path.exists(model_path):
        print(f"\nFound: {os.path.basename(model_path)}")
        
        try:
            # Create model
            model = HierarchicalClassificationModel(
                region_configs=region_configs,
                architecture=arch,
                coarse_model_type=arch,
                fine_model_type=arch,
                dropout_rate=MODEL_CONFIG['dropout_rate'],
                region_idx_to_name=region_idx_to_name,
                num_total_organs=dataset_info['num_fine_classes'],
                use_subtypes=False
            ).to(DEVICE)
            
            # Load weights
            checkpoint = torch.load(model_path, map_location=DEVICE)
            model.load_state_dict(checkpoint, strict=False)
            
            print(f"  -> Evaluating coarse classifier...")
            model.eval()
            
            all_preds, all_labels, all_probs = [], [], []
            
            with torch.no_grad():
                for imgs, coarse_labels, _ in tqdm(test_loader, desc=f"Stage1 {arch}"):
                    imgs = imgs.to(DEVICE, dtype=torch.float32)
                    if imgs.max() > 1:
                        imgs = imgs / 255.0
                    coarse_labels = coarse_labels.long().to(DEVICE)
                    
                    logits = model.forward_coarse(imgs)
                    probs = torch.softmax(logits, dim=1)
                    preds = logits.argmax(1)
                    
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(coarse_labels.cpu().numpy())
                    all_probs.append(probs.cpu().numpy())
            
            y_true = np.array(all_labels)
            y_pred = np.array(all_preds)
            y_probs = np.concatenate(all_probs, axis=0)
            
            # Compute metrics
            acc = accuracy_score(y_true, y_pred)
            prec = precision_score(y_true, y_pred, average='weighted', zero_division=0)
            rec = recall_score(y_true, y_pred, average='weighted', zero_division=0)
            f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
            
            # AUC
            try:
                y_true_bin = label_binarize(y_true, classes=list(range(num_coarse_classes)))
                auc_val = roc_auc_score(y_true_bin, y_probs, average='weighted', multi_class='ovr')
            except:
                auc_val = None
            
            stage1_results[arch] = {
                'accuracy': acc, 'precision': prec, 'recall': rec, 'f1': f1, 'auc': auc_val,
                'predictions': y_pred, 'labels': y_true, 'probs': y_probs
            }
            print(f"  -> Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
            
        except Exception as e:
            print(f"  -> Error: {e}")

print(f"\nStage 1 models evaluated: {list(stage1_results.keys())}")

## 3. Evaluate Stage 2 (Fine) Models

Load models from `hierarchical_model_*.pth` files.

In [None]:
ARCHITECTURES_STAGE2 = ['efficientnet3d_b0', 'efficientnet3d_b0_bare_metal', 'resnet18_3d', 'resnet34_3d', 'densenet121_3d']

stage2_results = {}

print("\n" + "=" * 60)
print("SCANNING FOR STAGE 2 (HIERARCHICAL) MODELS")
print("=" * 60)

for arch in ARCHITECTURES_STAGE2:
    base_arch = arch.replace('_bare_metal', '')
    
    # Try new format first
    model_path = f"{PATHS['models']}/hierarchical_model_{arch}.pth"
    if not os.path.exists(model_path):
        model_path = f"{PATHS['models']}/hierarchical_{arch}.pth"
    
    if os.path.exists(model_path):
        print(f"\nFound: {os.path.basename(model_path)}")
        
        try:
            model = HierarchicalClassificationModel(
                region_configs=region_configs,
                architecture=base_arch,
                coarse_model_type=base_arch,
                fine_model_type=base_arch,
                dropout_rate=MODEL_CONFIG['dropout_rate'],
                region_idx_to_name=region_idx_to_name,
                num_total_organs=dataset_info['num_fine_classes'],
                use_subtypes=False
            ).to(DEVICE)
            
            checkpoint = torch.load(model_path, map_location=DEVICE)
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
            
            print(f"  -> Evaluating fine classifiers...")
            model.eval()
            
            all_fine_preds, all_fine_labels = [], []
            
            with torch.no_grad():
                for imgs, coarse_labels, fine_labels in tqdm(test_loader, desc=f"Stage2 {arch}"):
                    imgs = imgs.to(DEVICE, dtype=torch.float32)
                    if imgs.max() > 1:
                        imgs = imgs / 255.0
                    coarse_labels = coarse_labels.long().to(DEVICE)
                    fine_labels = fine_labels.long().to(DEVICE)
                    
                    # Use ground truth routing for fine evaluation
                    for region_idx, region_name in region_idx_to_name.items():
                        mask = (coarse_labels == region_idx)
                        if not mask.any():
                            continue
                        
                        region_imgs = imgs[mask]
                        region_fine_labels = fine_labels[mask]
                        
                        fine_logits = model.forward_fine(region_imgs, region_name)
                        fine_preds = fine_logits.argmax(1)
                        
                        all_fine_preds.extend(fine_preds.cpu().numpy())
                        all_fine_labels.extend(region_fine_labels.cpu().numpy())
            
            y_true = np.array(all_fine_labels)
            y_pred = np.array(all_fine_preds)
            
            acc = accuracy_score(y_true, y_pred)
            prec = precision_score(y_true, y_pred, average='weighted', zero_division=0)
            rec = recall_score(y_true, y_pred, average='weighted', zero_division=0)
            f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
            
            stage2_results[arch] = {
                'accuracy': acc, 'precision': prec, 'recall': rec, 'f1': f1,
                'predictions': y_pred, 'labels': y_true
            }
            print(f"  -> Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}")
            
        except Exception as e:
            print(f"  -> Error: {e}")

print(f"\nStage 2 models evaluated: {list(stage2_results.keys())}")

## 4. Display Results Summary

In [None]:
# STAGE 1 RESULTS
print("\n" + "="*100)
print("HIERARCHICAL MODEL COMPARISON - STAGE 1 (COARSE/REGION)")
print("="*100)
print(f"\n{'Architecture':<28} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'AUC':<12}")
print("-" * 100)

for arch, m in stage1_results.items():
    auc_str = f"{m['auc']:.4f}" if m['auc'] is not None else "N/A"
    print(f"{arch:<28} {m['accuracy']:<12.4f} {m['precision']:<12.4f} {m['recall']:<12.4f} {m['f1']:<12.4f} {auc_str:<12}")

# STAGE 2 RESULTS
print("\n" + "="*100)
print("HIERARCHICAL MODEL COMPARISON - STAGE 2 (FINE/PATHOLOGY)")
print("="*100)
print(f"\n{'Architecture':<28} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
print("-" * 100)

for arch, m in stage2_results.items():
    print(f"{arch:<28} {m['accuracy']:<12.4f} {m['precision']:<12.4f} {m['recall']:<12.4f} {m['f1']:<12.4f}")

print("-" * 100)

## 5. Confusion Matrices

In [None]:
def plot_confusion_matrix(y_true, y_pred, title, labels=None, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    num_classes = len(cm)
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    hm = sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='viridis',
        xticklabels=labels,
        yticklabels=labels,
        ax=ax,
        square=True,
        cbar=False,
        linewidths=0,
        annot_kws={"size": 10}
    )
    
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=10)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=10)
    
    # Add custom colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.15)
    vmin = cm.min()
    vmax = cm.max()
    cbar = fig.colorbar(
        hm.collections[0], 
        cax=cax,
        drawedges=False,
        ticks=np.linspace(vmin, vmax, 11)
    )
    cbar.solids.set_edgecolor("face")
    cbar.solids.set_rasterized(False)
    cbar.set_label("Count", fontsize=12)
    cbar.outline.set_visible(False)
    cbar.ax.tick_params(size=0, labelsize=10)
    
    ax.set_title(title, fontsize=14, fontweight='bold', pad=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight="bold", labelpad=4)
    ax.set_ylabel("True Label", fontsize=12, fontweight="bold", labelpad=4)
    
    # Disable rasterization for all artists
    for artist in fig.findobj():
        if hasattr(artist, "set_rasterized"):
            artist.set_rasterized(False)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", pad_inches=0.05, dpi=150)
    
    plt.show()

In [None]:
# Stage 1 Confusion Matrices
region_labels = [dataset_info['idx_to_region'][i] for i in range(len(dataset_info['idx_to_region']))]

for arch, results in stage1_results.items():
    plot_confusion_matrix(
        results['labels'],
        results['predictions'],
        f'Confusion Matrix - Stage 1 ({arch})',
        labels=region_labels,
        save_path=f"{PATHS['figures']}/confusion_coarse_{arch}.pgf"
    )

In [None]:
# Stage 2 Confusion Matrices
for arch, results in stage2_results.items():
    fine_unique = sorted(list(set(results['labels'])))
    fine_labels_names = [str(i) for i in fine_unique]
    
    plot_confusion_matrix(
        results['labels'],
        results['predictions'],
        f'Confusion Matrix - Stage 2 ({arch})',
        labels=fine_labels_names,
        save_path=f"{PATHS['figures']}/confusion_fine_{arch}.pgf"
    )

## 6. ROC Curves (Stage 1)

In [None]:
# ROC Curves for Stage 1 (Coarse Classification)
for arch, results in stage1_results.items():
    if 'probs' not in results:
        continue
    
    y_true = results['labels']
    y_probs = results['probs']
    
    # Binarize labels
    y_true_bin = label_binarize(y_true, classes=list(range(num_coarse_classes)))
    
    plt.figure(figsize=(10, 8))
    
    for i in range(num_coarse_classes):
        fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
        roc_auc = auc(fpr, tpr)
        region_name = region_idx_to_name[i]
        plt.plot(fpr, tpr, label=f'{region_name} (AUC = {roc_auc:.3f})')
    
    plt.plot([0, 1], [0, 1], 'k--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curves - Stage 1 (Coarse) - {arch}')
    plt.legend(loc='lower right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"{PATHS['figures']}/roc_stage1_{arch}.pgf", bbox_inches="tight", pad_inches=0.05, dpi=150)
    plt.show()

## 7. Summary

In [None]:
print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)
print(f"\nStage 1 (Coarse) models evaluated: {len(stage1_results)}")
print(f"Stage 2 (Fine) models evaluated: {len(stage2_results)}")

if stage1_results:
    best_stage1 = max(stage1_results.items(), key=lambda x: x[1]['accuracy'])
    print(f"\nBest Stage 1 model: {best_stage1[0]} (Accuracy: {best_stage1[1]['accuracy']:.4f})")

if stage2_results:
    best_stage2 = max(stage2_results.items(), key=lambda x: x[1]['accuracy'])
    print(f"Best Stage 2 model: {best_stage2[0]} (Accuracy: {best_stage2[1]['accuracy']:.4f})")