In [2]:

import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

BASE = Path.cwd().resolve().parent / "runs"
print("Looking for runs in:", BASE)

Looking for runs in: /Users/yuandouwang/Documents/projects/federated_lab/runs


In [3]:
def load_all_runs(runs_root: Path) -> pd.DataFrame:
    runs_root = runs_root if runs_root.name == "runs" else (runs_root / "runs")
    rows = []
    if not runs_root.exists():
        print("No runs folder found:", runs_root)
        return pd.DataFrame()
    for exp_dir in sorted(runs_root.glob("*")):
        if not exp_dir.is_dir():
            continue
        for ts_dir in sorted(exp_dir.glob("*")):
            csv_path = ts_dir / "fl_log.csv"
            if csv_path.exists():
                try:
                    df = pd.read_csv(csv_path)
                    df["exp"] = exp_dir.name
                    df["ts"] = ts_dir.name
                    rows.append(df)
                except Exception as e:
                    print("Failed to read", csv_path, e)
    if not rows:
        return pd.DataFrame()
    return pd.concat(rows, ignore_index=True)

df = load_all_runs(BASE)
print("Loaded rows:", len(df))
df.head(3)


Loaded rows: 170


Unnamed: 0,round,global_acc,manifest_cid,node_id,samples,loss,acc,claimed_acc,eval_acc,acc_diff,...,balance,stake_penalty,stake,reputation,malicious_detected,committee,is_malicious,strategy,exp,ts
0,0,0.7703,Qm00000022,0,7682,0.353403,0.91415,0.9141,0.6532,0.2609,...,74.568,0.0,174.568,19.5688,1,1,0,none,benign,20250813-223511
1,0,0.7703,Qm00000022,1,6725,0.203163,0.945145,0.9451,0.3488,0.5963,...,39.7479,0.0,139.7479,19.0325,1,1,0,none,benign,20250813-223511
2,0,0.7703,Qm00000022,2,1321,0.341979,0.912869,0.9129,0.3931,0.5198,...,26.0427,0.0,126.0427,18.4381,1,1,0,none,benign,20250813-223511


In [5]:
# Federated evaluation across rounds & nodes using base + delta reconstruction.
# - Auto-discovers files like:
#     models/global_round_{r}_base.pt (or fallback: global_round_{r}.pt)
#     updates/round_{r}_node_{i}_delta.pt
# - Reconstructs each node's local model: base + delta
# - Evaluates accuracy on a test loader (replace with your real test loader)
# - Summarizes results in a DataFrame and saves CSV
#
# If no artifacts are found at the default path, this will run a synthetic demo
# so you can see the full pipeline and results format.
#
# === HOW TO ADAPT ===
# 1) Set RUN_DIR to your run folder containing "models" and "updates".
# 2) Replace MODEL_FACTORY() with your actual model constructor.
# 3) Replace build_test_loader() with your real test DataLoader.
# 4) If your .pt files are full models (not state_dict), adjust load logic.
#
# Files will be saved to /mnt/data, including results CSV.

import os
import re
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from typing import Dict, Any

# -----------------------------
# Config: paths
# -----------------------------
# Try to auto-detect a likely run dir under /mnt/data/runs/benign/*

CANDIDATE = None
runs_root = BASE / "runs" / "benign"
if runs_root.exists():
    # pick the most recently modified child that has both models/ and updates/
    candidates = []
    for p in runs_root.glob("*"):
        if p.is_dir() and (p / "models").exists() and (p / "updates").exists():
            candidates.append((p.stat().st_mtime, p))
    if candidates:
        CANDIDATE = sorted(candidates, key=lambda x: x[0], reverse=True)[0][1]

# Allow the user to change this path if needed:
RUN_DIR = CANDIDATE if CANDIDATE else (BASE / "benign" / "20250813-224933")
MODELS_DIR = RUN_DIR / "models"
UPDATES_DIR = RUN_DIR / "updates"

# -----------------------------
# Config: model + test loader
# -----------------------------
NUM_CLASSES = 10
IMAGE_SHAPE = (1, 28, 28)

class SimpleNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(IMAGE_SHAPE[0]*IMAGE_SHAPE[1]*IMAGE_SHAPE[2], 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

def MODEL_FACTORY():
    # TODO: replace with your real model constructor
    return SimpleNet(num_classes=NUM_CLASSES)

def build_test_loader():
    # TODO: replace with your real test dataset/loader
    X_test = torch.randn(256, *IMAGE_SHAPE)
    y_test = torch.randint(0, NUM_CLASSES, (256,))
    return DataLoader(TensorDataset(X_test, y_test), batch_size=64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Helpers
# -----------------------------
def reconstruct_full_state(base_state: Dict[str, torch.Tensor],
                           delta_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    full = {}
    for k, base_tensor in base_state.items():
        if k in delta_state:
            full[k] = base_tensor + delta_state[k]
        else:
            full[k] = base_tensor
    for k, v in delta_state.items():
        if k not in full:
            full[k] = v
    return full

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> Dict[str, Any]:
    model.eval()
    total, correct, total_loss = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.numel()
        total_loss += loss.item() * yb.size(0)
    acc = correct / total if total > 0 else 0.0
    avg_loss = total_loss / total if total > 0 else 0.0
    return {"acc": acc, "loss": avg_loss, "n": total}

def find_rounds_nodes(updates_dir: Path):
    # pattern: round_{r}_node_{i}_delta.pt
    pat = re.compile(r"round_(\d+)_node_(\d+)_delta\.pt$")
    mapping = {}
    if not updates_dir.exists():
        return mapping
    for f in updates_dir.glob("*.pt"):
        m = pat.search(f.name)
        if m:
            r = int(m.group(1))
            i = int(m.group(2))
            mapping.setdefault(r, []).append((i, f))
    # sort nodes by id per round
    for r in mapping:
        mapping[r] = sorted(mapping[r], key=lambda x: x[0])
    return dict(sorted(mapping.items(), key=lambda x: x[0]))

def load_base_for_round(models_dir: Path, r: int):
    # prefer global_round_{r}_base.pt else fallback to global_round_{r}.pt
    base1 = models_dir / f"global_round_{r-1}_base.pt"
    base2 = models_dir / f"global_round_{r-1}.pt"
    if base1.exists():
        return torch.load(base1, map_location="cpu")
    if base2.exists():
        return torch.load(base2, map_location="cpu")
    return None

# -----------------------------
# Main evaluation
# -----------------------------
test_loader = build_test_loader()

results = []
artifact_found = MODELS_DIR.exists() and UPDATES_DIR.exists()
if artifact_found:
    rounds = find_rounds_nodes(UPDATES_DIR)
    if not rounds:
        print(f"[Info] No update deltas found under {UPDATES_DIR}. Running synthetic demo instead.")
        artifact_found = False

if artifact_found:
    print(f"[Using artifacts from] {RUN_DIR}")
    for r, node_list in rounds.items():
        base_state = load_base_for_round(MODELS_DIR, r)
        if base_state is None:
            print(f"[Warn] No base model found for round {r} in {MODELS_DIR}, skipping this round.")
            continue
        for node_id, delta_path in node_list:
            try:
                delta_state = torch.load(delta_path, map_location="cpu")
            except Exception as e:
                print(f"[Error] Could not load delta {delta_path}: {e}")
                continue

            # reconstruct
            reconstructed_state = reconstruct_full_state(base_state, delta_state)
            model = MODEL_FACTORY().to(device)
            try:
                missing, unexpected = model.load_state_dict(reconstructed_state, strict=False)
            except Exception as e:
                print(f"[Error] load_state_dict failed for round {r}, node {node_id}: {e}")
                continue

            # evaluate
            metrics = evaluate(model, test_loader)
            results.append({
                "round": r,
                "node": node_id,
                "acc": metrics["acc"],
                "loss": metrics["loss"],
                "n": metrics["n"],
                "missing_keys": len(missing),
                "unexpected_keys": len(unexpected),
                "delta_file": str(delta_path.name)
            })
else:
    # Synthetic demo so the pipeline is visible and results are produced
    print("[Demo] No real artifacts detected. Running a synthetic demonstration.")
    # Create a base model
    base_model = MODEL_FACTORY().to(device)
    base_state = base_model.state_dict()
    # Simulate 2 rounds × 3 nodes with small random deltas
    for r in range(2):
        for node_id in range(3):
            delta_state = {k: torch.randn_like(v) * 0.01 for k, v in base_state.items()}
            reconstructed_state = reconstruct_full_state(base_state, delta_state)
            model = MODEL_FACTORY().to(device)
            model.load_state_dict(reconstructed_state, strict=False)
            metrics = evaluate(model, test_loader)
            results.append({
                "round": r,
                "node": node_id,
                "acc": metrics["acc"],
                "loss": metrics["loss"],
                "n": metrics["n"],
                "missing_keys": 0,
                "unexpected_keys": 0,
                "delta_file": f"synthetic_round_{r}_node_{node_id}_delta"
            })

# -----------------------------
# Save & show results
# -----------------------------
df = pd.DataFrame(results).sort_values(["round", "node"]).reset_index(drop=True)
print(df)
# out_csv = "/mnt/data/per_node_accuracy_by_round.csv"
# df.to_csv(out_csv, index=False)

# Display table to the user in the UI
# %pip install caas_jupyter_tools

# from caas_jupyter_tools import display_dataframe_to_user
# display_dataframe_to_user("Per-node accuracy by round", df)

# print(f"\nSaved results to: {out_csv}")
# print(f"RUN_DIR used: {RUN_DIR}")


[Using artifacts from] /Users/yuandouwang/Documents/projects/federated_lab/runs/benign/20250813-224933
[Warn] No base model found for round 0 in /Users/yuandouwang/Documents/projects/federated_lab/runs/benign/20250813-224933/models, skipping this round.
    round  node       acc      loss    n  missing_keys  unexpected_keys  \
0       1     0  0.109375  2.308768  256             4                2   
1       1     1  0.125000  2.293167  256             4                2   
2       1     2  0.089844  2.342567  256             4                2   
3       1     3  0.109375  2.336424  256             4                2   
4       1     4  0.097656  2.310634  256             4                2   
5       1     5  0.121094  2.339877  256             4                2   
6       1     6  0.082031  2.341358  256             4                2   
7       1     7  0.101562  2.321245  256             4                2   
8       1     8  0.101562  2.324399  256             4                2