In [127]:
import itertools
import operator
import sys
import os
import numpy as np
from tqdm.notebook import tqdm
sys.path.append("../src")
from glob import glob
import pandas as pd
import json
from matplotlib import pyplot as plt

from neuraldb.scoring.r_precision import f1, recall


In [132]:
search_root = "/checkpoint/jth/job_staging/neuraldb_expts/experiment=ssg_final"
checkpoint_name = "metrics_test.json"
files = glob("{}*/**/{}".format(search_root,checkpoint_name), recursive=True)

print(len(files))

27


In [133]:

def expand(idx,chunk):
  #elif idx == 1:
  #  return ["experiment={}".format(chunk)]
  if chunk.startswith("seed-"):
    return ["seed={}".format(chunk.replace("seed-",""))]
  elif "," in chunk:
    return chunk.split(",")
  elif "=" in chunk:
    return [chunk]

  return []

experiments = []
for file in files:
    chunks = file.split("/")
    chunks = itertools.chain(*[expand(idx, chunk) for idx, chunk in enumerate(chunks)])

    data = {k:v for k,v in (chunk.split("=") for chunk in chunks)}
    data["file"] = file
    data['dir'] = os.path.dirname(file)
    experiments.append(data)


print(len(experiments))

27


In [146]:
from collections import defaultdict


for experiment in tqdm(experiments):
    all_raw = []
    em = 0.0
    with open(experiment['file']) as f:
        for line in f:
            partial_results = json.loads(line)
            all_raw.extend(partial_results['test']['raw'])

    experiment["EM"] = np.mean([rec[2] for rec in all_raw])
    experiment["raw"] = all_raw

    gold = defaultdict(lambda: defaultdict(list))
    for instance in experiment["raw"]:
        query_type = instance[3]["instance"]["metadata"]["query_type"]
        question = instance[3]["instance"]["question"]

        #if instance[3]["instance"]["metadata"]["query_type"] in  {"min/max", "count", "set"}:
        question = question+"_{}".format(instance[3]['instance']['db_id'])
        gold[query_type][question].append((instance[0], instance[1]))


    aem = 0
    aem_count = 0

    scores = defaultdict(int)
    counts = defaultdict(int)

    count_acts = []
    count_preds = []
    for t, questions in gold.items():

        if t in {"atomic_boolean","join_boolean", "atomic_extractive","join_extractive"}:
            for question, answers in questions.items():
                correct = set()
                predicted = set()

                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]":
                        predicted.add(answer[0])

                    if answer[1] != "[NULL_ANSWER]":
                        correct.add(answer[1])

                aem_count+=1
                #counts[t]+=1

                aem +=1.0 if f1(correct,predicted) == 1.0 else 0
                #scores[t] += 1.0 if f1(correct,predicted) == 1.0 else 0


                scores[t+"_recall"] += recall(correct,predicted)
                scores[t+"_f1"] += f1(correct,predicted)
                scores[t+"_em"] += 1.0 if f1(correct,predicted) == 1.0 else 0

                counts[t+"_recall"]+=1
                counts[t+"_f1"]+=1
                counts[t+"_em"]+=1

                if "atomic" in t:
                    counts["all_atomic"] += 1
                    scores["all_atomic"] += 1.0 if f1(correct,predicted) == 1.0 else 0
                elif "join" in t:
                    counts["all_join"] += 1
                    scores["all_join"] += 1.0 if f1(correct,predicted) == 1.0 else 0

                    # if answer[0] == "[NULL_ANSWER]":
                    #     continue
                    # aem_count +=1
                    # counts[t] += 1
                    #
                    # if answer[0] == answer[1]:
                    #     aem += 1
                    #     scores[t]+=1
        elif t == "min/max":

            for question, answers in questions.items():

                argmin_aggr_gold = defaultdict(list)
                argmin_aggr_pred = defaultdict(list)
                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]" and "[LIST]" in answer[0]:
                        key,value = answer[0].split("[LIST]",maxsplit=1)
                        argmin_aggr_pred[key.strip()] = value.strip()

                    if answer[1] != "[NULL_ANSWER]":
                        key,value = answer[1].split("[LIST]",maxsplit=1)
                        argmin_aggr_gold[key.strip()] = value.strip()

                min_item_gold = sorted(argmin_aggr_gold.items(),key=lambda item: item[1])
                min_item_pred = sorted(argmin_aggr_pred.items(),key=lambda item: item[1])

                max_item_gold = sorted(argmin_aggr_gold.items(),key=lambda item: item[1],reverse=True)
                max_item_pred = sorted(argmin_aggr_pred.items(),key=lambda item: item[1],reverse=True)

                aem_count +=1
                counts[t] +=1
                if len(min_item_pred) and len(max_item_pred):
                    if min_item_gold[0][0] == min_item_pred[0][0] or max_item_gold[0][0] == max_item_pred[0][0]:
                        aem+=1
                        scores[t]+=1


        elif t == "set":

            for question, answers in questions.items():
                set_gold = set()
                set_pred = set()
                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]":
                        set_pred.add(answer[0].strip())

                    if answer[1] != "[NULL_ANSWER]":
                        set_gold.add(answer[1].strip())

                aem_count +=1
                counts[t] +=1
                aem += f1(set_gold, set_pred)
                scores[t] += f1(set_gold, set_pred)


        elif t == "count":
            for question, answers in questions.items():

                set_gold = set()
                set_pred = set()
                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]" and answer[0] != "0":
                        set_pred.add(answer[0].strip())

                    if answer[1] != "[NULL_ANSWER]" and answer[1] != "0":
                        set_gold.add(answer[1].strip())


                aem_count +=1
                aem += 1 if len(set_gold) == len(set_pred) else 0
                scores[t] += 1 if len(set_gold) == len(set_pred) else 0
                counts[t] += 1

                count_acts.append(len(set_gold))
                count_preds.append(len(set_pred))


    ape = sum([abs(act-pred)/act if act else 0 for act,pred in zip(count_acts,count_preds)])/len(count_preds)*100
    aae = sum([abs(act-pred) if act else 0 for act,pred in zip(count_acts,count_preds)])/len(count_preds)
    experiment["A_count_ape"] = ape
    experiment["A_count_aae"] = aae
    print(ape,aae)
    for k,v in counts.items():
        experiment["A_type_{}".format(k)] = scores[k]/v

    experiment["A_EM"] =aem/aem_count

HBox(children=(FloatProgress(value=0.0, max=27.0), HTML(value='')))

9.34199503253212 0.2557544757033248
9.032755394647985 0.26342710997442453
9.934594391116141 0.2672634271099744
8.005957105778027 0.4106217616580311
8.582860427820876 0.4209844559585492
8.396687534131026 0.44041450777202074
8.67185741226662 0.23657289002557544
8.71863962976495 0.2391304347826087
9.229138295634463 0.2647058823529412
9.774904691784494 0.2710997442455243
9.482812228975918 0.2531969309462916
9.952761078080774 0.2672634271099744
9.722628181707474 0.2557544757033248
9.676251923694384 0.2608695652173913
9.789007170720733 0.26214833759590794
8.90842353097514 0.37176165803108807
7.063453962017488 0.36658031088082904
8.125953774131526 0.3484455958549223
7.375270225997397 0.3963730569948187
7.3354463271299934 0.34974093264248707
7.6611592817637435 0.41580310880829013
10.22284512705969 0.27458492975734355
8.303335305251016 0.23499361430395913
8.681814495990746 0.24265644955300128
47.884210652409855 0.859514687100894
110.79142086804924 1.9118773946360152
56.15459466034176 0.97956577

In [138]:
results = pd.DataFrame(experiments).fillna(0)

cols = {col:[np.mean, np.std] for col in filter(lambda col: col == "A_EM" or col == "EM" or col.startswith("A_") or col.startswith("prop_") or col.startswith("type_"),results.columns)}
cols.update({col:[np.max] for col in filter(lambda col: col.startswith("count_type_negative"), results.columns)})
breakdown_cols = list(filter(lambda col: col.startswith("prop_"),results.columns))
type_cols = list(filter(lambda col: (col.startswith("type_") and "negative" not in col) or col.startswith("x"),results.columns))
print(type_cols)
type_cols2 = list(filter(lambda col: col.startswith("type_") and "negative" not in col or col == "x_avg_negative",results.columns))
a_type_cols = list(filter(lambda col: col.startswith("A_type_") and "negative" not in col or col == "x_avg_negative",results.columns))
type_cols3 = list(filter(lambda col: "count" not in col and  "negative" in col,results.columns))
type_cols4 = list(filter(lambda col: col.startswith("type_") and "negative" not in col,results.columns))
breakdown = pd.pivot_table(results, index=["experiment","model","lr", "method"],columns=[],aggfunc=cols)
#pd.option_context("display.max_rows",None)
pd.options.display.max_rows = 150
pd.options.display.max_columns = 150
breakdown


[]


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,A_EM,A_EM,A_count_ape,A_count_ape,A_type_all_atomic,A_type_all_atomic,A_type_all_join,A_type_all_join,A_type_atomic_boolean_em,A_type_atomic_boolean_em,A_type_atomic_boolean_f1,A_type_atomic_boolean_f1,A_type_atomic_boolean_recall,A_type_atomic_boolean_recall,A_type_atomic_extractive_em,A_type_atomic_extractive_em,A_type_atomic_extractive_f1,A_type_atomic_extractive_f1,A_type_atomic_extractive_recall,A_type_atomic_extractive_recall,A_type_count,A_type_count,A_type_join_boolean_em,A_type_join_boolean_em,A_type_join_boolean_f1,A_type_join_boolean_f1,A_type_join_boolean_recall,A_type_join_boolean_recall,A_type_join_extractive_em,A_type_join_extractive_em,A_type_join_extractive_f1,A_type_join_extractive_f1,A_type_join_extractive_recall,A_type_join_extractive_recall,A_type_min/max,A_type_min/max,A_type_set,A_type_set,EM,EM
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
experiment,model,lr,method,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2,Unnamed: 33_level_2,Unnamed: 34_level_2,Unnamed: 35_level_2,Unnamed: 36_level_2,Unnamed: 37_level_2,Unnamed: 38_level_2,Unnamed: 39_level_2,Unnamed: 40_level_2,Unnamed: 41_level_2,Unnamed: 42_level_2,Unnamed: 43_level_2
ssg_final,t5-base,0.0004,both,0.868217,0.004944,24.936061,1.672212,0.89084,0.00904,0.691304,0.027152,0.90995,0.011202,0.962687,0.004422,0.996269,0.001974,0.847875,0.004223,0.92646,0.0045,0.989933,0.004439,0.806479,0.009598,0.755411,0.007498,0.893939,0.004329,0.97619,0.009919,0.561404,0.067521,0.817105,0.029979,0.986842,0.0,1.0,0.0,0.909677,0.000955,0.932969,0.001253
ssg_final,t5-base,0.0004,ds,0.867029,0.001569,26.342711,0.586007,0.89239,0.002329,0.685507,0.013283,0.914925,0.003253,0.964179,0.000746,0.996269,0.000746,0.841723,0.003493,0.921839,0.002847,0.987136,0.000969,0.800512,0.003383,0.78355,0.009919,0.901876,0.004506,0.974026,0.006494,0.486842,0.02279,0.786404,0.002321,0.986842,0.0,1.0,0.0,0.907696,0.001547,0.933806,0.001185
ssg_final,t5-base,0.0004,duo_both,0.82472,0.015811,129.118774,58.540919,0.957062,0.008304,0.866667,0.022222,0.971166,0.00184,0.976074,0.002075,0.986503,0.00184,0.918618,0.026878,0.92159,0.027514,0.928094,0.028131,0.576841,0.05538,0.914493,0.021447,0.921256,0.028115,0.93913,0.030123,0.737255,0.02449,0.763399,0.015849,0.780392,0.013585,0.846939,0.079696,0.676367,0.04025,0.753972,0.041289
ssg_final,t5-base,0.0004,duo_ds,0.934104,0.003107,25.117071,2.076415,0.984141,0.000686,0.910053,0.006608,0.988548,0.000937,0.989775,0.000937,0.996319,0.000613,0.972129,0.002554,0.973987,0.00307,0.978261,0.003344,0.804172,0.015343,0.926087,0.008696,0.929952,0.008855,0.943478,0.007531,0.866667,0.017971,0.869281,0.019345,0.870588,0.020377,1.0,0.0,0.923085,0.002631,0.972587,0.000397
ssg_final,t5-base,0.0004,duo_supervised,0.945773,0.001063,38.816926,3.341199,0.985955,0.002469,0.971609,0.003155,0.988569,0.001968,0.989658,0.001712,0.993672,0.000935,0.978818,0.003862,0.981791,0.004221,0.986622,0.003344,0.811744,0.001496,0.985632,0.002489,0.991379,0.002489,0.99569,0.0,0.933333,0.006792,0.938562,0.002264,0.941176,0.0,1.0,0.0,0.946188,0.002079,0.945619,0.000684
ssg_final,t5-base,0.0004,fix_both,0.910832,0.00664,26.55584,0.99328,0.975723,0.006284,0.7,0.041929,0.980348,0.004856,0.988972,0.00137,0.997264,0.001878,0.965324,0.011172,0.980425,0.001957,0.989933,0.00605,0.795823,0.009767,0.764069,0.048737,0.899711,0.012683,0.984848,0.009919,0.570175,0.033113,0.812061,0.035099,0.969298,0.007597,1.0,0.0,0.916053,0.005117,0.964665,0.001509
ssg_final,t5-base,0.0004,fix_ds,0.913015,0.000879,26.172208,0.29532,0.978994,0.002088,0.702899,0.015269,0.983831,0.00114,0.990796,0.00114,0.996766,0.000862,0.968121,0.00605,0.978859,0.003594,0.986018,0.003493,0.794544,0.005168,0.766234,0.019481,0.896104,0.009919,0.97619,0.007498,0.574561,0.007597,0.825,0.001519,0.97807,0.007597,1.0,0.0,0.919071,0.000668,0.964497,0.001263
ssg_final,t5-base,0.0004,fix_supervised,0.942161,0.004047,36.658031,1.187196,0.987268,0.003232,0.969506,0.001821,0.991342,0.001636,0.993129,0.001135,0.996289,0.000618,0.97542,0.00811,0.984213,0.0045,0.989209,0.001799,0.793178,0.010863,0.977011,0.002489,0.985632,0.002489,0.989943,0.002489,0.94902,0.006792,0.955817,0.002521,0.964706,0.0,1.0,0.0,0.94561,0.003419,0.945659,0.002227
ssg_final,t5-base,0.0004,supervised,0.89691,0.001814,42.530225,1.512459,0.927289,0.001218,0.806519,0.001821,0.966811,0.001889,0.979317,0.002271,0.994434,0.000618,0.81235,0.007269,0.91386,0.003577,0.974221,0.001038,0.802245,0.005235,0.867816,0.006584,0.939655,0.006584,0.99569,0.00431,0.639216,0.013585,0.844052,0.003644,0.964706,0.0,1.0,0.0,0.930091,0.001403,0.919733,0.000573
