In [1]:
import json
import pathlib
from collections import defaultdict
import numpy as np 
import re 
from dataflow.core.lispress import parse_lispress, render_compact, render_pretty
from dataflow.core.linearize import lispress_to_seq

In [2]:
def read_nucleus_file(miso_pred_file):
    with open(miso_pred_file, "r") as f:
        data = [json.loads(x) for x in f.readlines()]
    to_ret = []
    data_by_idx = defaultdict(list)
    data_by_src_str = defaultdict(list)
    for line in data:
        data_by_src_str[line['src_str']].append(line) 
        data_by_idx[line['line_idx']].append(line) 

    for src_str, lines in data_by_src_str.items():
        total_probs = [np.exp(np.sum(np.log(x['expression_probs']))) 
                                if x['expression_probs'] is not None else 0.0 
                                    for x in lines ]
        mean_probs = [np.mean(x['expression_probs']) 
                                if x['expression_probs'] is not None and np.sum(x['expression_probs']) > 0.0 
                                else 0.0 for x in lines ]
        min_probs = []
        for x in lines:
            if x['expression_probs'] is not None and len(x['expression_probs']) > 0:
                min_probs.append(np.min(x['expression_probs']))
            else:
                min_probs.append(0.0)

        combo_lines = zip(lines, min_probs, mean_probs, total_probs)
        sorted_combo_lines = sorted(combo_lines, key=lambda x: x[-1], reverse=True)

        data_by_src_str[src_str] = sorted_combo_lines
        idx = lines[0]['line_idx']
        data_by_idx[idx] = sorted_combo_lines
    return data_by_src_str, data_by_idx

def read_gold_file(file):
    with open(file) as f:
        if file.endswith(".tgt"):
            to_ret = [render_compact(parse_lispress(line)) for line in f.readlines()]
        else:
            to_ret = [re.sub("__StartOfProgram", "", x).strip() for x in f.readlines()]
    return to_ret 


In [3]:
calflow_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221101T105421.jsonl" 
calflow_t5 = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103315.jsonl"
calflow_miso = "/brtx/604-nvme1/estengel/calflow_calibration/miso/tune_roberta_tok_fix_benchclamp_data/translate_output_calibrated/test_all.tgt"

calflow_models_and_paths = {"miso": calflow_miso, "bart": calflow_bart, "t5": calflow_t5} 

gold_path = "/brtx/601-nvme1/estengel/resources/data/smcalflow.agent.data.from_benchclamp"

gold_src = read_gold_file(f"{gold_path}/test_all.src_tok")
gold_tgt = read_gold_file(f"{gold_path}/test_all.tgt")
gold_idx = read_gold_file(f"{gold_path}/test_all.idx")
gold_tgt_by_idx = {idx: gold for idx, gold in zip(gold_idx, gold_tgt)}



calflow_data_by_model = {}

__, nuc_data = read_nucleus_file(calflow_miso)
calflow_data_by_model["miso"] = nuc_data

with open(calflow_bart) as f1:
    bart_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['bart'] = bart_data

with open(calflow_t5) as f1:
    t5_data = [json.loads(x) for x in f1.readlines()]
calflow_data_by_model['t5'] = t5_data

# treedst_bart = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_tree_dst_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103357.jsonl" 
# not ready yet
# treedst_t5 = "/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_tree_dst_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/"



In [4]:
def get_low_prob(iterator, is_miso = False, threshold = 0.6):
    low_prob_idxs = []
    for idx, example in iterator:
        if is_miso: 
            try:
                min_prob = example[0][1]
            except:
                min_prob = np.min(example[0]['expression_probs'])
        else:
            probs = np.exp(np.array(example['token_logprobs'][0]))
            min_prob = np.min(probs)

        if min_prob < threshold:
            low_prob_idxs.append(idx) 
    return low_prob_idxs

low_idxs_by_model = {}
for model, data in calflow_data_by_model.items():
    if model == "miso":
        low_prob_idxs = get_low_prob(data.items(), is_miso=True)
    else:
        low_prob_idxs = get_low_prob(zip(gold_idx, data), is_miso=False)
    low_idxs_by_model[model] = low_prob_idxs

done = []
for model_a in low_idxs_by_model.keys():
    for model_b in low_idxs_by_model.keys():
        if model_a == model_b or (model_a, model_b) in done or (model_b, model_a) in done:
            continue
        done.append((model_a, model_b))

        # compute the overlap between the two 
        intersection = set(low_idxs_by_model[model_a]).intersection(set(low_idxs_by_model[model_b]))
        union = set(low_idxs_by_model[model_a]).union(set(low_idxs_by_model[model_b]))
        print(f"{model_a}: {len(low_idxs_by_model[model_a])/13496*100:.2f}%")
        print(f"{model_b}: {len(low_idxs_by_model[model_b])/13496*100:.2f}%")
        print(f"intersection of {model_a} and {model_b}: {len(intersection)} / {len(union)}: {len(intersection) / len(union):.2f}")

miso: 22.24%
bart: 6.09%
intersection of miso and bart: 562 / 3261: 0.17
miso: 22.24%
t5: 8.49%
intersection of miso and t5: 728 / 3419: 0.21
bart: 6.09%
t5: 8.49%
intersection of bart and t5: 312 / 1656: 0.19


In [5]:
def get_counts(data, low_idxs, is_miso=False, gold_idxs = None, gold_tgt_by_idx=None): 
    counts = {"low_correct": 0, "low_total": 0, "all_correct": 0, "all_total": 0, "high_correct": 0, "high_total": 0}
    if is_miso:
        enumerator = list(data.items())
    else:
        enumerator = list(zip(gold_idxs, data))

    for idx, example in enumerator:
        if is_miso: 
            try:
                pred_str = example[0][0]['tgt_str']
            except KeyError:
                pred_str = example[0]['tgt_str']
            gold_str = gold_tgt_by_idx[idx]
        else:
            pred_str = example['outputs'][0]
            gold_str = example['test_datum_canonical']
        try:
            pred_tgt = render_compact(parse_lispress(pred_str))
        except (AssertionError, IndexError) as e:
            pred_tgt = "(Error)"

        gold_tgt = render_compact(parse_lispress(gold_str))

        is_correct = pred_tgt == gold_tgt
        if idx in low_idxs:
            counts["low_total"] += 1
            if is_correct:
                counts["low_correct"] += 1
        else:
            counts["high_total"] += 1
            if is_correct:
                counts["high_correct"] += 1
        counts["all_total"] += 1
        if is_correct:
            counts["all_correct"] += 1
    return counts 


In [6]:

counts_by_model_pair = {}
done = []
for model_a in calflow_data_by_model.keys():
    for model_b in calflow_data_by_model.keys():
        print(f"evaluating {model_a} on low conf. from {model_b}")
        if (model_a, model_b) in done: 
            continue
        done.append((model_a, model_b))
        counts_by_model_pair[(model_a, model_b)] = get_counts(calflow_data_by_model[model_a], 
                                                              low_idxs_by_model[model_b], 
                                                              is_miso=model_a == "miso", 
                                                              gold_idxs=gold_idx, 
                                                              gold_tgt_by_idx=gold_tgt_by_idx)


evaluating miso on low conf. from miso
evaluating miso on low conf. from bart
evaluating miso on low conf. from t5
evaluating bart on low conf. from miso
evaluating bart on low conf. from bart
evaluating bart on low conf. from t5
evaluating t5 on low conf. from miso
evaluating t5 on low conf. from bart
evaluating t5 on low conf. from t5


In [7]:
# print report 
for (model_a, model_b), count_data in counts_by_model_pair.items():
    print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
    print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
    print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
    print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
    print()

MISO on low conf. from MISO
low: 798 / 3001: 26.59
high: 9454 / 10495: 90.08
all: 10252 / 13496: 75.96

MISO on low conf. from BART
low: 239 / 822: 29.08
high: 10013 / 12674: 79.00
all: 10252 / 13496: 75.96

MISO on low conf. from T5
low: 390 / 1146: 34.03
high: 9862 / 12350: 79.85
all: 10252 / 13496: 75.96

BART on low conf. from MISO
low: 1726 / 3001: 57.51
high: 9421 / 10495: 89.77
all: 11147 / 13496: 82.59

BART on low conf. from BART
low: 258 / 822: 31.39
high: 10889 / 12674: 85.92
all: 11147 / 13496: 82.59

BART on low conf. from T5
low: 516 / 1146: 45.03
high: 10631 / 12350: 86.08
all: 11147 / 13496: 82.59

T5 on low conf. from MISO
low: 1634 / 3001: 54.45
high: 9390 / 10495: 89.47
all: 11024 / 13496: 81.68

T5 on low conf. from BART
low: 304 / 822: 36.98
high: 10720 / 12674: 84.58
all: 11024 / 13496: 81.68

T5 on low conf. from T5
low: 381 / 1146: 33.25
high: 10643 / 12350: 86.18
all: 11024 / 13496: 81.68



In [8]:
three_way_union = set(low_idxs_by_model['miso']).union(set(low_idxs_by_model['t5'])).union(set(low_idxs_by_model['bart']))
print(f"Size of three-way union: {len(three_way_union)}")
for model_a in calflow_data_by_model.keys():
    counts_by_model_pair[(model_a, 'union')] = get_counts(calflow_data_by_model[model_a], 
                                                            three_way_union,
                                                            is_miso=model_a == "miso", 
                                                            gold_idxs=gold_idx, 
                                                            gold_tgt_by_idx=gold_tgt_by_idx)
for (model_a, model_b), count_data in counts_by_model_pair.items():
    if model_b == "union": 
        print(f"{model_a.upper()} on low conf. from {model_b.upper()}")
        print(f"low: {count_data['low_correct']} / {count_data['low_total']}: {count_data['low_correct'] / count_data['low_total']*100:.2f}")
        print(f"high: {count_data['high_correct']} / {count_data['high_total']}: {count_data['high_correct'] / count_data['high_total']*100:.2f}")
        print(f"all: {count_data['all_correct']} / {count_data['all_total']}: {count_data['all_correct'] / count_data['all_total']*100:.2f}")
        print()

Size of three-way union: 3595
MISO on low conf. from UNION
low: 1129 / 3595: 31.40
high: 9123 / 9901: 92.14
all: 10252 / 13496: 75.96

BART on low conf. from UNION
low: 2020 / 3595: 56.19
high: 9127 / 9901: 92.18
all: 11147 / 13496: 82.59

T5 on low conf. from UNION
low: 1892 / 3595: 52.63
high: 9132 / 9901: 92.23
all: 11024 / 13496: 81.68



In [9]:
np.random.seed(12)
union_idxs = np.random.choice(list(union), size=20, replace=False)


for idx, src, tgt in zip(gold_idx, gold_src, gold_tgt):
    if idx in union_idxs:
        tgt = parse_lispress(tgt)
        tgt = render_pretty(tgt)
        print(f"{idx}: {src} -> \n {tgt}")
        print()

1564: __User I need you to change the name of my Baking event to The Big Cook . __Agent How is this ? __User Whoops , I meant change the Cooking event -> 
 (Yield
  (Execute
    (ReviseConstraint
      (refer (^(Dynamic) roleConstraint (Path.apply " output ")))
      (^((CalflowIntension (Constraint Event))) QueryEventIntensionConstraint)
      (Event.subject_? (?~= " Cooking ")))))

1772: __User Set a reminder to call my Mom on Saturday . -> 
 (FenceReminder)

2141: __User What is their email address ? __Agent The email address of Damon Straeter is dstraetor@thenextunicorn.com . __User please find the my mail account -> 
 (FenceSwitchTabs)

2328: __User show me my luncheon with david __Agent The " Luncheon " is on December 24 th from 11 : 00 to 11 : 30 AM . __User Change luncheon time to noon to 2 pm . -> 
 (let
  (x0
    (singleton
      (QueryEventResponse.results
        (FindEventWrapperWithDefaults
          (Event.subject_? (?~= " luncheon ")))))
    x1
    (DateAtTimeWithDefaul