#### This prepares the `visqa_dataset.pkl` for VisQA dataset by packing up necessary information
- Put it to the repo root before using, since this notebook requires some main tool components.

In [None]:
# path to the original VisQA dataset
data_path = "./benchmarks/VisQA/shared/"

In [None]:
import json
import pickle
import pandas as pd
pd.set_option('display.max_rows', 10)
from io import StringIO
from trinity.utils.visqa import vl_to_svg, svg_to_table, normalize_table, grouped_svg_to_table, vl_to_grouped_svg, parse_value

In [None]:
with open("{}/qadata.json".format(data_path), "r") as f:
    dt = json.load(f)

In [None]:
def get_spec_name(arg_dataset, arg_chart_name):
    if arg_dataset=="kong":
        return arg_chart_name.split("_")[1]
    elif arg_dataset=="d3":
        return arg_chart_name.split("_")[1]
        raise None
    elif arg_dataset=="vega-lite-example-gallery":
        return arg_chart_name
    elif "wikitables" in arg_dataset:
        return arg_chart_name
    else:
        raise NotImplementedError("Unsupported data series, got: {}, {}.".format(arg_dataset, arg_chart_name))

def get_repr_answer(arg_refined_answer):
    if "|" in arg_refined_answer:
        return sorted([parse_value(p) for p in arg_refined_answer.split("|")], key=lambda x:str(x))
    else:
        # parse as a single value
        return [parse_value(arg_refined_answer)]

# start collecting
cached_utils = {}
benchmark_collections = []
tmp_count = 0
for dkey in dt.keys():
    tmp_count += 1
    
#     if tmp_count < 195:
#         continue
    
    print("\r# processing {}/{}, {}".format(tmp_count, len(dt), dkey), end="")
    
    bitem = {}
    bitem["id"] = dkey
    bitem["short_id"] = dkey[:6]
    bitem["query"] = dt[dkey]["question"]
    bitem["answer"] = dt[dkey]["answer"] # original answer from the dataset
    
    # refined_answer: some original answers are wrong, which will be corrected as refined answers
    #                 see block below for detailed refinements
    bitem["refined_answer"] = dt[dkey]["answer"]
    # repr_answer is for quick comparison for accuracy; its data structure should support equality comparison
    bitem["repr_answer"] = get_repr_answer(bitem["refined_answer"])
    
    bitem["data_series"] = dt[dkey]["dataset"]
    
    bitem["spec_name"] = get_spec_name(dt[dkey]["dataset"], dt[dkey]["chartName"])
    # load full spec
    with open("{}/dataset/{}/specs/{}.json".format(data_path, bitem["data_series"], bitem["spec_name"]), "r") as f:
        tmp_spec = json.load(f)
    bitem["spec"] = tmp_spec
    
    bitem["csv_name"] = bitem["spec"]["data"]["url"].split("/")[-1]
    if bitem["csv_name"].endswith(".csv"):
        bitem["csv_name"] = bitem["csv_name"].replace(".csv", "")
        # load csv and remove trailing delimiters
        with open("{}/dataset/{}/data/{}.csv".format(data_path, bitem["data_series"], bitem["csv_name"]), "r") as f:
            tmp_csv = f.readlines()
        for i in range(len(tmp_csv)):
            tmp_csv[i] = tmp_csv[i].strip()
            while tmp_csv[i].endswith(","):
                tmp_csv[i] = tmp_csv[i][:-1]
        bitem["csv"] = "\n".join(tmp_csv)
        bitem["table"] = normalize_table(pd.read_csv(StringIO(bitem["csv"])))
    elif bitem["csv_name"].endswith(".json"):
        # some of the data comes in json format
        # convert to csv
        bitem["csv_name"] = bitem["csv_name"].replace(".json", "")
        with open("{}/dataset/{}/data/{}.json".format(data_path, bitem["data_series"], bitem["csv_name"]), "r") as f:
            tmp_json = json.load(f)
        bitem["table"] = normalize_table(pd.DataFrame.from_records(tmp_json))
        bitem["csv"] = bitem["table"].to_csv(index=False)
    else:
        raise NotImplementedError("Unsupported data file type, got: {}.".format(bitem["csv_name"]))
    
    # update spec
    del bitem["spec"]["data"]["url"]
    bitem["spec"]["data"]["values"] = bitem["table"].to_dict("records")
    
    # render table
    # check cache first
    tmp_key = (bitem["data_series"], bitem["spec_name"])
    if tmp_key in cached_utils.keys():
        bitem["grouped_svg"] = cached_utils[tmp_key]["grouped_svg"]
        bitem["rendered_table"] = cached_utils[tmp_key]["rendered_table"]
    else:
        bitem["grouped_svg"] = vl_to_grouped_svg(bitem["spec"])
        bitem["rendered_table"] = grouped_svg_to_table(bitem["grouped_svg"], bitem["spec"])
        cached_utils[tmp_key] = {}
        cached_utils[tmp_key]["grouped_svg"] = bitem["grouped_svg"]
        cached_utils[tmp_key]["rendered_table"] = bitem["rendered_table"]
        # for debugging only
        cached_utils[tmp_key]["table"] = bitem["table"]
        # merge table to prevent data loss due to spec filter, e.g., kong:15
        # bitem["merged_table"] = pd.merge(bitem["table"], bitem["rendered_table"])
    
    # re-order the column order
    tmp_cs = [p for p in bitem["table"].columns if p in bitem["rendered_table"].columns]
    for q in bitem["rendered_table"].columns:
        if q not in tmp_cs:
            tmp_cs.append(q)
    bitem["rendered_table"] = bitem["rendered_table"][tmp_cs]
    
    # put to collection
    benchmark_collections.append(bitem)
#     break
    
print("\n# done")

In [None]:
# additional patches to apply to data, see README or Google Sheets for details
benchmark_collections[74]["refined_answer"] = "Hindus"
benchmark_collections[74]["repr_answer"] = get_repr_answer("Hindus")

benchmark_collections[83]["refined_answer"] = "Hindus"
benchmark_collections[83]["repr_answer"] = get_repr_answer("Hindus")

# yes, refined_answers are all strings
benchmark_collections[244]["refined_answer"] = "40221893"
benchmark_collections[244]["repr_answer"] = get_repr_answer("40221893")

benchmark_collections[288]["refined_answer"] = "Aug 1 2004"
benchmark_collections[288]["repr_answer"] = get_repr_answer("Aug 1 2004")

benchmark_collections[348]["refined_answer"] = "Aug"
benchmark_collections[348]["repr_answer"] = get_repr_answer("Aug")

benchmark_collections[353]["refined_answer"] = "Jul"
benchmark_collections[353]["repr_answer"] = get_repr_answer("Jul")

benchmark_collections[354]["refined_answer"] = "Feb"
benchmark_collections[354]["repr_answer"] = get_repr_answer("Feb")

benchmark_collections[355]["refined_answer"] = "Aug"
benchmark_collections[355]["repr_answer"] = get_repr_answer("Aug")

# yes, refined_answers are all strings
benchmark_collections[438]["refined_answer"] = "2915000"
benchmark_collections[438]["repr_answer"] = get_repr_answer("2915000")

benchmark_collections[520]["refined_answer"] = "gojō|yoshino"
benchmark_collections[520]["repr_answer"] = get_repr_answer("gojō|yoshino")

benchmark_collections[537]["refined_answer"] = "north-east skåne"
benchmark_collections[537]["repr_answer"] = get_repr_answer("north-east skåne")

benchmark_collections[541]["refined_answer"] = "north-east skåne"
benchmark_collections[541]["repr_answer"] = get_repr_answer("north-east skåne")

benchmark_collections[616]["refined_answer"] = "x-Houston Rockets|x-San Antonio Spurs"
benchmark_collections[616]["repr_answer"] = get_repr_answer("x-Houston Rockets|x-San Antonio Spurs")

In [None]:
len(benchmark_collections)

In [None]:
with open("{}/visqa_dataset.pkl".format(data_path), "wb") as f:
    pickle.dump(benchmark_collections, f)

### inspect the data

In [None]:
import pickle
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 10)

from trinity.utils.visqa import normalize_table

In [None]:
# path to the original VisQA dataset
data_path = "./benchmarks/VisQA/shared/"

In [None]:
with open("{}/visqa_dataset.pkl".format(data_path), "rb") as f:
    dt = pickle.load(f)

In [None]:
for dkey in cached_utils.keys():
    print("==============================")
    print(dkey)
    print(cached_utils[dkey]["table"])
    print(cached_utils[dkey]["rendered_table"].dtypes)
    print(cached_utils[dkey]["rendered_table"])

In [None]:
cached_utils[('kong', '18')]["table"].columns

In [None]:
cached_utils[('kong', '18')]["rendered_table"].columns

In [None]:
for i in range(len(dt)):
    # display(dt[i]["rendered_table"])
    print("i={}".format(i))
    print("{}".format(dt[i]["rendered_table"].dtypes))
    print("{}".format(dt[i]["rendered_table"]))
    # input("i={}".format(i))