In [1]:
import json
import re 
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd 
from collections import Counter, defaultdict
from scipy import stats

from calibration_metric.vis.calibration_plot import plot_df
from calibration_metric.metric import ECEMetric

plt.rcParams["font.family"] = "Nimbus Roman"

from calibration_utils import (read_nucleus_file, 
                                read_gold_file,get_probs_and_accs, 
                                read_benchclamp_file, 
                                get_probs_and_accs_benchclamp,
                                get_probs_and_accs_sql,
                                get_accs_sql)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:


bart_data = read_benchclamp_file("/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221101T105421.jsonl") 
# bart_min_probs, bart_mean_probs, bart_accs = get_probs_and_accs_benchclamp(bart_data) 
calflow_gold_path = "/brtx/601-nvme1/estengel/resources/data/benchclamp/processed/CalFlowV2/test_all.jsonl"
bart_min_probs, bart_mean_probs, bart_exact_accs = get_probs_and_accs_benchclamp(bart_data) # , spider_gold_path) 
input_test_data = read_benchclamp_file("/brtx/601-nvme1/estengel/resources/data/benchclamp/processed/CalFlowV2/test_all.jsonl")


t5_data = read_benchclamp_file("/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-large-lm-adapt_calflow_last_user_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20221102T103315.jsonl") 
# t5_min_probs, t5_mean_probs, t5_accs = get_probs_and_accs_benchclamp(t5_data)
t5_min_probs, t5_mean_probs, t5_exact_accs = get_probs_and_accs_benchclamp(t5_data) #, spider_gold_path)

In [5]:
import numpy as np
np.random.seed(12)
ece_metric = ECEMetric(n_bins=20, binning_strategy="adaptive")

(min_values_em, 
min_bins, 
min_bin_number) = ece_metric.adaptive_bin(bart_min_probs, bart_exact_accs)

data_by_bin = defaultdict(list)
for i, datum in enumerate(bart_data):
    input_datum = input_test_data[i]
    bin_number = min_bin_number[i]
    bin_confidence = min_bins[bin_number]
    bin_acc = min_values_em[bin_number]

    data_by_bin[bin_number].append((bin_confidence, bin_acc, datum, input_datum))


for bin_num in data_by_bin.keys():
    bin_conf = data_by_bin[bin_num][0][0]
    print(bin_num, bin_conf)
    bin_conf_str = f"{bin_conf:.2f}"
    bin_str = f"{bin_num}_{bin_conf_str}"
    # write the inputs to a file for later analysis 
    with open(f"spider_test_by_bart_bin/{bin_str}.jsonl","w") as f1:
        for (_, _, datum, input_datum) in data_by_bin[bin_num]:
            f1.write(json.dumps(input_datum) + "\n")


# sample
sample_n = 10
for bin_num, cands in data_by_bin.items():
    cand_idxs = [i for i in range(len(cands))]
    sample_idxs = np.random.choice(cand_idxs, size=sample_n, replace=False)
    examples = [cands[i] for i in sample_idxs]
    print(f"Bin number: {bin_num}")
    for (conf, acc, datum, __) in examples:
        full_input = datum['test_datum_natural']
        print(f"\tConfidence: {conf}, Acc: {acc}")
        print(f"\tInput: {full_input}")
        print(f"\tPred Output: {datum['outputs'][0]}")
        print(f"\tGold Output: {datum['test_datum_canonical']}")
        print(f"\tCorrect: {datum['metrics']['exact_match/top1']}")
        print()



0 0.9595875112258064
1 0.9488612507376211
2 0.9373496615935393
3 0.9182750658179325
4 0.8860958020737665
5 0.8494594551585415
6 0.8036726988478053
7 0.7555402719688629
8 0.7093986850678904
9 0.6607705418283155
10 0.6100244892142395
11 0.5589441292817081
12 0.5082299789746288
13 0.4405896026166398
14 0.35963488882527367
15 0.2270245195474269
Bin number: 0
	Confidence: 0.9595875112258064, Acc: 0.9420289855072463
	Input: Oops I meant can you find the rundown event in June? | The event matching "rundown" in June is on June 10th at 1:00 PM. | when is my build a fire
	Pred Output: (Yield (Event.start (singleton (QueryEventResponse.results (FindEventWrapperWithDefaults (Event.subject_? (?~= "build a fire")))))))
	Gold Output: (Yield (Event.start (singleton (QueryEventResponse.results (FindEventWrapperWithDefaults (Event.subject_? (?~= "build a fire")))))))
	Correct: correct

	Confidence: 0.9595875112258064, Acc: 0.9420289855072463
	Input: update my calendar | I can help you create, update, an

## Initial observations
- 