In [2]:
import itertools
import operator
import sys
import os
import numpy as np
from tqdm.notebook import tqdm
sys.path.append("../src")
from glob import glob
import pandas as pd
import json
from matplotlib import pyplot as plt

from neuraldb.scoring.r_precision import f1


In [3]:
search_root = "/checkpoint/jth/job_staging/neuraldb_expts/experiment=oracle_d3"
checkpoint_name = "metrics_test.json"
files = glob("{}*/**/{}".format(search_root,checkpoint_name), recursive=True)

print(len(files))

3


In [4]:

def expand(idx,chunk):
  #elif idx == 1:
  #  return ["experiment={}".format(chunk)]
  if chunk.startswith("seed-"):
    return ["seed={}".format(chunk.replace("seed-",""))]
  elif "," in chunk:
    return chunk.split(",")
  elif "=" in chunk:
    return [chunk]

  return []

experiments = []
for file in files:
    chunks = file.split("/")
    chunks = itertools.chain(*[expand(idx, chunk) for idx, chunk in enumerate(chunks)])

    data = {k:v for k,v in (chunk.split("=") for chunk in chunks)}
    data["file"] = file
    data['dir'] = os.path.dirname(file)
    experiments.append(data)


print(len(experiments))

3


In [None]:
from collections import defaultdict


for experiment in tqdm(experiments):
    all_raw = []
    em = 0.0
    with open(experiment['file']) as f:
        for line in f:
            partial_results = json.loads(line)
            all_raw.extend(partial_results['test']['raw'])


    experiment["EM"] = np.mean([rec[2] for rec in all_raw])
    experiment["raw"] = all_raw

    gold = defaultdict(lambda: defaultdict(list))
    for instance in experiment["raw"]:
        if instance[0] != "[NULL_ANSWER]" or instance[1] != "[NULL_ANSWER]":

            gold[instance[3]["query_type"]][instance[3]["query"]["input"]].append((instance[0], instance[1]))


    aem = 0
    aem_count = 0


    scores = defaultdict(int)
    counts = defaultdict(int)

    for t, questions in gold.items():
        
        if t in {"atomic_boolean","join_boolean", "atomic_extractive","join_extractive"}:
            for question, answers in questions.items():
                print(answers)
                for answer in answers:
                    aem_count +=1
                    counts[t] += 1

                    if answer[0] == answer[1]:
                        aem += 1
                        scores[t]+=1
        elif t == "min/max":

            for question, answers in questions.items():

                argmin_aggr_gold = defaultdict(list)
                argmin_aggr_pred = defaultdict(list)
                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]" and "[LIST]" in answer[0]:
                        key,value = answer[0].split("[LIST]",maxsplit=1)
                        argmin_aggr_pred[key.strip()] = value.strip()

                    if answer[1] != "[NULL_ANSWER]":
                        key,value = answer[1].split("[LIST]",maxsplit=1)
                        argmin_aggr_gold[key.strip()] = value.strip()

                min_item_gold = sorted(argmin_aggr_gold.items(),key=lambda item: len(item[1]))
                min_item_pred = sorted(argmin_aggr_pred.items(),key=lambda item: len(item[1]))

                max_item_gold = sorted(argmin_aggr_gold.items(),key=lambda item: len(item[1]),reverse=True)
                max_item_pred = sorted(argmin_aggr_pred.items(),key=lambda item: len(item[1]),reverse=True)

                aem_count +=1
                counts[t] +=1
                if len(min_item_pred) and len(max_item_pred):
                    if min_item_gold[0][0] == min_item_pred[0][0] or max_item_gold[0][0] == max_item_pred[0][0]:
                        aem+=1
                        scores[t]+=1


        elif t == "set":

            for question, answers in questions.items():
                set_gold = set()
                set_pred = set()
                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]":
                        set_pred.add(answer[0].strip())

                    if answer[1] != "[NULL_ANSWER]":
                        set_gold.add(answer[1].strip())

                aem_count +=1
                counts[t] +=1
                aem += f1(set_gold, set_pred)
                scores[t] += f1(set_gold, set_pred)


        elif t == "count":
            for question, answers in questions.items():

                set_gold = set()
                set_pred = set()
                for answer in answers:
                    if answer[0] != "[NULL_ANSWER]":
                        set_pred.add(answer[0].strip())

                    if answer[1] != "[NULL_ANSWER]":
                        set_gold.add(answer[1].strip())

                aem_count +=1
                aem += 1 if len(set_gold) == len(set_pred) else 0
                scores[t] += 1 if len(set_gold) == len(set_pred) else 0
                counts[t] += 1



    for k,v in counts.items():
        experiment["A_type_{}".format(k)] = scores[k]/v

    experiment["A_EM"] =aem/aem_count

In [6]:
results = pd.DataFrame(experiments).fillna(0)

cols = {col:[np.mean] for col in filter(lambda col: col == "A_EM" or col == "EM" or col.startswith("A_") or col.startswith("prop_") or col.startswith("type_"),results.columns)}
cols.update({col:[np.max] for col in filter(lambda col: col.startswith("count_type_negative"), results.columns)})
breakdown_cols = list(filter(lambda col: col.startswith("prop_"),results.columns))
type_cols = list(filter(lambda col: (col.startswith("type_") and "negative" not in col) or col.startswith("x"),results.columns))
print(type_cols)
type_cols2 = list(filter(lambda col: col.startswith("type_") and "negative" not in col or col == "x_avg_negative",results.columns))
a_type_cols = list(filter(lambda col: col.startswith("A_type_") and "negative" not in col or col == "x_avg_negative",results.columns))
type_cols3 = list(filter(lambda col: "count" not in col and  "negative" in col,results.columns))
type_cols4 = list(filter(lambda col: col.startswith("type_") and "negative" not in col,results.columns))
breakdown = pd.pivot_table(results, index=["experiment","model","lr"],columns=[],aggfunc=cols)
#pd.option_context("display.max_rows",None)
pd.options.display.max_rows = 150
breakdown


[]


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,A_EM,A_type_atomic_boolean,A_type_atomic_extractive,A_type_count,A_type_join_boolean,A_type_join_extractive,A_type_min/max,A_type_set,EM
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean
experiment,model,lr,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
oracle_d3,t5-base,0.0004,0.987229,0.98673,0.984221,0.98773,0.989637,0.988095,1.0,0.998208,0.985377
