In [1]:
import json
import pathlib
from collections import defaultdict
import numpy as np 
import re 
from dataflow.core.lispress import parse_lispress, render_compact, render_pretty
from dataflow.core.linearize import lispress_to_seq

from calibration_utils import read_nucleus_file, read_gold_file, single_exact_match

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
calflow_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221101T105421.jsonl" 
calflow_t5 = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103315.jsonl"
calflow_miso = "/brtx/604-nvme1/estengel/calflow_calibration/miso/tune_roberta_tok_fix_benchclamp_data/translate_output_calibrated/test_all.tgt"

calflow_models_and_paths = {"miso": calflow_miso, "bart": calflow_bart, "t5": calflow_t5} 

calflow_gold_path = "/brtx/601-nvme1/estengel/resources/data/smcalflow.agent.data.from_benchclamp"

calflow_gold_src = read_gold_file(f"{calflow_gold_path}/test_all.src_tok")
calflow_gold_tgt = read_gold_file(f"{calflow_gold_path}/test_all.tgt")
calflow_gold_idx = read_gold_file(f"{calflow_gold_path}/test_all.idx")
with open(f"{calflow_gold_path}/test_all.datum_id") as f1:
    calflow_gold_datum_id = [json.loads(x) for x in f1.readlines()]
calflow_gold_tgt_by_idx = {idx: gold for idx, gold in zip(calflow_gold_idx, calflow_gold_tgt)}



calflow_data_by_model = {}

__, nuc_data = read_nucleus_file(calflow_miso)
calflow_data_by_model["miso"] = nuc_data

with open(calflow_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['bart'] = bart_data

with open(calflow_t5) as f1:
    t5_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['t5'] = t5_data


In [3]:
treedst_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_tree_dst_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103357.jsonl" 
# not ready yet
treedst_t5 = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_tree_dst_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221106T140554.jsonl" 
treedst_miso = "/brtx/603-nvme1//estengel/calflow_calibration/tree_dst/tune_roberta/translate_output_calibrated/test.tgt"

treedst_models_and_paths = {"miso": treedst_miso, "bart": treedst_bart, "t5": treedst_t5} 

treedst_gold_path = "/brtx/601-nvme1/estengel/resources/data/tree_dst.agent.data"

treedst_gold_src = read_gold_file(f"{treedst_gold_path}/test.src_tok")
treedst_gold_tgt = read_gold_file(f"{treedst_gold_path}/test.tgt")
treedst_gold_idx = read_gold_file(f"{treedst_gold_path}/test.idx")
with open(f"{treedst_gold_path}/test.datum_id") as f1:
    treedst_gold_datum_id = [json.loads(x) for x in f1.readlines()]
treedst_gold_tgt_by_idx = {idx: gold for idx, gold in zip(treedst_gold_idx, treedst_gold_tgt)}



treedst_data_by_model = {}

__, nuc_data = read_nucleus_file(treedst_miso)
treedst_data_by_model["miso"] = nuc_data

with open(treedst_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]
treedst_data_by_model['bart'] = bart_data

with open(treedst_t5) as f1:
    t5_data = [json.loads(x) for x in f1.readlines()]
treedst_data_by_model['t5'] = t5_data

In [29]:
def get_low_prob(iterator, is_miso = False, threshold = 0.6):
    low_prob_idxs = []
    for idx, example in iterator:
        if is_miso: 
            try:
                min_prob = example[0][1]
                assert min_prob is not None
            except:
                min_prob = np.min(example[0]['expression_probs'])
                if min_prob is None:
                    min_prob = 1.0
        else:
            probs = np.exp(np.array(example['token_logprobs'][0]))
            min_prob = np.min(probs)

        if min_prob < threshold:
            low_prob_idxs.append(idx) 
    return low_prob_idxs

def report_low_idxs(low_idxs_by_model, is_treedst=False):
    if is_treedst:
        denom = 22841
    else:
        denom = 13496
    done = []
    for model_a in low_idxs_by_model.keys():
        for model_b in low_idxs_by_model.keys():
            if model_a == model_b or (model_a, model_b) in done or (model_b, model_a) in done:
                continue
            done.append((model_a, model_b))

            # compute the overlap between the two 
            intersection = set(low_idxs_by_model[model_a]).intersection(set(low_idxs_by_model[model_b]))
            union = set(low_idxs_by_model[model_a]).union(set(low_idxs_by_model[model_b]))
            print(f"{model_a}: {len(low_idxs_by_model[model_a])/denom*100:.2f}%")
            print(f"{model_b}: {len(low_idxs_by_model[model_b])/denom*100:.2f}%")
            print(f"intersection of {model_a} and {model_b}: {len(intersection)} / {len(union)}: {len(intersection) / len(union):.2f}")

In [35]:
# find the threshold by quantile
# get all the min probs overall 
def get_percentiles(data_by_model):
    min_probs = []
    min_probs_by_model = defaultdict(list)
    for model, data in data_by_model.items():
        is_miso = "miso" in model
        if is_miso:
            iterator = data.items()
        else:
            iterator = enumerate(data)
        for idx, example in iterator:
            if is_miso: 
                try:
                    min_prob = example[0][1]
                    assert min_prob is not None
                except:
                    min_prob = np.min(example[0]['expression_probs'])
                    if min_prob is None:
                        min_prob = 1.0
            else:
                probs = np.exp(np.array(example['token_logprobs'][0]))
                min_prob = np.min(probs)
            min_probs.append(min_prob) 
            min_probs_by_model[model].append(min_prob)

    min_probs = np.array(min_probs)
    print(min_probs.shape)
    threshold = None
    # for percentile in range(0, 100, 5):
    for percentile in [25]:
        value = np.percentile(min_probs, percentile)
        num_hard, num_easy = np.sum(min_probs < value), np.sum(min_probs >= value)
        for model, model_probs in min_probs_by_model.items():
            model_num_hard, model_num_easy = np.sum(np.array(model_probs) < value), np.sum(np.array(model_probs) >= value)
            model_value = np.percentile(np.array(model_probs), percentile)
            print(f"\t{model}: {percentile:.2f}% of examples have min prob < {model_value:.2f}: {model_num_easy/len(model_probs)*100:.2f} easy, {model_num_hard/len(model_probs)*100:.2f} hard")
        perc_easy, perc_hard = num_easy / len(min_probs), num_hard / len(min_probs)
        print(f"{percentile:.2f}% of examples have min prob < {value:.2f}: {perc_easy*100:.2f} easy, {perc_hard*100:.2f} hard")
        # return the 25th percentile 
        if percentile == 25:
            threshold = value
    return threshold

print("Calflow")
cf_thresh = get_percentiles(calflow_data_by_model)
print(cf_thresh)
print(f"TreeDST")
tdst_thresh = get_percentiles(treedst_data_by_model)
print(tdst_thresh)

Calflow
(40488,)
	miso: 25.00% of examples have min prob < 0.72: 66.69 easy, 33.31 hard
	bart: 25.00% of examples have min prob < 0.91: 81.46 easy, 18.54 hard
	t5: 25.00% of examples have min prob < 0.88: 76.85 easy, 23.15 hard
25.00% of examples have min prob < 0.86: 75.00 easy, 25.00 hard
0.8592969851961498
TreeDST
(68523,)
	miso: 25.00% of examples have min prob < 0.94: 81.42 easy, 18.58 hard
	bart: 25.00% of examples have min prob < 0.84: 55.50 easy, 44.50 hard
	t5: 25.00% of examples have min prob < 0.99: 88.07 easy, 11.93 hard
25.00% of examples have min prob < 0.85: 75.00 easy, 25.00 hard
0.8523301291401095


In [15]:
## CALFLOW 

calflow_low_idxs_by_model = {}
for model, data in calflow_data_by_model.items():
    if model == "miso":
        low_prob_idxs = get_low_prob(data.items(), is_miso=True, threshold=cf_thresh)
    else:
        low_prob_idxs = get_low_prob(zip(calflow_gold_idx,  data), is_miso=False, threshold=cf_thresh)
    calflow_low_idxs_by_model[model] = low_prob_idxs

report_low_idxs(calflow_low_idxs_by_model)

miso: 33.31%
bart: 18.54%
intersection of miso and bart: 2052 / 4946: 0.41
miso: 33.31%
t5: 23.15%
intersection of miso and t5: 2409 / 5211: 0.46
bart: 18.54%
t5: 23.15%
intersection of bart and t5: 1771 / 3855: 0.46


In [30]:
## TREE DST
treedst_low_idxs_by_model = {}
for model, data in treedst_data_by_model.items():
    if model == "miso":
        low_prob_idxs = get_low_prob(data.items(), is_miso=True, threshold=tdst_thresh)
    else:
        low_prob_idxs = get_low_prob(zip(treedst_gold_idx, data), is_miso=False, threshold=tdst_thresh)
    treedst_low_idxs_by_model[model] = low_prob_idxs

report_low_idxs(treedst_low_idxs_by_model, is_treedst=True)

miso: 18.58%
bart: 44.50%
intersection of miso and bart: 2971 / 11436: 0.26
miso: 18.58%
t5: 11.93%
intersection of miso and t5: 2555 / 4412: 0.58
bart: 44.50%
t5: 11.93%
intersection of bart and t5: 2371 / 10517: 0.23


In [17]:
# low_prob_idxs = get_low_prob(list(treedst_data_by_model['miso'].items())[0:100], is_miso=True)
# print(len(treedst_data_by_model['miso']))A
print(len(nuc_data))


22841


In [18]:
def get_counts(data, low_idxs, is_miso=False, gold_idxs = None, gold_tgt_by_idx=None): 
    counts = {"low_correct": 0, "low_total": 0, "all_correct": 0, "all_total": 0, "high_correct": 0, "high_total": 0}
    if is_miso:
        enumerator = list(data.items())
    else:
        enumerator = list(zip(gold_idxs, data))

    for idx, example in enumerator:
        if is_miso: 
            try:
                pred_str = example[0][0]['tgt_str']
            except KeyError:
                pred_str = example[0]['tgt_str']
            gold_str = gold_tgt_by_idx[idx]
        else:
            pred_str = example['outputs'][0]
            gold_str = example['test_datum_canonical']
        try:
            pred_tgt = render_compact(parse_lispress(pred_str))
        except (AssertionError, IndexError) as e:
            pred_tgt = "(Error)"

        gold_tgt = render_compact(parse_lispress(gold_str))
        is_correct, __ = single_exact_match(pred_tgt, gold_tgt)

        # is_correct = pred_tgt == gold_tgt
        if idx in low_idxs:
            counts["low_total"] += 1
            if is_correct:
                counts["low_correct"] += 1
        else:
            counts["high_total"] += 1
            if is_correct:
                counts["high_correct"] += 1
        counts["all_total"] += 1
        if is_correct:
            counts["all_correct"] += 1
    return counts 


In [22]:

calflow_counts_by_model_pair = {}
done = []
for model_a in calflow_data_by_model.keys():
    for model_b in calflow_data_by_model.keys():
        print(f"evaluating {model_a} on low conf. from {model_b}")
        if (model_a, model_b) in done: 
            continue
        done.append((model_a, model_b))
        calflow_counts_by_model_pair[(model_a, model_b)] = get_counts(calflow_data_by_model[model_a], 
                                                              calflow_low_idxs_by_model[model_b], 
                                                              is_miso=model_a == "miso", 
                                                              gold_idxs=calflow_gold_idx, 
                                                              gold_tgt_by_idx=calflow_gold_tgt_by_idx)

# print report 
for (model_a, model_b), count_data in calflow_counts_by_model_pair.items():
    print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
    print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
    print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
    print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
    print()


evaluating miso on low conf. from miso
evaluating miso on low conf. from bart
evaluating miso on low conf. from t5
evaluating bart on low conf. from miso
evaluating bart on low conf. from bart
evaluating bart on low conf. from t5
evaluating t5 on low conf. from miso
evaluating t5 on low conf. from bart
evaluating t5 on low conf. from t5
MISO on low conf. from MISO
low: 2229 / 4496: 49.58
high: 8411 / 9000: 93.46
all: 10640 / 13496: 78.84

MISO on low conf. from BART
low: 1074 / 2502: 42.93
high: 9566 / 10994: 87.01
all: 10640 / 13496: 78.84

MISO on low conf. from T5
low: 1454 / 3124: 46.54
high: 9186 / 10372: 88.57
all: 10640 / 13496: 78.84

BART on low conf. from MISO
low: 2774 / 4496: 61.70
high: 8377 / 9000: 93.08
all: 11151 / 13496: 82.62

BART on low conf. from BART
low: 1122 / 2502: 44.84
high: 10029 / 10994: 91.22
all: 11151 / 13496: 82.62

BART on low conf. from T5
low: 1681 / 3124: 53.81
high: 9470 / 10372: 91.30
all: 11151 / 13496: 82.62

T5 on low conf. from MISO
low: 2655 

In [23]:
calflow_three_way_union = set(calflow_low_idxs_by_model['miso']).union(set(calflow_low_idxs_by_model['t5'])).union(set(calflow_low_idxs_by_model['bart']))
print(f"Size of three-way union: {len(calflow_three_way_union)}")
for model_a in calflow_data_by_model.keys():
    calflow_counts_by_model_pair[(model_a, 'union')] = get_counts(calflow_data_by_model[model_a], 
                                                            calflow_three_way_union,
                                                            is_miso=model_a == "miso", 
                                                            gold_idxs=calflow_gold_idx, 
                                                            gold_tgt_by_idx=calflow_gold_tgt_by_idx)
for (model_a, model_b), count_data in calflow_counts_by_model_pair.items():
    if model_b == "union": 
        print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
        print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
        print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
        print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
        print()

Size of three-way union: 5450
MISO on low conf. from UNION
low: 2912 / 5450: 53.43
high: 7728 / 8046: 96.05
all: 10640 / 13496: 78.84

BART on low conf. from UNION
low: 3415 / 5450: 62.66
high: 7736 / 8046: 96.15
all: 11151 / 13496: 82.62

T5 on low conf. from UNION
low: 3285 / 5450: 60.28
high: 7744 / 8046: 96.25
all: 11029 / 13496: 81.72



In [21]:
# TREE DST

treedst_counts_by_model_pair = {}
done = []
for model_a in treedst_data_by_model.keys():
    for model_b in treedst_data_by_model.keys():
        print(f"evaluating {model_a} on low conf. from {model_b}")
        if (model_a, model_b) in done: 
            continue
        done.append((model_a, model_b))
        treedst_counts_by_model_pair[(model_a, model_b)] = get_counts(treedst_data_by_model[model_a], 
                                                              treedst_low_idxs_by_model[model_b], 
                                                              is_miso=model_a == "miso", 
                                                              gold_idxs=treedst_gold_idx, 
                                                              gold_tgt_by_idx=treedst_gold_tgt_by_idx)

# print report 
for (model_a, model_b), count_data in treedst_counts_by_model_pair.items():
    print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
    print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
    print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
    print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
    print()

evaluating miso on low conf. from miso
evaluating miso on low conf. from bart


KeyboardInterrupt: 

In [24]:
treedst_three_way_union = set(treedst_low_idxs_by_model['miso']).union(set(treedst_low_idxs_by_model['t5'])).union(set(treedst_low_idxs_by_model['bart']))
print(f"Size of three-way union: {len(treedst_three_way_union)}")
for model_a in treedst_data_by_model.keys():
    print(f"Model {model_a}-union")
    treedst_counts_by_model_pair[(model_a, 'union')] = get_counts(treedst_data_by_model[model_a], 
                                                            treedst_three_way_union,
                                                            is_miso=model_a == "miso", 
                                                            gold_idxs=treedst_gold_idx, 
                                                            gold_tgt_by_idx=treedst_gold_tgt_by_idx)
for (model_a, model_b), count_data in treedst_counts_by_model_pair.items():
    if model_b == "union": 
        print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
        print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
        print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
        print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
        print()

Size of three-way union: 11535
Model miso-union
Model bart-union
Model t5-union
MISO on low conf. from UNION
low: 9265 / 11535: 80.32
high: 10675 / 11306: 94.42
all: 19940 / 22841: 87.30

BART on low conf. from UNION
low: 9801 / 11535: 84.97
high: 11156 / 11306: 98.67
all: 20957 / 22841: 91.75

T5 on low conf. from UNION
low: 9701 / 11535: 84.10
high: 11149 / 11306: 98.61
all: 20850 / 22841: 91.28



In [None]:
len(treedst_data_by_model['bart'])