In [1]:
import sys
sys.path.insert(0, "../src")
import re
from pathlib import Path
from functools import partial
from collections import defaultdict

import pandas as pd
import numpy as np
from joblib import Parallel, delayed

import constants
from rte.aggregate import agg_predict_proba, agg_predict
from gen.util import read_data, write_jsonl
from gen.special import entropy3

# Init

In [2]:
error_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/thesis/errors/scifact/sentence")
sent_pls = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/thesis/predictions/sent/scifact")

In [3]:
def preprocess_predictions(fn):
    df = pd.DataFrame(read_data(fn))
    rnd = partial(np.round, decimals=4)
    df = df.assign(
        predicted_label=df["predicted_label"].map(constants.LABEL2ID),
        predicted_proba=df["predicted_proba"].apply(rnd).apply(lambda x: x.tolist())
    )
    
    get_proba = partial(agg_predict_proba, return_proba=True)
    df_mean = (
        df
        .groupby("claim_id", sort=False)[["predicted_proba"]]
        .agg({"predicted_proba": get_proba})
        .rename(columns={"predicted_proba": "mean_proba"})
    )
    df_mean["mean_proba"] = df_mean["mean_proba"].apply(rnd).apply(lambda x: x.tolist())
    df_grp = (
        df
        .groupby("claim_id", sort=False)
        .agg({"predicted_label": list, "predicted_proba": list})
        .join(df_mean, how="inner")
    )
    
    
    tok = fn.stem.split(".")[0].split("-")
    dataset = "-".join([tok[0], tok[1] if "climatefever" in tok[1] else ""]).strip("-")
    model = fn.stem.split(".")[0].strip(dataset).strip("-")
    df_grp["score_name"] = fn.stem.split(".")[0]
    df_grp["dataset"] = dataset
    df_grp["model"] = model
    
    return df_grp.reset_index()

In [4]:
res = Parallel()(delayed(preprocess_predictions)(p) for p in sent_pls.iterdir())
df_sent_all = pd.concat(res, axis=0).rename(columns={"claim_id": "id"})

In [5]:
def update_with_preds(fn, df_sent):
    agg_type = fn.stem.split("_")[0]
    dataset = re.findall(".*_disagree_(.*)_total.*", fn.stem)[0]
    
    error_f = read_data(fn)
    if "mean_proba" in error_f[0] or "predicted_label" in error_f[0]:
        return
    
    df_filter = (
        df_sent
        .set_index("id")
        .loc[pd.Index([doc["id"] for doc in error_f], name="id")]
    )
    if dataset != "alltrain":
        df_filter = df_filter.query(f"dataset == '{dataset}'")
        
    all_models = sorted(df_sent["model"].unique())
    cols = ["model", "dataset"] + (["predicted_label"] if agg_type == "majority" else ["predicted_proba", "mean_proba"])
    df_filter = df_filter[cols]
    res = {}
    for bert, xlnet in zip(df_filter.query(f"model == '{all_models[0]}'").iterrows(), df_filter.query(f"model == '{all_models[1]}'").iterrows()):
        assert bert[0] == xlnet[0]
        sfid = bert[0]
        if sfid not in res:
            res[sfid] = {}
        if dataset == "alltrain":
            for c in cols[2:]:
                if c not in res[sfid]:
                    res[sfid][c] = {}
                if bert[1][cols[1]] not in res[sfid][c]:
                    res[sfid][c][bert[1][cols[1]]] = {bert[1].model.split("-")[0]: bert[1][c]}
                else:
                    res[sfid][c][bert[1][cols[1]]].update({bert[1].model.split("-")[0]: bert[1][c]})
                if xlnet[1][cols[1]] not in res[sfid][c]:
                    res[sfid][c][xlnet[1][cols[1]]] = {xlnet[1].model.split("-")[0]: xlnet[1][c]}
                else:
                    res[sfid][c][xlnet[1][cols[1]]].update({xlnet[1].model.split("-")[0]: xlnet[1][c]})
        else:
            for c in cols[2:]:
                res[sfid].update({
                    c: {
                        bert[1].model.split("-")[0]: bert[1][c],
                        xlnet[1].model.split("-")[0]: xlnet[1][c]
                    }
                })
    if dataset == "alltrain":
        res = {k: {kk: dict(sorted(vv.items())) for kk, vv in v.items()} for k, v in res.items()}
    error_f = pd.DataFrame(error_f).set_index("id").join(pd.DataFrame(res).T, how="left").reset_index()

    return write_jsonl(fn, error_f.to_dict("records"))

In [6]:
res = Parallel()(delayed(update_with_preds)(p, df_sent_all) for p in error_p.glob("*disagree*"))