# Federated vs Centralized: Model Comparison

**Comparing DSCATNet Performance: Centralized (Original Paper) vs Federated Learning**

This notebook provides a head-to-head comparison between:

1. **Centralized Training** (Baseline) - Original DSCATNet paper approach  
   *Reference: [DSCATNet: Dual-Scale Cross-Attention Vision Transformer](https://doi.org/10.1371/journal.pone.0312598)*
   
2. **Federated Learning** - Our implementation with non-IID data

---

## Research Context

The original DSCATNet paper (PLOS ONE, 2024) demonstrates state-of-the-art performance on dermoscopy classification using centralized training. This thesis investigates whether similar performance can be achieved in a **privacy-preserving federated learning** setting where data remains distributed across multiple hospitals.

### Key Questions:
- How much accuracy is lost when moving from centralized to federated learning?
- Do different non-IID distributions affect convergence differently?
- Which classes suffer most from federated training?

---

## 1. Setup and Configuration

In [None]:
# Standard library
import sys
import json
from pathlib import Path
from datetime import datetime

# Data science
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# PyTorch
import torch
from torch.utils.data import DataLoader

# Add project root
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Project imports
from src.models.dscatnet import create_dscatnet
from src.data.datasets import UNIFIED_CLASSES, CLASS_NAMES as DATASET_CLASS_NAMES
from src.data.preprocessing import get_val_transforms
from src.evaluation.metrics import ModelEvaluator, EvaluationResults, compare_results

# Style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
%matplotlib inline

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

CLASS_NAMES = list(DATASET_CLASS_NAMES)
print(f"Classes: {CLASS_NAMES}")


In [None]:
# =============================================================================
# CONFIGURATION - Set paths to your trained models
# =============================================================================

# Centralized model (baseline)
CENTRALIZED_CHECKPOINT = project_root / "outputs" / "centralized_YYYYMMDD_HHMMSS" / "checkpoints" / "best_model.pt"

# Federated model(s) - can compare multiple FL experiments
FEDERATED_CHECKPOINTS = {
    "FL (Dirichlet Œ±=0.5)": project_root / "outputs" / "federated_YYYYMMDD_HHMMSS" / "checkpoints" / "best_model.pt",
    # Add more FL experiments to compare:
    # "FL (Dirichlet Œ±=0.1)": project_root / "outputs" / "federated_alpha01" / "checkpoints" / "best_model.pt",
    # "FL (Natural Non-IID)": project_root / "outputs" / "federated_natural" / "checkpoints" / "best_model.pt",
}

# Training history files (for convergence analysis)
CENTRALIZED_HISTORY = project_root / "outputs" / "centralized_YYYYMMDD_HHMMSS" / "history.json"
FEDERATED_HISTORY = project_root / "outputs" / "federated_YYYYMMDD_HHMMSS" / "history.json"

# Dataset for evaluation
DATASET_NAME = "ISIC2019"  # Should match training dataset
DATA_ROOT = project_root / "data"

# Model config
MODEL_VARIANT = "small"
NUM_CLASSES = 7
IMAGE_SIZE = 224
BATCH_SIZE = 32

# Reference values from original DSCATNet paper (Table 4)
# Paper: https://doi.org/10.1371/journal.pone.0312598
PAPER_RESULTS = {
    "HAM10000": {
        "accuracy": 0.9512,
        "precision": 0.8956,
        "recall": 0.8823,
        "f1": 0.8851,
        "auc": 0.9934
    },
    "ISIC2019": {
        "accuracy": 0.9234,  # Approximate from paper
        "precision": 0.8700,
        "recall": 0.8500,
        "f1": 0.8600,
        "auc": 0.9800
    }
}

print("Configuration loaded!")
print(f"\nPaper reference results for {DATASET_NAME}:")
if DATASET_NAME in PAPER_RESULTS:
    for metric, val in PAPER_RESULTS[DATASET_NAME].items():
        print(f"  {metric}: {val:.4f}")

## 2. Load Test Dataset

In [None]:
from src.data.datasets import (
    HAM10000Dataset, ISIC2018Dataset, ISIC2019Dataset,
    ISIC2020Dataset, PADUFES20Dataset
)

DATASET_CLASSES = {
    "HAM10000": HAM10000Dataset,
    "ISIC2018": ISIC2018Dataset,
    "ISIC2019": ISIC2019Dataset,
    "ISIC2020": ISIC2020Dataset,
    "PAD-UFES-20": PADUFES20Dataset,
}

DATASET_PATHS = {
    "HAM10000": (DATA_ROOT / "HAM10000", DATA_ROOT / "HAM10000" / "HAM10000_metadata.csv"),
    "ISIC2018": (DATA_ROOT / "ISIC2018" / "ISIC2018_Task3_Training_Input", DATA_ROOT / "ISIC2018" / "ISIC2018_Task3_Training_GroundTruth.csv"),
    "ISIC2019": (DATA_ROOT / "ISIC2019" / "ISIC_2019_Training_Input", DATA_ROOT / "ISIC2019" / "ISIC_2019_Training_GroundTruth.csv"),
    "ISIC2020": (DATA_ROOT / "ISIC2020" / "train", DATA_ROOT / "ISIC2020" / "train.csv"),
    "PAD-UFES-20": (DATA_ROOT / "PAD-UFES-20", DATA_ROOT / "PAD-UFES-20" / "metadata.csv"),
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_transform = get_val_transforms(img_size=IMAGE_SIZE)

root_dir, csv_path = DATASET_PATHS[DATASET_NAME]
dataset_class = DATASET_CLASSES[DATASET_NAME]

try:
    test_dataset = dataset_class(
        root_dir=str(root_dir),
        csv_path=str(csv_path),
        transform=val_transform
    )
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    print(f"‚úì Loaded {DATASET_NAME}: {len(test_dataset):,} samples")
except Exception as e:
    print(f"‚úó Failed: {e}")

## 3. Evaluate All Models

In [None]:
def load_and_evaluate(checkpoint_path, name):
    """Load a model and evaluate it."""
    print(f"\n{'='*60}")
    print(f"Evaluating: {name}")
    print(f"{'='*60}")
    
    if not checkpoint_path.exists():
        print(f"‚ö†Ô∏è Checkpoint not found: {checkpoint_path}")
        return None
    
    # Create model
    model = create_dscatnet(variant=MODEL_VARIANT, num_classes=NUM_CLASSES, pretrained=False)
    
    # Load weights
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
        print(f"  Loaded from epoch/round: {checkpoint.get('epoch', 'N/A')}")
    else:
        model.load_state_dict(checkpoint)
    
    model = model.to(device)
    model.eval()
    
    # Evaluate
    evaluator = ModelEvaluator(model, device, NUM_CLASSES, CLASS_NAMES)
    results = evaluator.evaluate(test_loader, compute_auc=True)
    
    print(f"  Accuracy: {results.accuracy:.4f}")
    print(f"  F1 (macro): {results.f1_macro:.4f}")
    print(f"  AUC-ROC: {results.auc_macro:.4f}" if results.auc_macro else "  AUC-ROC: N/A")
    
    return results

# Store all results
all_results = {}

# Evaluate centralized
cent_results = load_and_evaluate(CENTRALIZED_CHECKPOINT, "Centralized (Baseline)")
if cent_results:
    all_results["Centralized"] = cent_results

# Evaluate federated models
for name, path in FEDERATED_CHECKPOINTS.items():
    fed_results = load_and_evaluate(path, name)
    if fed_results:
        all_results[name] = fed_results

print(f"\n‚úì Evaluated {len(all_results)} models")

## 4. Head-to-Head Comparison Table

In [None]:
# Build comparison DataFrame
comparison_data = []

# Add paper reference
if DATASET_NAME in PAPER_RESULTS:
    paper = PAPER_RESULTS[DATASET_NAME]
    comparison_data.append({
        "Model": "üìö Paper (DSCATNet)",
        "Accuracy": paper["accuracy"],
        "Precision": paper["precision"],
        "Recall": paper["recall"],
        "F1-Score": paper["f1"],
        "AUC-ROC": paper["auc"],
        "Type": "Reference"
    })

# Add evaluated models
for name, results in all_results.items():
    model_type = "Centralized" if "Centralized" in name else "Federated"
    comparison_data.append({
        "Model": name,
        "Accuracy": results.accuracy,
        "Precision": results.precision_macro,
        "Recall": results.recall_macro,
        "F1-Score": results.f1_macro,
        "AUC-ROC": results.auc_macro if results.auc_macro else 0.0,
        "Type": model_type
    })

comparison_df = pd.DataFrame(comparison_data)

# Calculate difference from paper
if DATASET_NAME in PAPER_RESULTS:
    paper_acc = PAPER_RESULTS[DATASET_NAME]["accuracy"]
    comparison_df["Œî Accuracy"] = comparison_df["Accuracy"] - paper_acc

# Style the table
def color_diff(val):
    if pd.isna(val):
        return ''
    if val > 0:
        return 'color: green; font-weight: bold'
    elif val < -0.05:
        return 'color: red; font-weight: bold'
    else:
        return 'color: orange'

print("\n" + "="*80)
print("üìä HEAD-TO-HEAD COMPARISON")
print("="*80)

styled_df = comparison_df.style.format({
    "Accuracy": "{:.4f}",
    "Precision": "{:.4f}",
    "Recall": "{:.4f}",
    "F1-Score": "{:.4f}",
    "AUC-ROC": "{:.4f}",
    "Œî Accuracy": "{:+.4f}" if "Œî Accuracy" in comparison_df.columns else "{}"
})

if "Œî Accuracy" in comparison_df.columns:
    styled_df = styled_df.map(color_diff, subset=["Œî Accuracy"])

display(styled_df)


## 5. Visual Comparison

In [None]:
# Bar chart comparison
metrics = ["Accuracy", "Precision", "Recall", "F1-Score", "AUC-ROC"]
models = comparison_df["Model"].tolist()

fig, ax = plt.subplots(figsize=(14, 7))

x = np.arange(len(metrics))
width = 0.8 / len(models)
colors = plt.cm.get_cmap('Set2')(np.linspace(0, 1, len(models)))

for i, (model, color) in enumerate(zip(models, colors)):
    values = comparison_df[comparison_df["Model"] == model][metrics].values[0]
    offset = (i - len(models)/2 + 0.5) * width
    bars = ax.bar(x + offset, values, width, label=model, color=color, edgecolor='black', linewidth=0.5)
    
    # Add value labels
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{val:.2f}', ha='center', va='bottom', fontsize=8, rotation=45)

ax.set_ylabel('Score', fontsize=12)
ax.set_title(f'Model Comparison on {DATASET_NAME}', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metrics, fontsize=11)
ax.legend(loc='lower right', fontsize=10)
ax.set_ylim(0, 1.15)
ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.5, label='90% threshold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(project_root / 'outputs' / 'comparison_bar_chart.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# Radar chart comparison
from math import pi

metrics_radar = ["Accuracy", "Precision", "Recall", "F1-Score", "AUC-ROC"]
N = len(metrics_radar)
angles = [n / float(N) * 2 * pi for n in range(N)]
angles += angles[:1]  # Complete the loop

fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))

colors = ['#e74c3c', '#3498db', '#2ecc71', '#9b59b6']

for i, (model, color) in enumerate(zip(models[:4], colors)):  # Limit to 4 models
    values = comparison_df[comparison_df["Model"] == model][metrics_radar].values[0].tolist()
    values += values[:1]  # Complete the loop
    
    ax.plot(angles, values, 'o-', linewidth=2, label=model, color=color)
    ax.fill(angles, values, alpha=0.15, color=color)

ax.set_xticks(angles[:-1])
ax.set_xticklabels(metrics_radar, fontsize=11)
ax.set_ylim(0, 1)
ax.set_title('Model Performance Radar Chart', fontsize=14, fontweight='bold', pad=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))

plt.tight_layout()
plt.savefig(project_root / 'outputs' / 'comparison_radar.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Per-Class Performance Comparison

In [None]:
# Compare per-class F1 scores
if len(all_results) >= 2:
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Extract per-class metrics
    model_names = list(all_results.keys())[:2]  # Compare first two models
    
    for idx, (model_name, ax) in enumerate(zip(model_names, axes)):
        results = all_results[model_name]
        per_class = results.per_class_metrics
        
        classes = list(per_class.keys())
        precision = [per_class[c]['precision'] for c in classes]
        recall = [per_class[c]['recall'] for c in classes]
        support = [per_class[c]['support'] for c in classes]
        
        x = np.arange(len(classes))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, precision, width, label='Precision', color='#3498db', alpha=0.8)
        bars2 = ax.bar(x + width/2, recall, width, label='Recall', color='#e74c3c', alpha=0.8)
        
        # Add support labels on top
        for i, (p, r, s) in enumerate(zip(precision, recall, support)):
            ax.text(i, max(p, r) + 0.05, f'n={s:,}', ha='center', fontsize=8)
        
        ax.set_xlabel('Class', fontsize=12)
        ax.set_ylabel('Score', fontsize=12)
        ax.set_title(f'{model_name}', fontsize=12, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(classes, rotation=45, ha='right')
        ax.legend()
        ax.set_ylim(0, 1.2)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Per-Class Performance: Centralized vs Federated', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(project_root / 'outputs' / 'per_class_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Need at least 2 models to compare per-class metrics.")

## 7. Confusion Matrix Comparison

In [None]:
# Side-by-side confusion matrices
if len(all_results) >= 2:
    model_names = list(all_results.keys())[:2]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    
    for ax, model_name in zip(axes, model_names):
        cm = all_results[model_name].confusion_matrix
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_norm = np.nan_to_num(cm_norm)
        
        sns.heatmap(
            cm_norm,
            annot=True,
            fmt='.2%',
            cmap='Blues',
            xticklabels=CLASS_NAMES,
            yticklabels=CLASS_NAMES,
            ax=ax,
            cbar_kws={'label': 'Proportion'}
        )
        ax.set_xlabel('Predicted', fontsize=11)
        ax.set_ylabel('True', fontsize=11)
        ax.set_title(f'{model_name}', fontsize=12, fontweight='bold')
    
    plt.suptitle('Confusion Matrix Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(project_root / 'outputs' / 'confusion_matrix_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Need at least 2 models to compare confusion matrices.")


## 8. Convergence Analysis

In [None]:
# Load training histories
histories = {}

if CENTRALIZED_HISTORY.exists():
    with open(CENTRALIZED_HISTORY) as f:
        histories["Centralized"] = json.load(f)
    print(f"‚úì Loaded centralized history: {len(histories['Centralized'].get('epochs', histories['Centralized'].get('rounds', [])))} epochs")

if FEDERATED_HISTORY.exists():
    with open(FEDERATED_HISTORY) as f:
        histories["Federated"] = json.load(f)
    print(f"‚úì Loaded federated history: {len(histories['Federated'].get('rounds', histories['Federated'].get('epochs', [])))} rounds")

if not histories:
    print("‚ö†Ô∏è No history files found. Update CENTRALIZED_HISTORY and FEDERATED_HISTORY paths.")

In [None]:
# Plot convergence curves
if histories:
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    colors = {'Centralized': '#e74c3c', 'Federated': '#3498db'}
    
    # Loss curves
    for name, history in histories.items():
        x_key = 'epochs' if 'epochs' in history else 'rounds'
        x = history.get(x_key, range(len(history.get('train_loss', []))))
        
        if 'train_loss' in history:
            axes[0].plot(x, history['train_loss'], '-', label=f'{name} Train', 
                        color=colors.get(name, 'gray'), alpha=0.7)
        if 'val_loss' in history:
            axes[0].plot(x, history['val_loss'], '--', label=f'{name} Val',
                        color=colors.get(name, 'gray'), linewidth=2)
    
    axes[0].set_xlabel('Epoch / Round', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training Convergence: Loss', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy curves
    for name, history in histories.items():
        x_key = 'epochs' if 'epochs' in history else 'rounds'
        x = history.get(x_key, range(len(history.get('val_accuracy', []))))
        
        if 'train_accuracy' in history:
            axes[1].plot(x, history['train_accuracy'], '-', label=f'{name} Train',
                        color=colors.get(name, 'gray'), alpha=0.7)
        if 'val_accuracy' in history:
            axes[1].plot(x, history['val_accuracy'], '--', label=f'{name} Val',
                        color=colors.get(name, 'gray'), linewidth=2)
    
    axes[1].set_xlabel('Epoch / Round', fontsize=12)
    axes[1].set_ylabel('Accuracy', fontsize=12)
    axes[1].set_title('Training Convergence: Accuracy', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim(0, 1)
    
    plt.suptitle('Centralized vs Federated Training Convergence', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(project_root / 'outputs' / 'convergence_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No history data to plot.")

## 9. Statistical Summary

In [None]:
# Calculate performance gap
print("\n" + "="*70)
print("üìâ PERFORMANCE GAP ANALYSIS")
print("="*70)

if "Centralized" in all_results and len(all_results) > 1:
    cent = all_results["Centralized"]
    
    print(f"\nBaseline (Centralized):")
    print(f"  Accuracy: {cent.accuracy:.4f}")
    print(f"  F1-Score: {cent.f1_macro:.4f}")
    print(f"  AUC-ROC:  {cent.auc_macro:.4f}" if cent.auc_macro else "  AUC-ROC:  N/A")
    
    print(f"\nPerformance Gaps vs Centralized:")
    print("-"*70)
    print(f"{'Model':<30} {'Œî Accuracy':>12} {'Œî F1':>12} {'Œî AUC':>12}")
    print("-"*70)
    
    for name, results in all_results.items():
        if name == "Centralized":
            continue
        
        delta_acc = results.accuracy - cent.accuracy
        delta_f1 = results.f1_macro - cent.f1_macro
        delta_auc = (results.auc_macro - cent.auc_macro) if (results.auc_macro and cent.auc_macro) else 0
        
        # Color coding in text
        acc_str = f"{delta_acc:+.4f}" + (" ‚¨áÔ∏è" if delta_acc < -0.02 else " ‚úì" if delta_acc >= 0 else " ~")
        f1_str = f"{delta_f1:+.4f}" + (" ‚¨áÔ∏è" if delta_f1 < -0.02 else " ‚úì" if delta_f1 >= 0 else " ~")
        auc_str = f"{delta_auc:+.4f}" + (" ‚¨áÔ∏è" if delta_auc < -0.02 else " ‚úì" if delta_auc >= 0 else " ~")
        
        print(f"{name:<30} {acc_str:>12} {f1_str:>12} {auc_str:>12}")

# Compare with paper
if DATASET_NAME in PAPER_RESULTS:
    paper = PAPER_RESULTS[DATASET_NAME]
    print(f"\n\nComparison with Original Paper (DOI: 10.1371/journal.pone.0312598):")
    print("-"*70)
    print(f"{'Model':<30} {'Œî vs Paper Accuracy':>20} {'Œî vs Paper F1':>20}")
    print("-"*70)
    
    for name, results in all_results.items():
        delta_acc = results.accuracy - paper["accuracy"]
        delta_f1 = results.f1_macro - paper["f1"]
        
        print(f"{name:<30} {delta_acc:>+19.4f} {delta_f1:>+19.4f}")

## 10. Export Comparison Results

In [None]:
# Export comparison to JSON
export_comparison = {
    "timestamp": datetime.now().isoformat(),
    "dataset": DATASET_NAME,
    "paper_reference": {
        "doi": "10.1371/journal.pone.0312598",
        "title": "DSCATNet: Dual-Scale Cross-Attention Vision Transformer for Skin Cancer Classification",
        "metrics": PAPER_RESULTS.get(DATASET_NAME, {})
    },
    "models": {}
}

for name, results in all_results.items():
    export_comparison["models"][name] = {
        "accuracy": float(results.accuracy),
        "balanced_accuracy": float(results.balanced_accuracy),
        "precision_macro": float(results.precision_macro),
        "recall_macro": float(results.recall_macro),
        "f1_macro": float(results.f1_macro),
        "f1_weighted": float(results.f1_weighted),
        "auc_macro": float(results.auc_macro) if results.auc_macro else None,
        "per_class_metrics": results.per_class_metrics
    }

# Save
output_path = project_root / 'outputs' / f'comparison_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
with open(output_path, 'w') as f:
    json.dump(export_comparison, f, indent=2)

print(f"\n‚úì Comparison exported to: {output_path}")

# Also save comparison DataFrame
csv_path = project_root / 'outputs' / f'comparison_table_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
comparison_df.to_csv(csv_path, index=False)
print(f"‚úì Comparison table saved to: {csv_path}")

---

## Key Findings Summary

### Research Questions Addressed:

| Question | Finding |
|----------|----------|
| **FL vs Centralized Gap** | Measured accuracy/F1 difference above |
| **Non-IID Impact** | Compare different Œ± values in FEDERATED_CHECKPOINTS |
| **Class-wise Degradation** | Minority classes typically suffer more |
| **Convergence Speed** | FL typically requires more rounds than epochs |

### Reference

**Original DSCATNet Paper:**
> Wei et al. (2024). *DSCATNet: Dual-Scale Cross-Attention Vision Transformer for Skin Cancer Classification*. PLOS ONE.  
> DOI: [10.1371/journal.pone.0312598](https://doi.org/10.1371/journal.pone.0312598)

---