# LOSO 실험 비교 분석

여러 체크포인트(`loso_*`)에서 생성된 LOSO 결과를 한 번에 불러와 성능을 비교합니다. 필요 시 `RUNS` 설정을 수정해 분석 대상을 조정하세요.


In [None]:
from pathlib import Path
import json

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display

sns.set_theme(style="whitegrid", context="notebook")

NOTEBOOK_DIR = Path.cwd()
if (NOTEBOOK_DIR / "checkpoints").exists():
    PROJECT_ROOT = NOTEBOOK_DIR
elif (NOTEBOOK_DIR.parent / "checkpoints").exists():
    PROJECT_ROOT = NOTEBOOK_DIR.parent
else:
    raise FileNotFoundError("checkpoints 디렉터리를 찾을 수 없습니다. 노트북 위치를 확인해 주세요.")

CHECKPOINT_BASE = PROJECT_ROOT / "checkpoints"
CLASS_LABELS = {
    0: "Correct",
    1: "Knee Valgus",
    2: "Butt Wink",
    3: "Excessive Lean",
    4: "Partial Squat",
}

RUNS = [
    {"id": "loso_nossl", "name": "No-SSL", "path": CHECKPOINT_BASE / "loso_nossl"},
    {"id": "loso_sc", "name": "SimCLR", "path": CHECKPOINT_BASE / "loso_sc"},
    {"id": "loso_ss", "name": "SimSiam", "path": CHECKPOINT_BASE / "loso_ss"},
    {"id": "loso_ss_focalyes", "name": "SimSiam+Focal", "path": CHECKPOINT_BASE / "loso_ss_focalyes"},
]

print(f"프로젝트 루트: {PROJECT_ROOT}")
print("분석 대상:")
for run in RUNS:
    print(f"  - {run['name']}: {run['path']}")


In [None]:
def find_report_file(subject_dir: Path):
    candidates = sorted(subject_dir.glob("classification_report*.json"))
    return candidates[0] if candidates else None

fold_rows = []
class_rows = []
prediction_frames = []
missing_reports = []

for order, run in enumerate(RUNS):
    run_dir = run["path"]
    if not run_dir.exists():
        print(f"[경고] {run['name']} 경로가 없어 건너뜁니다: {run_dir}")
        continue

    subject_dirs = sorted([
        p for p in run_dir.iterdir() if p.is_dir() and p.name.startswith("subject")
    ], key=lambda p: int(''.join(filter(str.isdigit, p.name)) or 0))

    for subject_dir in subject_dirs:
        subject_id = subject_dir.name
        report_path = find_report_file(subject_dir)
        if report_path is None:
            missing_reports.append((run["name"], subject_id))
            continue

        with report_path.open() as f:
            report = json.load(f)

        macro_avg = report.get("macro avg", {})
        accuracy = report.get("accuracy", np.nan)
        macro_precision = macro_avg.get("precision", np.nan)
        macro_recall = macro_avg.get("recall", np.nan)
        macro_f1 = macro_avg.get("f1-score", np.nan)

        fold_rows.append(
            {
                "run_id": run["id"],
                "run_name": run["name"],
                "run_order": order,
                "subject": subject_id,
                "Accuracy": accuracy,
                "Macro Precision": macro_precision,
                "Macro Recall": macro_recall,
                "Macro F1": macro_f1,
                "Balanced Accuracy": macro_recall,
            }
        )

        for cls_key, metrics in report.items():
            if not cls_key.isdigit():
                continue
            cls_id = int(cls_key)
            class_rows.append(
                {
                    "run_id": run["id"],
                    "run_name": run["name"],
                    "run_order": order,
                    "subject": subject_id,
                    "class_id": cls_id,
                    "class_name": CLASS_LABELS.get(cls_id, f"Class {cls_id}"),
                    "precision": metrics.get("precision", np.nan),
                    "recall": metrics.get("recall", np.nan),
                    "f1": metrics.get("f1-score", np.nan),
                    "support": metrics.get("support", np.nan),
                }
            )

        prediction_path = subject_dir / "sample_predictions.csv"
        if prediction_path.exists():
            df = pd.read_csv(prediction_path)
            df["run_name"] = run["name"]
            df["run_id"] = run["id"]
            df["run_order"] = order
            df["subject"] = subject_id
            prediction_frames.append(df)

fold_df = pd.DataFrame(fold_rows)
class_df = pd.DataFrame(class_rows)
pred_df = pd.concat(prediction_frames, ignore_index=True) if prediction_frames else pd.DataFrame()

print(f"총 fold 레코드 수: {len(fold_df)}")
print(f"총 클래스 레코드 수: {len(class_df)}")
print(f"총 예측 샘플 수: {len(pred_df)}")
if missing_reports:
    print("[경고] 누락된 리포트:")
    for run_name, subject in missing_reports:
        print(f"  - {run_name}: {subject}")


## 1. 실험별 성능 요약
각 지표별로 실험 평균/표준편차/최고/최저를 비교합니다.


In [None]:
if fold_df.empty:
    raise RuntimeError("fold_df가 비어 있습니다. 상단 설정을 확인하세요.")

summary_rows = []
metrics = ["Accuracy", "Macro F1", "Balanced Accuracy"]

for metric in metrics:
    grouped = fold_df.groupby(["run_order", "run_name"])[metric]
    for (order, run_name), values in grouped:
        summary_rows.append(
            {
                "Run": run_name,
                "Metric": metric,
                "Mean": values.mean(),
                "Std": values.std(ddof=1),
                "Min": values.min(),
                "Max": values.max(),
                "Mean ± Std": f"{values.mean():.4f} ± {values.std(ddof=1):.4f}",
            }
        )

run_summary = pd.DataFrame(summary_rows).sort_values(["Metric", "Run"])
display(run_summary.style.format({"Mean": "{:.3f}", "Std": "{:.3f}", "Min": "{:.3f}", "Max": "{:.3f}"}))


## 2. 피험자별 Macro F1 비교
각 실험의 피험자별 Macro F1을 테이블과 히트맵으로 확인합니다.


In [None]:
macro_f1_pivot = fold_df.pivot(index="subject", columns="run_name", values="Macro F1")
display(macro_f1_pivot.sort_index().style.format("{:.3f}"))

plt.figure(figsize=(10, 4 + 0.5 * len(macro_f1_pivot)))
sns.heatmap(
    macro_f1_pivot.sort_index(),
    annot=True,
    fmt=".2f",
    cmap="YlGnBu",
    vmin=0,
    vmax=1,
    cbar_kws={"label": "Macro F1"},
)
plt.title("Subject vs. Run Macro F1")
plt.ylabel("Subject")
plt.xlabel("Run")
plt.tight_layout()
plt.show()


## 3. 실험별 Macro F1 분포
피험자 분포를 박스플롯으로 비교합니다.


In [None]:
plt.figure(figsize=(8, 4))
sns.boxplot(data=fold_df, x="run_name", y="Macro F1", palette="Set2")
sns.swarmplot(data=fold_df, x="run_name", y="Macro F1", color="black", alpha=0.6)
plt.ylim(0, 1)
plt.xlabel("Run")
plt.ylabel("Macro F1")
plt.title("Macro F1 distribution per Run")
plt.tight_layout()
plt.show()


## 4. 클래스별 Precision / Recall / F1 비교


In [None]:
class_avg = (
    class_df.groupby(["run_order", "run_name", "class_name"])[["precision", "recall", "f1"]]
    .mean()
    .reset_index()
    .sort_values(["class_name", "run_order"])
)
display(class_avg.style.format({"precision": "{:.3f}", "recall": "{:.3f}", "f1": "{:.3f}"}))

plt.figure(figsize=(10, 5))
sns.barplot(data=class_avg, x="class_name", y="f1", hue="run_name", palette="viridis")
plt.ylim(0, 1)
plt.ylabel("F1-Score")
plt.xlabel("Class")
plt.title("Class-wise F1 across Runs")
plt.legend(title="Run")
plt.tight_layout()
plt.show()


## 5. Confusion Matrix 비교 (모든 fold 예측 포함)
각 실험의 `sample_predictions.csv`를 합쳐 혼동행렬을 생성합니다.


In [None]:
if pred_df.empty:
    print("sample_predictions.csv가 없어 혼동행렬을 생성할 수 없습니다.")
else:
    pred_df = pred_df.copy()
    pred_df["true_idx"] = pred_df["true_idx"].astype(int)
    pred_df["pred_idx"] = pred_df["pred_idx"].astype(int)
    class_ids = sorted(CLASS_LABELS.keys())
    n_runs = pred_df["run_name"].nunique()
    fig_cols = min(2, n_runs)
    fig_rows = int(np.ceil(n_runs / fig_cols))
    fig, axes = plt.subplots(fig_rows, fig_cols, figsize=(6 * fig_cols, 5 * fig_rows))
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = axes.flatten()

    for ax in axes[n_runs:]:
        ax.axis("off")

    for ax, (run_name, group) in zip(axes, pred_df.groupby("run_name")):
        cm = (
            pd.crosstab(group["true_idx"], group["pred_idx"])
            .reindex(index=class_ids, columns=class_ids, fill_value=0)
        )
        cm.index = [CLASS_LABELS[i] for i in class_ids]
        cm.columns = cm.index
        sns.heatmap(cm, annot=True, fmt="d", cmap="rocket_r", cbar=False, ax=ax)
        ax.set_title(run_name)
        ax.set_ylabel("True")
        ax.set_xlabel("Pred")

    fig.suptitle("Confusion Matrices by Run", y=1.02, fontsize=14)
    plt.tight_layout()
    plt.show()


## 6. 추가 아이디어
- 상단 `RUNS` 리스트에 새로운 체크포인트 경로를 추가하면 자동으로 분석됩니다.
- 필요한 경우 `metrics` 리스트를 수정해 다른 지표(예: Weighted F1)를 비교할 수 있습니다.
