In [1]:
import json
import pathlib
from collections import defaultdict
import numpy as np 
import re 
from dataflow.core.lispress import parse_lispress, render_compact
from dataflow.core.linearize import lispress_to_seq

In [2]:
def read_nucleus_file(miso_pred_file):
    with open(miso_pred_file, "r") as f:
        data = [json.loads(x) for x in f.readlines()]
    to_ret = []
    data_by_idx = defaultdict(list)
    data_by_src_str = defaultdict(list)
    for line in data:
        data_by_src_str[line['src_str']].append(line) 
        data_by_idx[line['line_idx']].append(line) 

    for src_str, lines in data_by_src_str.items():
        total_probs = [np.exp(np.sum(np.log(x['expression_probs']))) 
                                if x['expression_probs'] is not None else 0.0 
                                    for x in lines ]
        mean_probs = [np.mean(x['expression_probs']) 
                                if x['expression_probs'] is not None and np.sum(x['expression_probs']) > 0.0 
                                else 0.0 for x in lines ]
        min_probs = []
        for x in lines:
            if x['expression_probs'] is not None and len(x['expression_probs']) > 0:
                min_probs.append(np.min(x['expression_probs']))
            else:
                min_probs.append(0.0)

        combo_lines = zip(lines, min_probs, mean_probs, total_probs)
        sorted_combo_lines = sorted(combo_lines, key=lambda x: x[-1], reverse=True)

        data_by_src_str[src_str] = sorted_combo_lines
        idx = lines[0]['line_idx']
        data_by_idx[idx] = sorted_combo_lines
    return data_by_src_str, data_by_idx

def read_gold_file(file):
    with open(file) as f:
        if file.endswith(".tgt"):
            to_ret = [render_compact(parse_lispress(line)) for line in f.readlines()]
        else:
            to_ret = [re.sub("__StartOfProgram", "", x).strip() for x in f.readlines()]
    return to_ret 

def get_probs_and_accs(nucleus_file, gold_src_file, gold_tgt_file):
    nucleus, __ = read_nucleus_file(nucleus_file)
    gold_tgt = read_gold_file(gold_tgt_file)
    gold_src = read_gold_file(gold_src_file)
    # assert(len(nucleus) == len(gold_tgt))
    min_probs = []
    mean_probs = []
    accs = []
    for i,  (gold_src, gold_tgt) in enumerate(zip(gold_src, gold_tgt)):
        nuc = nucleus[gold_src]
        nuc_str = nuc[0][0]['tgt_str']
            
        nuc_str = render_compact(parse_lispress(nuc_str))
        # use the min prob, not the summed prob 
        min_probs.append(nuc[0][1])
        # TODO (elias): add total number of tokens to get mean prob 
        mean_probs.append(nuc[0][2])
        accs.append(nuc_str == gold_tgt)
    return min_probs, mean_probs, accs

In [3]:
calflow_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221101T105421.jsonl" 
calflow_miso = "/brtx/604-nvme1/estengel/calflow_calibration/miso/tune_roberta_tok_fix_benchclamp_data/translate_output_calibrated/test_all.tgt"

gold_path = "/brtx/601-nvme1/estengel/resources/data/smcalflow.agent.data.from_benchclamp"

gold_src = read_gold_file(f"{gold_path}/test_all.src_tok")
gold_tgt = read_gold_file(f"{gold_path}/test_all.tgt")
gold_idx = read_gold_file(f"{gold_path}/test_all.idx")
gold_tgt_by_idx = {idx: gold for idx, gold in zip(gold_idx, gold_tgt)}

__, nuc_data = read_nucleus_file(calflow_miso)

with open(calflow_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]


In [4]:
# get the set of low-probability examples from MISO
miso_low_prob_idxs = []
for idx, example in nuc_data.items():
    try:
        min_prob = example[0][1]
    except:
        min_prob = np.min(example[0]['expression_probs'])

    if min_prob < 0.5:
        miso_low_prob_idxs.append(idx)
# get the set of low-probability examples from BART
bart_low_prob_idxs = []
i=0
for idx, src, example in zip(gold_idx, gold_src, bart_data):
    probs = np.exp(np.array(example['token_logprobs'][0]))
    min_prob = np.min(probs)
    if min_prob < 0.5: 
        bart_low_prob_idxs.append(idx)


# compute the overlap between the two 
intersection = set(miso_low_prob_idxs).intersection(set(bart_low_prob_idxs))
union = set(miso_low_prob_idxs).union(set(bart_low_prob_idxs))
print(f"MISO: {len(miso_low_prob_idxs)}")
print(f"BART: {len(bart_low_prob_idxs)}")
print(f"intersection: {len(intersection)} / {len(union)}: {len(intersection) / len(union)}")



MISO: 2649
BART: 345
intersection: 224 / 2770: 0.08086642599277978


In [5]:
def get_counts(data, low_idxs, is_bart=False, gold_idxs = None, gold_tgt_by_idx=None): 
    counts = {"low_correct": 0, "low_total": 0, "all_correct": 0, "all_total": 0, "high_correct": 0, "high_total": 0}
    if is_bart:
        enumerator = list(zip(gold_idxs, data))
    else:
        enumerator = list(data.items())

    for idx, example in enumerator:
        if is_bart: 
            pred_str = example['outputs'][0]
            gold_str = example['test_datum_canonical']
        else:
            try:
                pred_str = example[0][0]['tgt_str']
            except KeyError:
                pred_str = example[0]['tgt_str']
            gold_str = gold_tgt_by_idx[idx]
        try:
            pred_tgt = render_compact(parse_lispress(pred_str))
        except (AssertionError, IndexError) as e:
            pred_tgt = "(Error)"

        gold_tgt = render_compact(parse_lispress(gold_str))

        is_correct = pred_tgt == gold_tgt
        if idx in low_idxs:
            counts["low_total"] += 1
            if is_correct:
                counts["low_correct"] += 1
        else:
            counts["high_total"] += 1
            if is_correct:
                counts["high_correct"] += 1
        counts["all_total"] += 1
        if is_correct:
            counts["all_correct"] += 1
    return counts 


In [6]:

intra_miso_counts = get_counts(nuc_data, miso_low_prob_idxs, is_bart = False, gold_tgt_by_idx=gold_tgt_by_idx)
intra_bart_counts = get_counts(bart_data, bart_low_prob_idxs, is_bart = True, gold_idxs = gold_idx) 

In [7]:
# performance of miso on bart low-conf examples
inter_miso_bart_counts = get_counts(nuc_data, bart_low_prob_idxs, is_bart = False, gold_tgt_by_idx=gold_tgt_by_idx)
# performance of bart on miso low-conf examples
inter_bart_miso_counts = get_counts(bart_data, miso_low_prob_idxs, is_bart = True, gold_idxs = gold_idx)

In [8]:
# MISO report
print("MISO-MISO")
print(f"low: {intra_miso_counts['low_correct']} / {intra_miso_counts['low_total']}: {intra_miso_counts['low_correct'] / intra_miso_counts['low_total']}")
print(f"high: {intra_miso_counts['high_correct']} / {intra_miso_counts['high_total']}: {intra_miso_counts['high_correct'] / intra_miso_counts['high_total']}")
print(f"all: {intra_miso_counts['all_correct']} / {intra_miso_counts['all_total']}: {intra_miso_counts['all_correct'] / intra_miso_counts['all_total']}")
print()
print("MISO on BART low-conf")
print(f"low: {inter_miso_bart_counts['low_correct']} / {inter_miso_bart_counts['low_total']}: {inter_miso_bart_counts['low_correct'] / inter_miso_bart_counts['low_total']}")
print(f"high: {inter_miso_bart_counts['high_correct']} / {inter_miso_bart_counts['high_total']}: {inter_miso_bart_counts['high_correct'] / inter_miso_bart_counts['high_total']}")
print(f"all: {inter_miso_bart_counts['all_correct']} / {inter_miso_bart_counts['all_total']}: {inter_miso_bart_counts['all_correct'] / inter_miso_bart_counts['all_total']}")
print()

# BART report
print("BART")
print(f"low: {intra_bart_counts['low_correct']} / {intra_bart_counts['low_total']}: {intra_bart_counts['low_correct'] / intra_bart_counts['low_total']}")
print(f"high: {intra_bart_counts['high_correct']} / {intra_bart_counts['high_total']}: {intra_bart_counts['high_correct'] / intra_bart_counts['high_total']}")
print(f"all: {intra_bart_counts['all_correct']} / {intra_bart_counts['all_total']}: {intra_bart_counts['all_correct'] / intra_bart_counts['all_total']}")
print()
print("BART on MISO low-conf")
print(f"low: {inter_bart_miso_counts['low_correct']} / {inter_bart_miso_counts['low_total']}: {inter_bart_miso_counts['low_correct'] / inter_bart_miso_counts['low_total']}")
print(f"high: {inter_bart_miso_counts['high_correct']} / {inter_bart_miso_counts['high_total']}: {inter_bart_miso_counts['high_correct'] / inter_bart_miso_counts['high_total']}")
print(f"all: {inter_bart_miso_counts['all_correct']} / {inter_bart_miso_counts['all_total']}: {inter_bart_miso_counts['all_correct'] / inter_bart_miso_counts['all_total']}")
print()

MISO-MISO
low: 587 / 2649: 0.22159305398263496
high: 9665 / 10847: 0.8910297778187517
all: 10252 / 13496: 0.7596324836988737

MISO on BART low-conf
low: 89 / 345: 0.2579710144927536
high: 10163 / 13151: 0.7727929435023952
all: 10252 / 13496: 0.7596324836988737

BART
low: 104 / 345: 0.30144927536231886
high: 11043 / 13151: 0.8397080069956657
all: 11147 / 13496: 0.8259484291641968

BART on MISO low-conf
low: 1515 / 2649: 0.5719139297848245
high: 9632 / 10847: 0.8879874619710519
all: 11147 / 13496: 0.8259484291641968

