In [2]:
# summarize_rounds.py
import os, json, csv, re
from collections import defaultdict

INPUT = "./eval_results.raw"
OUTPUT = "mean_test_acc_per_round.csv"

if not os.path.exists(INPUT):
    print(f"[ERR] 현재 경로에 '{INPUT}' 파일이 없습니다.")
    raise SystemExit(1)

def parse_client_id_from_role(role):
    if role is None:
        return None
    m = re.search(r"Client\s*#\s*(\d+)", str(role))
    return int(m.group(1)) if m else role

def is_better(new, old):
    # val_loss가 더 작을수록 우선, 같으면 test_acc가 큰 쪽
    if new["val_loss"] is None and old["val_loss"] is None:
        return new["test_acc"] > old["test_acc"]
    if new["val_loss"] is None:
        return False
    if old["val_loss"] is None:
        return True
    if new["val_loss"] < old["val_loss"] - 1e-12:
        return True
    if abs(new["val_loss"] - old["val_loss"]) <= 1e-12:
        return new["test_acc"] > old["test_acc"]
    return False

best_per_round_client = {}   # key=(round, client_id)
rounds_seen = set()

with open(INPUT, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            rec = json.loads(line)
        except json.JSONDecodeError:
            # 잘린/포맷이 다른 줄은 건너뛴다
            continue

        rr = rec.get("Results_raw") or {}
        rnd = rec.get("Round", rr.get("round"))
        cid = rr.get("client_id")
        if cid is None:
            cid = parse_client_id_from_role(rec.get("Role"))

        test_acc = rr.get("test_acc")
        val_loss = rr.get("val_loss")

        if rnd is None or cid is None or test_acc is None:
            continue

        rnd = int(rnd)
        key = (rnd, cid)
        cand = {
            "val_loss": float(val_loss) if val_loss is not None else None,
            "test_acc": float(test_acc),
        }
        cur = best_per_round_client.get(key)
        if cur is None or is_better(cand, cur):
            best_per_round_client[key] = cand
        rounds_seen.add(rnd)

# 라운드별 평균 test_acc
round_to_accs = defaultdict(list)
for (rnd, _cid), info in best_per_round_client.items():
    round_to_accs[rnd].append(info["test_acc"])

rows = []
for rnd in sorted(round_to_accs):
    accs = round_to_accs[rnd]
    mean_acc = sum(accs) / len(accs) if accs else float("nan")
    rows.append({"round": rnd, "n_clients": len(accs), "mean_test_acc": mean_acc})

print("Rounds found:", sorted(rounds_seen))
for r in rows:
    print(f"Round {r['round']:>4}: n={r['n_clients']:>3}  mean_test_acc={r['mean_test_acc']:.6f}")

with open(OUTPUT, "w", newline="", encoding="utf-8") as wf:
    writer = csv.DictWriter(wf, fieldnames=["round", "n_clients", "mean_test_acc"])
    writer.writeheader()
    writer.writerows(rows)

print(f"Saved: {OUTPUT}")


Rounds found: [24, 49, 74, 99, 124, 149, 174, 199, 224, 249]
Round   24: n= 53  mean_test_acc=0.595283
Round   49: n= 53  mean_test_acc=0.618868
Round   74: n= 53  mean_test_acc=0.621698
Round   99: n= 53  mean_test_acc=0.628774
Round  124: n= 53  mean_test_acc=0.639151
Round  149: n= 53  mean_test_acc=0.655660
Round  174: n= 53  mean_test_acc=0.640566
Round  199: n= 53  mean_test_acc=0.631132
Round  224: n= 53  mean_test_acc=0.642925
Round  249: n= 53  mean_test_acc=0.633019
Saved: mean_test_acc_per_round.csv
