In [1]:
import json
import re 
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd 
from collections import Counter, defaultdict
from scipy import stats

from calibration_metric.vis.calibration_plot import plot_df
from calibration_metric.metric import ECEMetric

plt.rcParams["font.family"] = "Nimbus Roman"

from calibration_utils import (read_nucleus_file, 
                                read_gold_file,get_probs_and_accs, 
                                read_benchclamp_file, 
                                get_probs_and_accs_benchclamp,
                                get_probs_and_accs_sql,
                                get_accs_sql)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:


bart_data = read_benchclamp_file("/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/bart-large_spider_past_none_db_val_all_0.0001_5000_test_eval_unconstrained-beam_bs_5/model_outputs.20230208T031316.jsonl")
# bart_min_probs, bart_mean_probs, bart_accs = get_probs_and_accs_benchclamp(bart_data) 
spider_gold_path = "/brtx/601-nvme1/estengel/resources/data/benchclamp/processed/Spider/test_all.jsonl"
bart_min_probs, bart_mean_probs, bart_exact_accs = get_probs_and_accs_benchclamp(bart_data) # , spider_gold_path) 
input_test_data = read_benchclamp_file("/brtx/601-nvme1/estengel/resources/data/benchclamp/processed/Spider/test_all.jsonl")


t5_data = read_benchclamp_file("/brtx/602-nvme1/estengel/calflow_calibration/benchclamp/logs/1.0/t5-base-lm-adapt_spider_past_none_db_val_all_0.0001_10000_test_eval_unconstrained-beam_bs_5/model_outputs.20230206T093954.jsonl") 
# t5_min_probs, t5_mean_probs, t5_accs = get_probs_and_accs_benchclamp(t5_data)
t5_min_probs, t5_mean_probs, t5_exact_accs = get_probs_and_accs_benchclamp(t5_data) #, spider_gold_path)

In [8]:
import numpy as np
np.random.seed(12)
ece_metric = ECEMetric(n_bins=20, binning_strategy="adaptive")

(min_values_em, 
min_bins, 
min_bin_number) = ece_metric.adaptive_bin(bart_min_probs, bart_exact_accs)

data_by_bin = defaultdict(list)
for i, datum in enumerate(bart_data):
    input_datum = input_test_data[i]
    bin_number = min_bin_number[i]
    bin_confidence = min_bins[bin_number]
    bin_acc = min_values_em[bin_number]

    data_by_bin[bin_number].append((bin_confidence, bin_acc, datum, input_datum))


for bin_num in data_by_bin.keys():
    bin_conf = data_by_bin[bin_num][0][0]
    print(bin_num, bin_conf)
    bin_conf_str = f"{bin_conf:.2f}"
    bin_str = f"{bin_num}_{bin_conf_str}"
    # write the inputs to a file for later analysis 
    with open(f"spider_test_by_bart_bin/{bin_str}.jsonl","w") as f1:
        for (_, _, datum, input_datum) in data_by_bin[bin_num]:
            f1.write(json.dumps(input_datum) + "\n")


# sample
sample_n = 10
for bin_num, cands in data_by_bin.items():
    cand_idxs = [i for i in range(len(cands))]
    sample_idxs = np.random.choice(cand_idxs, size=sample_n, replace=False)
    examples = [cands[i] for i in sample_idxs]
    print(f"Bin number: {bin_num}")
    for (conf, acc, datum, __) in examples:
        full_input = datum['test_datum_natural']
        print(f"\tConfidence: {conf}, Acc: {acc}")
        print(f"\tInput: {full_input}")
        print(f"\tPred Output: {datum['outputs'][0]}")
        print(f"\tGold Output: {datum['test_datum_canonical']}")
        print(f"\tCorrect: {datum['metrics']['exact_match/top1']}")
        print()



0 0.7829297644731674
1 0.7445562160616894
2 0.695669958393968
3 0.6188248897252094
4 0.5245958175969945
5 0.43645621017136105
6 0.2872891378969516
Bin number: 0
	Confidence: 0.7829297644731674, Acc: 0.37037037037037035
	Input:  ,  | pets_1 | student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype ( cat ) , pet_age , weight , What major is every student who does not own a cat as a pet, and also how old are they?
	Pred Output: SELECT major , age FROM student EXCEPT SELECT T1 . major , T1 . age FROM student AS T1 JOIN has_pet AS T2 ON T1 . stuid = T2 . stuid JOIN pets AS T3 ON T2 . petid = T3 . petid WHERE T3 . pettype = "cat"
	Gold Output: SELECT major , age FROM student WHERE stuid NOT IN ( SELECT T1 . stuid FROM student AS T1 JOIN has_pet AS T2 ON T1 . stuid = T2 . stuid JOIN pets AS T3 ON T3 . petid = T2 . petid WHERE T3 . pettype = 'cat' )
	Correct: incorrect

	Confidence: 0.7829297644731674, Acc: 0.3703703703703703

## Initial observations
- at higher confidence (>0.5)
    - A common source of exact match error is a mismatch in quotation mark style (predicts ", reference has ')
    - sometimes also missing semi-colon causes error 
    - Another common source: model not making use of info in the input context, e.g. over-riding table names, etc., hallucinating values (especially problems with countries and places e.g. "Kanghanistan" instead of "Kabul", "Angolance" instead of "Angola")
    - background knowledge required (beyond what's in the prompt). For example, input "What are the names of conductors who have conducted at more than one orchestra?" has gold program that groups by conductor_id instead of conductor_name. That's probably a smarter thing to do, but conductor_name isn't necessarily wrong (unless multiple conductors have the same name). Without knowing more about the assumptions underlying the database, no way to say that this is definitively wrong. 
- at lower confidence (<0.5)
   - seems like more straight up mistakes, seems like lots with overly-simple programs
   - maybe more relative clauses/multi-hop queries in the input ("What are the names of high schoolers who have likes, and how many likes does each have?")
   - 


## Ambiguous examples:

 - Input:  ,  | tvshow | tv_channel : id , series_name , country , language , content , pixel_aspect_ratio_par , hight_definition_tv , pay_per_view_ppv , package_option | tv_series : id , episode , air_date , rating , share , 18_49_rating_share , viewers_m , weekly_rank , channel | cartoon : id , title , directed_by , written_by , original_air_date , production_code , channel , List the number of different series names and contents in the TV Channel table.
 - Pred Output: SELECT count ( DISTINCT series_name ) , content FROM tv_channel
 - Gold Output: SELECT count ( DISTINCT series_name ) , count ( DISTINCT content ) FROM tv_channel ;`

  - Input:  ,  | student_transcripts_tracking | addresses : address_id , line_1 , line_2 , line_3 , city , zip_postcode , state_province_county , country , other_address_details | courses : course_id , course_name , course_description , other_details | departments : department_id , department_name , department_description , other_details | degree_programs : degree_program_id , department_id , degree_summary_name , degree_summary_description , other_details | sections : section_id , course_id , section_name , section_description , other_details | semesters : semester_id , semester_name , semester_description , other_details | students : student_id , current_address_id , permanent_address_id , first_name , middle_name , last_name , cell_mobile_number , email_address , ssn , date_first_registered , date_left , other_student_details | student_enrolment : student_enrolment_id , degree_program_id , semester_id , student_id , other_details | student_enrolment_courses : student_course_id , course_id , student_enrolment_id | transcripts : transcript_id , transcript_date , other_details | transcript_contents : student_course_id , transcript_id , What are the names and id of courses having at most 2 sections?
  - Pred Output: SELECT T1 . course_name , T2 . course_id FROM courses AS T1 JOIN sections AS T2 ON T1 . course_id = T2 . course_id GROUP BY T2 . course_id HAVING count ( * ) <= 2
  - Gold Output: SELECT T1 . course_name , T1 . course_id FROM courses AS T1 JOIN sections AS T2 ON T1 . course_id = T2 . course_id GROUP BY T1 . course_id HAVING count ( * ) <= 2