# 02 - Federated vs Central Training Comparison

This notebook loads metrics from central and federated GNN training,
plots convergence curves, and compares performance.

In [None]:
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

METRICS_DIR = Path("..") / "results" / "metrics"
FIGURES_DIR = Path("..") / "results" / "figures"
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# Load central GNN training history
history_path = METRICS_DIR / "gnn_training_history.json"
if history_path.exists():
    history = json.load(open(history_path))
    print(f"Central training: {len(history['train_loss'])} epochs")
    print(f"Best val F1: {max(history['val_f1']):.4f}")
else:
    print("No central training history found. Run scripts/run_all.py first.")
    history = None

In [None]:
# Plot central training curves
if history:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
    ax1.plot(history['train_loss'], 'b-o', markersize=3)
    ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Central GNN - Training Loss')
    ax2.plot(history['val_f1'], 'g-o', markersize=3, label='F1')
    ax2.plot(history['val_roc_auc'], 'r-s', markersize=3, label='ROC-AUC')
    ax2.set_xlabel('Epoch'); ax2.set_ylabel('Score'); ax2.set_title('Central GNN - Validation Metrics')
    ax2.legend()
    fig.tight_layout()
    fig.savefig(FIGURES_DIR / 'central_training_curves.png', dpi=150)
    plt.show()

In [None]:
# Load FL round metrics
fl_path = METRICS_DIR / "fl_rounds.json"
if fl_path.exists():
    fl_data = json.load(open(fl_path))
    rounds = fl_data.get('rounds', [])
    comm = fl_data.get('comm_bytes', [])
    print(f"FL rounds: {len(rounds)}")
    if rounds:
        print(f"Final round metrics: {rounds[-1]}")
    print(f"Total communication: {sum(comm) / 1e6:.2f} MB")
else:
    print("No FL metrics found. Run federated training first.")
    rounds, comm = [], []

In [None]:
# Plot FL convergence and communication cost
if rounds:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
    r_nums = [r.get('round', i+1) for i, r in enumerate(rounds)]
    f1s = [r.get('f1', 0) for r in rounds]
    ax1.plot(r_nums, f1s, 'g-o')
    ax1.set_xlabel('FL Round'); ax1.set_ylabel('F1'); ax1.set_title('Federated GNN - F1 per Round')
    if comm:
        ax2.bar(range(1, len(comm)+1), [c/1e6 for c in comm])
        ax2.set_xlabel('Round'); ax2.set_ylabel('MB'); ax2.set_title('Communication Cost per Round')
    fig.tight_layout()
    fig.savefig(FIGURES_DIR / 'federated_convergence.png', dpi=150)
    plt.show()

In [None]:
# Side-by-side model comparison table
results = {}
for name in ['rf', 'mlp', 'central_gnn', 'federated_gnn']:
    p = METRICS_DIR / f'{name}_metrics.json'
    if p.exists():
        results[name] = json.load(open(p))

if results:
    print(f"{'Model':<22} {'Precision':>10} {'Recall':>10} {'F1':>10} {'ROC-AUC':>10} {'Inf(ms)':>10}")
    print('-' * 75)
    for name, m in results.items():
        print(f"{name:<22} {m.get('precision',0):>10.4f} {m.get('recall',0):>10.4f} "
              f"{m.get('f1',0):>10.4f} {m.get('roc_auc',0):>10.4f} {m.get('inference_ms',0):>10.2f}")
else:
    print('No metrics found yet.')

In [None]:
# Bar chart comparison
if results:
    models = list(results.keys())
    metrics_to_plot = ['precision', 'recall', 'f1', 'roc_auc']
    x = np.arange(len(models))
    width = 0.2
    fig, ax = plt.subplots(figsize=(8, 5))
    for i, m in enumerate(metrics_to_plot):
        vals = [results[n].get(m, 0) for n in models]
        ax.bar(x + i*width, vals, width, label=m.upper())
    ax.set_xticks(x + width*1.5)
    ax.set_xticklabels([n.replace('_',' ').title() for n in models], rotation=15)
    ax.set_ylim(0, 1.05); ax.legend(); ax.set_title('Model Comparison')
    fig.tight_layout()
    fig.savefig(FIGURES_DIR / 'model_comparison_bar.png', dpi=150)
    plt.show()