In [23]:
import jsonlines
import os
import tqdm
import pandas as pd
import yaml

with open("config.yaml", "r") as fin:
    config = yaml.load(fin, Loader=yaml.FullLoader)

# Evaluation

In [24]:
dataset_names = [
    "sup2vsup2",
    "refvsup2",
    "sup2vsup2+refvsup2",
    "refvsup2policy"
]
metrics = pd.DataFrame()
for dataset_name in dataset_names:
    input_file = os.path.join(
        config["models_dir"],
        "comparisons_" + dataset_name + "_train_gpt2",
        "final_checkpoint",
        "metrics_sup2vsup2_test.txt")
    m = pd.read_csv(input_file, header=None)
    m.columns = ["metric", "value"]
    m["dataset_name"] = dataset_name
    metrics = metrics.append(m)
metrics[metrics["metric"] == "acc"].sort_values("value", ascending=False)

Unnamed: 0,metric,value,dataset_name
1,acc,0.593708,sup2vsup2+refvsup2
1,acc,0.580295,sup2vsup2
1,acc,0.568223,refvsup2
1,acc,0.455798,refvsup2policy


# Policy composition

In [2]:
df = {
    "policyA": [],
    "policyB": [],
    "choiceAB": [],
    "prompt": [],
    "summaryA": [],
    "summaryB": [],
    "noteA": [],
    "noteB": [],
    "batch": [],
    "split": []
}
input_files = [
    "comparisons_train.jsonl",
    "comparisons_valid.jsonl",
    "comparisons_test.jsonl"
]
for input_file in input_files:
    input_path = os.path.join(config["data_dir"], input_file)
    with jsonlines.open(input_path, "r") as fin:
        for line in tqdm.tqdm(fin):
            policy0 = line["example"]["summaries"][0]["policy"]
            policy1 = line["example"]["summaries"][1]["policy"]
            choice = line["choice"]
            assert line["example"]["summaries"][0]["text"].strip() == line["completion0"].strip()
            assert line["example"]["summaries"][1]["text"].strip() == line["completion1"].strip()
            if policy0 < policy1:
                policyA = policy0
                policyB = policy1
                choiceAB = line["choice"]
                summaryA = line["example"]["summaries"][0]["text"].strip()
                summaryB = line["example"]["summaries"][1]["text"].strip()
                noteA = line["example"]["summaries"][0]["note"]
                noteB = line["example"]["summaries"][1]["note"]
            else:
                policyA = policy1
                policyB = policy0
                choiceAB = 1 - line["choice"]
                summaryA = line["example"]["summaries"][1]["text"].strip()
                summaryB = line["example"]["summaries"][0]["text"].strip()
                noteA = line["example"]["summaries"][1]["note"]
                noteB = line["example"]["summaries"][0]["note"]
            df["policyA"].append(policyA)
            df["policyB"].append(policyB)
            df["choiceAB"].append(choiceAB)
            df["prompt"].append(line["prompt"])
            df["summaryA"].append(summaryA)
            df["summaryB"].append(summaryB)
            df["noteA"].append(noteA)
            df["noteB"].append(noteB)
            df["batch"].append(line["example"]["batch"])
            df["split"].append(line["example"]["split"])
df = pd.DataFrame(df)
df["n"] = 1

92858it [00:02, 40190.30it/s]
33082it [00:00, 40424.17it/s]
50715it [00:01, 40478.60it/s]


In [19]:
d = df.groupby(["policyA", "policyB"]).mean().reset_index().sort_values("choiceAB", ascending=False)
d = d[["policyA", "policyB", "choiceAB"]]
d2 = df.groupby(["policyA", "policyB"]).count().reset_index().sort_values("n", ascending=False)
d2 = d2[["policyA", "policyB", "n"]]
d = pd.merge(d, d2, on=["policyA", "policyB"])
d = d.sort_values("n", ascending=False)
f = [True] * len(d)
f &= (d["policyA"].apply(lambda x: x.find("_ppo")) == -1)
f &= (d["policyB"].apply(lambda x: x.find("_ppo")) == -1)
f &= (d["policyA"].apply(lambda x: x.find("_bo")) == -1)
f &= (d["policyB"].apply(lambda x: x.find("_bo")) == -1)
print(sum(d[f]["n"]))
d[f].sort_values("n", ascending=False)[0:20]

96098


Unnamed: 0,policyA,policyB,choiceAB,n
216,ref,sup2,0.263811,18608
121,sup2,sup2,0.491735,16273
229,ref,sup1,0.189002,7947
111,sup1,sup1,0.506337,7653
115,sup4_6b_t0.7,sup4_6b_t0.7,0.500269,7442
224,ref,sup4_t0.7,0.228298,4919
104,sup4_t0.7,sup4_t0.7,0.520085,4257
195,ref,sup3_6b,0.330476,2333
72,sup2,sup3_6b,0.59117,2265
206,ref,sup4_6b_t0.7,0.291667,1752


In [132]:
d = df.groupby(["policyA", "policyB", "split"]).sum().reset_index()
d = d[["policyA", "policyB", "split", "n"]]
d = d.sort_values("n", ascending=False)
d[0:20]

Unnamed: 0,policyA,policyB,split,n
97,ref,sup2,train,18065
245,sup2,sup2,train,16273
96,ref,sup1,train,7947
244,sup1,sup1,train,7653
383,sup4_ppo_rm3_kl10,sup4_ppo_rm3_kl10,train,6206
409,sup4_ppo_rm3_kl20,sup4_ppo_rm3_kl20,train,6098
340,sup4_6b_t0.7,sup4_6b_t0.7,train,5614
444,sup4_t0.7,sup4_t0.7,valid1,2785
410,sup4_ppo_rm3_kl20,sup4_ppo_rm3_kl20,valid1,2340
384,sup4_ppo_rm3_kl10,sup4_ppo_rm3_kl10,valid1,2070
