In [1]:
import json
import pathlib
from collections import defaultdict
import numpy as np 
import re 
from dataflow.core.lispress import parse_lispress, render_compact, render_pretty
from dataflow.core.linearize import lispress_to_seq

from calibration_utils import read_nucleus_file, read_gold_file, single_exact_match

In [2]:
calflow_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221101T105421.jsonl" 
calflow_t5 = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103315.jsonl"
calflow_miso = "/brtx/604-nvme1/estengel/calflow_calibration/miso/tune_roberta_tok_fix_benchclamp_data/translate_output_calibrated/test_all.tgt"

calflow_models_and_paths = {"miso": calflow_miso, "bart": calflow_bart, "t5": calflow_t5} 

calflow_gold_path = "/brtx/601-nvme1/estengel/resources/data/smcalflow.agent.data.from_benchclamp"

calflow_gold_src = read_gold_file(f"{calflow_gold_path}/test_all.src_tok")
calflow_gold_tgt = read_gold_file(f"{calflow_gold_path}/test_all.tgt")
calflow_gold_idx = read_gold_file(f"{calflow_gold_path}/test_all.idx")
with open(f"{calflow_gold_path}/test.datum_id") as f1:
    calflow_gold_datum_id = [json.loads() for x in f1.readlines()]
calflow_gold_tgt_by_idx = {idx: gold for idx, gold in zip(calflow_gold_idx, calflow_gold_tgt)}



calflow_data_by_model = {}

__, nuc_data = read_nucleus_file(calflow_miso)
calflow_data_by_model["miso"] = nuc_data

with open(calflow_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['bart'] = bart_data

with open(calflow_t5) as f1:
    t5_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['t5'] = t5_data


In [3]:
treedst_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_tree_dst_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103357.jsonl" 
# not ready yet
treedst_t5 = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_tree_dst_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221106T140554.jsonl" 
treedst_miso = "/brtx/603-nvme1//estengel/calflow_calibration/tree_dst/tune_roberta/translate_output_calibrated/test.tgt"

treedst_models_and_paths = {"miso": treedst_miso, "bart": treedst_bart, "t5": treedst_t5} 

treedst_gold_path = "/brtx/601-nvme1/estengel/resources/data/tree_dst.agent.data"

treedst_gold_src = read_gold_file(f"{treedst_gold_path}/test.src_tok")
treedst_gold_tgt = read_gold_file(f"{treedst_gold_path}/test.tgt")
treedst_gold_idx = read_gold_file(f"{treedst_gold_path}/test.idx")
with open(f"{treedst_gold_path}/test.datum_id") as f1:
    treedst_gold_datum_id = [json.loads() for x in f1.readlines()]
treedst_gold_tgt_by_idx = {idx: gold for idx, gold in zip(treedst_gold_idx, treedst_gold_tgt)}



treedst_data_by_model = {}

__, nuc_data = read_nucleus_file(treedst_miso)
treedst_data_by_model["miso"] = nuc_data

with open(treedst_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]
treedst_data_by_model['bart'] = bart_data

with open(treedst_t5) as f1:
    t5_data = [json.loads(x) for x in f1.readlines()]
treedst_data_by_model['t5'] = t5_data

In [4]:
def get_low_prob(iterator, is_miso = False, threshold = 0.6):
    low_prob_idxs = []
    for idx, example in iterator:
        if is_miso: 
            try:
                min_prob = example[0][1]
                assert min_prob is not None
            except:
                min_prob = np.min(example[0]['expression_probs'])
                if min_prob is None:
                    min_prob = 1.0
        else:
            probs = np.exp(np.array(example['token_logprobs'][0]))
            min_prob = np.min(probs)

        if min_prob < threshold:
            low_prob_idxs.append(idx) 
    return low_prob_idxs

def report_low_idxs(low_idxs_by_model):
    done = []
    for model_a in low_idxs_by_model.keys():
        for model_b in low_idxs_by_model.keys():
            if model_a == model_b or (model_a, model_b) in done or (model_b, model_a) in done:
                continue
            done.append((model_a, model_b))

            # compute the overlap between the two 
            intersection = set(low_idxs_by_model[model_a]).intersection(set(low_idxs_by_model[model_b]))
            union = set(low_idxs_by_model[model_a]).union(set(low_idxs_by_model[model_b]))
            print(f"{model_a}: {len(low_idxs_by_model[model_a])/13496*100:.2f}%")
            print(f"{model_b}: {len(low_idxs_by_model[model_b])/13496*100:.2f}%")
            print(f"intersection of {model_a} and {model_b}: {len(intersection)} / {len(union)}: {len(intersection) / len(union):.2f}")

In [5]:
## CALFLOW 

calflow_low_idxs_by_model = {}
for model, data in calflow_data_by_model.items():
    if model == "miso":
        low_prob_idxs = get_low_prob(data.items(), is_miso=True)
    else:
        low_prob_idxs = get_low_prob(zip(calflow_gold_idx,  data), is_miso=False)
    calflow_low_idxs_by_model[model] = low_prob_idxs

report_low_idxs(calflow_low_idxs_by_model)

miso: 21.01%
bart: 6.09%
intersection of miso and bart: 555 / 3103: 0.18
miso: 21.01%
t5: 8.49%
intersection of miso and t5: 724 / 3258: 0.22
bart: 6.09%
t5: 8.49%
intersection of bart and t5: 312 / 1656: 0.19


In [6]:
## TREE DST
treedst_low_idxs_by_model = {}
for model, data in treedst_data_by_model.items():
    if model == "miso":
        low_prob_idxs = get_low_prob(data.items(), is_miso=True)
    else:
        low_prob_idxs = get_low_prob(zip(treedst_gold_idx, data), is_miso=False)
    treedst_low_idxs_by_model[model] = low_prob_idxs

report_low_idxs(treedst_low_idxs_by_model)

miso: 22.15%
bart: 8.28%
intersection of miso and bart: 1050 / 3058: 0.34
miso: 22.15%
t5: 7.68%
intersection of miso and t5: 923 / 3104: 0.30
bart: 8.28%
t5: 7.68%
intersection of bart and t5: 726 / 1429: 0.51


In [7]:
# low_prob_idxs = get_low_prob(list(treedst_data_by_model['miso'].items())[0:100], is_miso=True)
# print(len(treedst_data_by_model['miso']))A
print(len(nuc_data))


22841


In [14]:
def get_counts(data, low_idxs, is_miso=False, gold_idxs = None, gold_tgt_by_idx=None): 
    counts = {"low_correct": 0, "low_total": 0, "all_correct": 0, "all_total": 0, "high_correct": 0, "high_total": 0}
    if is_miso:
        enumerator = list(data.items())
    else:
        enumerator = list(zip(gold_idxs, data))

    for idx, example in enumerator:
        if is_miso: 
            try:
                pred_str = example[0][0]['tgt_str']
            except KeyError:
                pred_str = example[0]['tgt_str']
            gold_str = gold_tgt_by_idx[idx]
        else:
            pred_str = example['outputs'][0]
            gold_str = example['test_datum_canonical']
        try:
            pred_tgt = render_compact(parse_lispress(pred_str))
        except (AssertionError, IndexError) as e:
            pred_tgt = "(Error)"

        gold_tgt = render_compact(parse_lispress(gold_str))
        is_correct, __ = single_exact_match(pred_tgt, gold_tgt)

        # is_correct = pred_tgt == gold_tgt
        if idx in low_idxs:
            counts["low_total"] += 1
            if is_correct:
                counts["low_correct"] += 1
        else:
            counts["high_total"] += 1
            if is_correct:
                counts["high_correct"] += 1
        counts["all_total"] += 1
        if is_correct:
            counts["all_correct"] += 1
    return counts 


In [9]:

calflow_counts_by_model_pair = {}
done = []
for model_a in calflow_data_by_model.keys():
    for model_b in calflow_data_by_model.keys():
        print(f"evaluating {model_a} on low conf. from {model_b}")
        if (model_a, model_b) in done: 
            continue
        done.append((model_a, model_b))
        calflow_counts_by_model_pair[(model_a, model_b)] = get_counts(calflow_data_by_model[model_a], 
                                                              calflow_low_idxs_by_model[model_b], 
                                                              is_miso=model_a == "miso", 
                                                              gold_idxs=calflow_gold_idx, 
                                                              gold_tgt_by_idx=calflow_gold_tgt_by_idx)

# print report 
for (model_a, model_b), count_data in calflow_counts_by_model_pair.items():
    print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
    print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
    print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
    print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
    print()


evaluating miso on low conf. from miso
evaluating miso on low conf. from bart
evaluating miso on low conf. from t5
evaluating bart on low conf. from miso
evaluating bart on low conf. from bart
evaluating bart on low conf. from t5
evaluating t5 on low conf. from miso
evaluating t5 on low conf. from bart
evaluating t5 on low conf. from t5
MISO on low conf. from MISO
low: 1028 / 2836: 36.25
high: 9607 / 10660: 90.12
all: 10635 / 13496: 78.80

MISO on low conf. from BART
low: 277 / 822: 33.70
high: 10358 / 12674: 81.73
all: 10635 / 13496: 78.80

MISO on low conf. from T5
low: 454 / 1146: 39.62
high: 10181 / 12350: 82.44
all: 10635 / 13496: 78.80

BART on low conf. from MISO
low: 1574 / 2836: 55.50
high: 9573 / 10660: 89.80
all: 11147 / 13496: 82.59

BART on low conf. from BART
low: 258 / 822: 31.39
high: 10889 / 12674: 85.92
all: 11147 / 13496: 82.59

BART on low conf. from T5
low: 516 / 1146: 45.03
high: 10631 / 12350: 86.08
all: 11147 / 13496: 82.59

T5 on low conf. from MISO
low: 1484 /

In [16]:
calflow_three_way_union = set(calflow_low_idxs_by_model['miso']).union(set(calflow_low_idxs_by_model['t5'])).union(set(calflow_low_idxs_by_model['bart']))
print(f"Size of three-way union: {len(calflow_three_way_union)}")
for model_a in calflow_data_by_model.keys():
    calflow_counts_by_model_pair[(model_a, 'union')] = get_counts(calflow_data_by_model[model_a], 
                                                            calflow_three_way_union,
                                                            is_miso=model_a == "miso", 
                                                            gold_idxs=calflow_gold_idx, 
                                                            gold_tgt_by_idx=calflow_gold_tgt_by_idx)
for (model_a, model_b), count_data in calflow_counts_by_model_pair.items():
    if model_b == "union": 
        print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
        print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
        print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
        print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
        print()

Size of three-way union: 3440
MISO on low conf. from UNION
low: 1369 / 3440: 39.80
high: 9271 / 10056: 92.19
all: 10640 / 13496: 78.84

BART on low conf. from UNION
low: 1879 / 3440: 54.62
high: 9272 / 10056: 92.20
all: 11151 / 13496: 82.62

T5 on low conf. from UNION
low: 1750 / 3440: 50.87
high: 9279 / 10056: 92.27
all: 11029 / 13496: 81.72



In [11]:
# TREE DST

treedst_counts_by_model_pair = {}
done = []
for model_a in treedst_data_by_model.keys():
    for model_b in treedst_data_by_model.keys():
        print(f"evaluating {model_a} on low conf. from {model_b}")
        if (model_a, model_b) in done: 
            continue
        done.append((model_a, model_b))
        treedst_counts_by_model_pair[(model_a, model_b)] = get_counts(treedst_data_by_model[model_a], 
                                                              treedst_low_idxs_by_model[model_b], 
                                                              is_miso=model_a == "miso", 
                                                              gold_idxs=treedst_gold_idx, 
                                                              gold_tgt_by_idx=treedst_gold_tgt_by_idx)

# print report 
for (model_a, model_b), count_data in treedst_counts_by_model_pair.items():
    print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
    print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
    print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
    print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
    print()

evaluating miso on low conf. from miso
evaluating miso on low conf. from bart
evaluating miso on low conf. from t5
evaluating bart on low conf. from miso
evaluating bart on low conf. from bart
evaluating bart on low conf. from t5
evaluating t5 on low conf. from miso
evaluating t5 on low conf. from bart
evaluating t5 on low conf. from t5
MISO on low conf. from MISO
low: 758 / 2990: 25.35
high: 14488 / 19851: 72.98
all: 15246 / 22841: 66.75

MISO on low conf. from BART
low: 173 / 1118: 15.47
high: 15073 / 21723: 69.39
all: 15246 / 22841: 66.75

MISO on low conf. from T5
low: 174 / 1037: 16.78
high: 15072 / 21804: 69.12
all: 15246 / 22841: 66.75

BART on low conf. from MISO
low: 1600 / 2990: 53.51
high: 19357 / 19851: 97.51
all: 20957 / 22841: 91.75

BART on low conf. from BART
low: 320 / 1118: 28.62
high: 20637 / 21723: 95.00
all: 20957 / 22841: 91.75

BART on low conf. from T5
low: 337 / 1037: 32.50
high: 20620 / 21804: 94.57
all: 20957 / 22841: 91.75

T5 on low conf. from MISO
low: 152

In [15]:
treedst_three_way_union = set(treedst_low_idxs_by_model['miso']).union(set(treedst_low_idxs_by_model['t5'])).union(set(treedst_low_idxs_by_model['bart']))
print(f"Size of three-way union: {len(treedst_three_way_union)}")
for model_a in treedst_data_by_model.keys():
    print(f"Model {model_a}-union")
    treedst_counts_by_model_pair[(model_a, 'union')] = get_counts(treedst_data_by_model[model_a], 
                                                            treedst_three_way_union,
                                                            is_miso=model_a == "miso", 
                                                            gold_idxs=treedst_gold_idx, 
                                                            gold_tgt_by_idx=treedst_gold_tgt_by_idx)
for (model_a, model_b), count_data in treedst_counts_by_model_pair.items():
    if model_b == "union": 
        print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
        print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
        print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
        print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
        print()

Size of three-way union: 3155
Model miso-union
Model bart-union
Model t5-union
MISO on low conf. from UNION
low: 1196 / 3155: 37.91
high: 19225 / 19686: 97.66
all: 20421 / 22841: 89.41

BART on low conf. from UNION
low: 1699 / 3155: 53.85
high: 19258 / 19686: 97.83
all: 20957 / 22841: 91.75

T5 on low conf. from UNION
low: 1616 / 3155: 51.22
high: 19234 / 19686: 97.70
all: 20850 / 22841: 91.28



In [13]:
len(treedst_data_by_model['bart'])

22841