In [1]:
import json
import pathlib
from collections import defaultdict
import numpy as np 
import re 
from dataflow.core.lispress import parse_lispress, render_compact
from dataflow.core.linearize import lispress_to_seq

In [2]:
def read_nucleus_file(miso_pred_file):
    with open(miso_pred_file, "r") as f:
        data = [json.loads(x) for x in f.readlines()]
    to_ret = []
    data_by_idx = defaultdict(list)
    data_by_src_str = defaultdict(list)
    for line in data:
        data_by_src_str[line['src_str']].append(line) 
        data_by_idx[line['line_idx']].append(line) 

    for src_str, lines in data_by_src_str.items():
        total_probs = [np.exp(np.sum(np.log(x['expression_probs']))) 
                                if x['expression_probs'] is not None else 0.0 
                                    for x in lines ]
        mean_probs = [np.mean(x['expression_probs']) 
                                if x['expression_probs'] is not None and np.sum(x['expression_probs']) > 0.0 
                                else 0.0 for x in lines ]
        min_probs = []
        for x in lines:
            if x['expression_probs'] is not None and len(x['expression_probs']) > 0:
                min_probs.append(np.min(x['expression_probs']))
            else:
                min_probs.append(0.0)

        combo_lines = zip(lines, min_probs, mean_probs, total_probs)
        sorted_combo_lines = sorted(combo_lines, key=lambda x: x[-1], reverse=True)

        data_by_src_str[src_str] = sorted_combo_lines
        idx = lines[0]['line_idx']
        data_by_idx[idx] = sorted_combo_lines
    return data_by_src_str, data_by_idx

def read_gold_file(file):
    with open(file) as f:
        if file.endswith(".tgt"):
            to_ret = [render_compact(parse_lispress(line)) for line in f.readlines()]
        else:
            to_ret = [re.sub("__StartOfProgram", "", x).strip() for x in f.readlines()]
    return to_ret 


In [3]:
calflow_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221101T105421.jsonl" 
calflow_miso = "/brtx/604-nvme1/estengel/calflow_calibration/miso/tune_roberta_tok_fix_benchclamp_data/translate_output_calibrated/test_all.tgt"
calflow_t5 = ""

calflow_models_and_paths = {"miso": calflow_miso, "bart": calflow_bart} 

gold_path = "/brtx/601-nvme1/estengel/resources/data/smcalflow.agent.data.from_benchclamp"

gold_src = read_gold_file(f"{gold_path}/test_all.src_tok")
gold_tgt = read_gold_file(f"{gold_path}/test_all.tgt")
gold_idx = read_gold_file(f"{gold_path}/test_all.idx")
gold_tgt_by_idx = {idx: gold for idx, gold in zip(gold_idx, gold_tgt)}



calflow_data_by_model = {}

__, nuc_data = read_nucleus_file(calflow_miso)
calflow_data_by_model["miso"] = nuc_data

with open(calflow_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['bart'] = bart_data


In [4]:
def get_low_prob(iterator, is_miso = False, threshold = 0.6):
    low_prob_idxs = []
    for idx, example in iterator:
        if is_miso: 
            try:
                min_prob = example[0][1]
            except:
                min_prob = np.min(example[0]['expression_probs'])
        else:
            probs = np.exp(np.array(example['token_logprobs'][0]))
            min_prob = np.min(probs)

        if min_prob < 0.5:
            low_prob_idxs.append(idx) 
    return low_prob_idxs


low_idxs_by_model = {}
for model, data in calflow_data_by_model.items():
    if model == "miso":
        low_prob_idxs = get_low_prob(data.items(), is_miso=True)
    else:
        low_prob_idxs = get_low_prob(zip(gold_idx, data), is_miso=False)

    low_idxs_by_model[model] = low_prob_idxs
    # print(f"{model} has {len(low_prob_idxs)} low prob examples")

done = []
for model_a in low_idxs_by_model.keys():
    for model_b in low_idxs_by_model.keys():
        if model_a == model_b or (model_a, model_b) in done or (model_b, model_a) in done:
            continue
        done.append((model_a, model_b))

        # compute the overlap between the two 
        intersection = set(low_idxs_by_model[model_a]).intersection(set(low_idxs_by_model[model_b]))
        union = set(low_idxs_by_model[model_a]).union(set(low_idxs_by_model[model_b]))
        print(f"{model_a}: {len(low_idxs_by_model[model_a])}")
        print(f"{model_b}: {len(low_idxs_by_model[model_b])}")
        print(f"intersection: {len(intersection)} / {len(union)}: {len(intersection) / len(union):.2f}")



miso: 2649
bart: 345
intersection: 224 / 2770: 0.08086642599277978


In [7]:
def get_counts(data, low_idxs, is_miso=False, gold_idxs = None, gold_tgt_by_idx=None): 
    counts = {"low_correct": 0, "low_total": 0, "all_correct": 0, "all_total": 0, "high_correct": 0, "high_total": 0}
    if is_miso:
        enumerator = list(data.items())
    else:
        enumerator = list(zip(gold_idxs, data))

    for idx, example in enumerator:
        if is_miso: 
            try:
                pred_str = example[0][0]['tgt_str']
            except KeyError:
                pred_str = example[0]['tgt_str']
            gold_str = gold_tgt_by_idx[idx]
        else:
            pred_str = example['outputs'][0]
            gold_str = example['test_datum_canonical']
        try:
            pred_tgt = render_compact(parse_lispress(pred_str))
        except (AssertionError, IndexError) as e:
            pred_tgt = "(Error)"

        gold_tgt = render_compact(parse_lispress(gold_str))

        is_correct = pred_tgt == gold_tgt
        if idx in low_idxs:
            counts["low_total"] += 1
            if is_correct:
                counts["low_correct"] += 1
        else:
            counts["high_total"] += 1
            if is_correct:
                counts["high_correct"] += 1
        counts["all_total"] += 1
        if is_correct:
            counts["all_correct"] += 1
    return counts 


In [9]:

counts_by_model_pair = {}
done = []
for model_a in calflow_data_by_model.keys():
    for model_b in calflow_data_by_model.keys():
        print(f"evaluating {model_a} on low conf. from {model_b}")
        if model_a == model_b or (model_a, model_b) in done: 
            continue
        done.append((model_a, model_b))
        counts_by_model_pair[(model_a, model_b)] = get_counts(calflow_data_by_model[model_a], 
                                                              low_idxs_by_model[model_b], 
                                                              is_miso=model_a == "miso", 
                                                              gold_idxs=gold_idx, 
                                                              gold_tgt_by_idx=gold_tgt_by_idx)


evaluating miso on low conf. from miso
evaluating miso on low conf. from bart
evaluating bart on low conf. from miso
evaluating bart on low conf. from bart


In [11]:
# print report 
for (model_a, model_b), count_data in counts_by_model_pair.items():
    print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
    print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
    print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
    print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
    print()

MISO on low conf. from BART
low: 89 / 345: 25.80
high: 10163 / 13151: 77.28
all: 10252 / 13496: 75.96

BART on low conf. from MISO
low: 1515 / 2649: 57.19
high: 9632 / 10847: 88.80
all: 11147 / 13496: 82.59

