In [4]:
import os
import json
import numpy as np
from sklearn.metrics import balanced_accuracy_score

def choose_best_threshold(labels, scores):
    best_bacc = 0.0
    best_thresh = 0.0
    thresholds = [np.percentile(scores, p) for p in np.arange(0, 100, 0.2)]

    for thresh in thresholds:
        preds = [1 if score > thresh else 0 for score in scores]
        bacc = balanced_accuracy_score(labels, preds)
        if bacc >= best_bacc:
            best_bacc = bacc
            best_thresh = thresh

    return best_thresh, best_bacc

# 평가 대상 디렉토리 및 출력 파일
input_dir = "../results/aggre_xsum"
output_file = "xsum_FENICE_tuning_method.txt"

result_entries = []

# 모든 JSON 파일 처리
for filename in os.listdir(input_dir):
    if not filename.endswith(".json"):
        continue

    file_path = os.path.join(input_dir, filename)
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except json.JSONDecodeError:
        print(f"⚠️ JSON decode error: {filename}")
        continue

    val_scores = [item["score"] for item in data if item.get("cut") == "val"]
    val_labels = [item["label"] for item in data if item.get("cut") == "val"]
    test_scores = [item["score"] for item in data if item.get("cut") == "test"]
    test_labels = [item["label"] for item in data if item.get("cut") == "test"]

    if not val_scores or not test_scores:
        print(f"⚠️ Skipped due to missing val/test in: {filename}")
        continue

    threshold, val_bacc = choose_best_threshold(val_labels, val_scores)
    test_preds = [1 if s > threshold else 0 for s in test_scores]
    test_bacc = balanced_accuracy_score(test_labels, test_preds)

    result_entries.append({
        "filename": filename,
        "threshold": threshold,
        "val_bacc": val_bacc,
        "test_bacc": test_bacc
    })

# test_bacc 기준으로 내림차순 정렬
result_entries.sort(key=lambda x: x["test_bacc"], reverse=True)

# 결과 저장
with open(output_file, "w", encoding="utf-8") as fout:
    for entry in result_entries:
        fout.write(f"{entry['filename']}\n")
        fout.write(f"  Best Threshold (val): {entry['threshold']:.4f}\n")
        fout.write(f"  Val Balanced Accuracy: {entry['val_bacc']:.4f}\n")
        fout.write(f"  Test Balanced Accuracy: {entry['test_bacc']:.4f}\n\n")

print(f"✅ 결과가 '{output_file}'에 test_bacc 기준 내림차순으로 저장되었습니다.")


⚠️ JSON decode error: fenice_wr0p3_wb0p7_wcc1_wc1_wm0_ww1_k1.json
⚠️ JSON decode error: fenice_wr0p7_wb0p3_wcc1_wc1_wm0_ww1_k1.json
⚠️ JSON decode error: fenice_wr0p5_wb0p5_wcc1_wc1_wm0_ww1_k2.json
⚠️ JSON decode error: fenice_wr1_wb0_wcc1_wc1_wm0_ww1_k1.json
⚠️ JSON decode error: fenice_wr0p5_wb0p5_wcc1_wc1_wm0_ww1_k3.json
✅ 결과가 'xsum_FENICE_tuning_method.txt'에 test_bacc 기준 내림차순으로 저장되었습니다.
