In [None]:
from pathlib import Path

RESULTS_ROOT = Path("../results")
assert RESULTS_ROOT.exists()

In [None]:
# need to recompute decontextual f1
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef

def recompute_scores(results, task="decontextual"):
    y_pred = [
        sample[task]["score_comparator"] >= sample["decontextual"]["score_target"]
        for sample in results["samples"]
    ]
    y_true = [
        sample[task]["logp_comparator"] >= sample["decontextual"]["logp_target"] 
        for sample in results["samples"]
    ]
#     accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    return f1, mcc

def format_row(scores):
    scores_strs = [
        f"{score:.2f}".lstrip("0")
        if not isinstance(score, str) else score
        for score in scores
    ]
    scores_strs = [
        f"${score_str}$"
        if not isinstance(score, str) and score < 0
        else score_str
        for score_str, score in zip(scores_strs, scores)
    ]
    return " & ".join(scores_strs)

In [None]:
# control task results

import json

def format_row_for_control(control, fact_results_file, bias_results_file):
    with fact_results_file.open("r") as handle:
        fact_results = json.load(handle)
    with bias_results_file.open("r") as handle:
        bias_results = json.load(handle)

    scores = (
        control,
#         *recompute_scores(fact_results),
        fact_results["metrics"]["decontextual"]["f1"],
        fact_results["metrics"]["decontextual"]["mcc"],
        fact_results["metrics"]["contextual"]["f1"],
        fact_results["metrics"]["contextual"]["mcc"],
        "---" if control == "Task" else bias_results["metrics"]["probe_recall_k"],
        bias_results["metrics"]["f1"],
        bias_results["metrics"]["mcc"],
    )
    print(format_row(scores) + r" \\")

format_row_for_control(
    "Task",
    RESULTS_ROOT / "icml_eval_fact_cls_gptj_control/linear/1/fact_cls_layer_26_control_task.json",
    RESULTS_ROOT / "icml_eval_bias_cls_gptj_control_task/linear/11/error_cls_layer_23_control_task.json",
)
format_row_for_control(
    "Model",
    RESULTS_ROOT / "icml_eval_fact_cls_gptj_random/linear/1/fact_cls_layer_26.json",
    RESULTS_ROOT / "icml_eval_bias_cls_gptj_random/linear/11/error_cls_layer_23.json",
)
format_row_for_control(
    "Editor",
    RESULTS_ROOT / "icml_eval_fact_cls_gptj_identity/identity/1/fact_cls_layer_26.json",
    RESULTS_ROOT / "icml_eval_bias_cls_gptj_identity/identity/11/error_cls_layer_23.json"
)

In [None]:
bias_results.keys()

In [None]:
bias_results["metrics"].keys()