In [7]:
import json
import pathlib
from collections import defaultdict
import numpy as np 
import re 
from tqdm import tqdm
from dataflow.core.lispress import parse_lispress, render_compact
from dataflow.core.linearize import lispress_to_seq

In [2]:
def read_nucleus_file(miso_pred_file):
    with open(miso_pred_file, "r") as f:
        data = [json.loads(x) for x in f.readlines()]
    to_ret = []
    data_by_idx = defaultdict(list)
    data_by_src_str = defaultdict(list)
    for line in data:
        data_by_src_str[line['src_str']].append(line) 
        data_by_idx[line['line_idx']].append(line) 

    for src_str, lines in data_by_src_str.items():
        total_probs = [np.exp(np.sum(np.log(x['expression_probs']))) 
                                if x['expression_probs'] is not None else 0.0 
                                    for x in lines ]
        mean_probs = [np.mean(x['expression_probs']) 
                                if x['expression_probs'] is not None and np.sum(x['expression_probs']) > 0.0 
                                else 0.0 for x in lines ]
        min_probs = []
        for x in lines:
            if x['expression_probs'] is not None and len(x['expression_probs']) > 0:
                min_probs.append(np.min(x['expression_probs']))
            else:
                min_probs.append(0.0)

        combo_lines = zip(lines, min_probs, mean_probs, total_probs)
        sorted_combo_lines = sorted(combo_lines, key=lambda x: x[-1], reverse=True)

        data_by_src_str[src_str] = sorted_combo_lines
        idx = lines[0]['line_idx']
        data_by_idx[idx] = sorted_combo_lines
    return data_by_src_str, data_by_idx

def read_gold_file(file):
    with open(file) as f:
        if file.endswith(".tgt"):
            to_ret = [render_compact(parse_lispress(line)) for line in f.readlines()]
        else:
            to_ret = [re.sub("__StartOfProgram", "", x).strip() for x in f.readlines()]
    return to_ret 


In [3]:

calflow_miso = "/brtx/604-nvme1/estengel/calflow_calibration/miso/tune_roberta_tok_fix_benchclamp_data/translate_output_calibrated/dev_all.tgt"

gold_path = "/brtx/601-nvme1/estengel/resources/data/smcalflow.agent.data.from_benchclamp"

gold_src = read_gold_file(f"{gold_path}/dev_all.src_tok")
gold_tgt = read_gold_file(f"{gold_path}/dev_all.tgt")
gold_idx = read_gold_file(f"{gold_path}/dev_all.idx")
gold_tgt_by_idx = {idx: gold for idx, gold in zip(gold_idx, gold_tgt)}

__, nuc_data = read_nucleus_file(calflow_miso)

In [4]:

# go through examples, get min prob and pred tgt 
# compare pred to gold tgt 

def get_f1(nuc_data, gold_tgt_by_idx, cutoff):
    counts = {"tp": 0, "fp": 0, "fn": 0, "tn": 0}
    for idx, ex in nuc_data.items():
        try:
            min_prob = ex[0][1]
        except:
            min_prob = np.min(ex[0]['expression_probs'])
        try:
            pred_str = ex[0][0]['tgt_str']
        except KeyError:
            pred_str = ex[0]['tgt_str']
        gold_str = gold_tgt_by_idx[idx]
        try:
            pred_tgt = render_compact(parse_lispress(pred_str))
        except (AssertionError, IndexError) as e:
            pred_tgt = "(Error)"
        gold_tgt = render_compact(parse_lispress(gold_str))
        is_correct = pred_tgt == gold_tgt 

        if min_prob < cutoff and is_correct:
            counts['fn'] += 1
        elif min_prob > cutoff and is_correct:
            counts['tp'] += 1
        elif min_prob < cutoff and not is_correct:
            counts['tn'] += 1
        elif min_prob > cutoff and not is_correct:
            counts['fp'] += 1 
    precision = counts['tp'] / (counts['tp'] + counts['fp'])
    recall = counts['tp'] / (counts['tp'] + counts['fn'])
    f1 = 2 * (precision * recall) / (precision + recall)
    return f1, precision, recall, counts

In [10]:
all_scores = []
all_thresh = [x/100 for x in list(range(30, 50))]
for thresh in tqdm(all_thresh):
    f1_score = get_f1(nuc_data, gold_tgt_by_idx, thresh)
    all_scores.append(f1_score)

print(all_scores)

100%|██████████| 20/20 [01:48<00:00,  5.41s/it]

[(0.9116738111853254, 0.8681412568824758, 0.9598026868178001, {'tp': 9145, 'fp': 1389, 'fn': 383, 'tn': 1354}), (0.9114770459081838, 0.868816590563166, 0.9585432409739715, {'tp': 9133, 'fp': 1379, 'fn': 395, 'tn': 1364}), (0.911425288504771, 0.8696729907522166, 0.9573887489504618, {'tp': 9122, 'fp': 1367, 'fn': 406, 'tn': 1376}), (0.9115101795808115, 0.8707827582911211, 0.9562342569269522, {'tp': 9111, 'fp': 1352, 'fn': 417, 'tn': 1391}), (0.9114405930675215, 0.8717899578382522, 0.9548698572628044, {'tp': 9098, 'fp': 1338, 'fn': 430, 'tn': 1405}), (0.9118074705439959, 0.8729000671978496, 0.9543450881612091, {'tp': 9093, 'fp': 1324, 'fn': 435, 'tn': 1419}), (0.9122120162626111, 0.8741702741702742, 0.9537153652392947, {'tp': 9087, 'fp': 1308, 'fn': 441, 'tn': 1435}), (0.9122260205107582, 0.8754341952913933, 0.9522460117548279, {'tp': 9073, 'fp': 1291, 'fn': 455, 'tn': 1452}), (0.9121199919468491, 0.8763056092843327, 0.9509865659109992, {'tp': 9061, 'fp': 1279, 'fn': 467, 'tn': 1464}), (0




In [11]:
just_f1 = [x[0] for x in all_scores]
best_idx = np.argmax(just_f1)
best_thresh = all_thresh[best_idx]
best_f1 = just_f1[best_idx]
print(f"Best F1: {best_f1} at {best_thresh}")


Best F1: 0.9122453096631027 at 0.4


In [15]:
print(all_thresh[8:12])
print(just_f1[8:12])
print(best_idx)

[0.38, 0.39, 0.4, 0.41]
[0.9121199919468491, 0.9122329705763804, 0.9122453096631027, 0.9118820380750392]
10
