#### This prepares the `tapas_on_visqa_inputs.pkl` as the input to the TaPas model
- Put it to the repo root before using, since this notebook requires some main tool components.
- This requires the `visqa_dataset.pkl` ready before running.

In [None]:
import pickle

In [None]:
# path to the original VisQA dataset
data_path = "./benchmarks/VisQA/shared/"

In [None]:
with open("{}/visqa_dataset.pkl".format(data_path), "rb") as f:
    dt = pickle.load(f)

In [None]:
def format_table(arg_df):
    return arg_df.to_markdown(index=False,tablefmt="jira",numalign="left").replace("||","|").replace("|\n|","\n").strip("|")

In [None]:
# construct inputs for every benchmark
tapas_inputs = []
for i in range(len(dt)):
    p = dt[i]
    str_table = format_table(p["rendered_table"])
    tapas_inputs.append((p["short_id"], p["query"], str_table))

In [None]:
with open("{}/tapas_on_visqa_inputs.pkl".format(data_path), "wb") as f:
    pickle.dump(tapas_inputs, f)

In [None]:
# alternative old inputs construction (for debugging only)
# tapas_inputs = []
# for i in range(len(dt)):
#     p = dt[i]
#     str_table = format_table(p["table"])
#     tapas_inputs.append((p["short_id"], p["query"], str_table))

In [None]:
# with open("{}/tapas_on_visqa_inputs_old.pkl".format(data_path), "wb") as f:
#     pickle.dump(tapas_inputs, f)

#### This processes `tapas_on_visqa_outputs.pkl` to generate `tapas_on_visqa_dataset.pkl`
- Performs top-k strategy and generate candidate outputs
- Based on `visqa_dataset.pkl`, and merge the tapas results with it

In [None]:
import pickle
import numpy as np
import pandas as pd

from trinity.utils.visqa import normalize_table, parse_value
from trinity.utils.visqa_strategy import strategy_TaPas_A, strategy_TaPas_B, strategy_TaPas_C

In [None]:
# path to the original VisQA dataset
data_path = "./benchmarks/VisQA/shared/"

In [None]:
def interpret_answer(arg_line):
    if arg_line.startswith("COUNT of "):
        tmp_operands = [parse_value(p) for p in arg_line[len("COUNT of "):].split(", ")]
        return [len(tmp_operands)]
    elif arg_line.startswith("SUM of "):
        tmp_operands = [parse_value(p) for p in arg_line[len("SUM of "):].split(", ")]
        return [sum(tmp_operands)]
    elif arg_line.startswith("AVERAGE of "):
        tmp_operands = [parse_value(p) for p in arg_line[len("AVERAGE of "):].split(", ")]
        return [sum(tmp_operands)/len(tmp_operands)]
    else:
        # no ops
        tmp_operands = [parse_value(p) for p in arg_line.split(", ")]
        if len(tmp_operands)==0:
            return ["<no answer>"]
        elif len(tmp_operands)==1:
            if isinstance(tmp_operands[0], str) and tmp_operands[0].strip()=="":
                return ["<no answer>"]
            else:
                return [tmp_operands[0]]
        else:
            # len>1
            return sorted([p for p in tmp_operands], key=lambda x:str(x))
        
def extract_answers_from_logs(arg_logs):
    tmp_answers = []
    for i in range(len(arg_logs)):
        if arg_logs[i].startswith("Evaluation finished"):
            if arg_logs[i+1].startswith(">"):
                try:
                    tmp_answers.append(interpret_answer(arg_logs[i+2]))
                except TypeError:
                    tmp_answers.append(["<type error>"])
            else:
                # TaPas exception/error
                tmp_answers.append(["<tapas exception>"])
    return tmp_answers

In [None]:
with open("{}/visqa_dataset.pkl".format(data_path), "rb") as f:
    dt = pickle.load(f)

with open("{}/tapas_on_visqa_outputs.log".format(data_path), "r") as f:
    tapas_logs = f.readlines()
tapas_logs = extract_answers_from_logs(tapas_logs)
with open("{}/tapas_on_visqa_outputs.pkl".format(data_path), "rb") as f:
    tapas_outputs = pickle.load(f)
tapas_outputs = [tapas_outputs[i] for i in range(len(tapas_outputs)) if i%2!=0]

In [None]:
assert len(tapas_outputs)==len(dt)
assert len(tapas_logs)==len(dt)
len(tapas_outputs)

In [None]:
# first extract all the cell pointers with probs
tapas_parsed_outputs = []
for i in range(len(tapas_outputs)):
    # print("# i={}".format(i))
    if len(tapas_outputs[i])>0:
        p = tapas_outputs[i][0] # always at 0 since we pass 1 benchmark to TaPas at a time
        dop = p["pred_aggr"] # predicted operator
        qlist = p["probabilities"]>0 # find all cells with prob>0
        cpps = []
        for j in range(len(qlist)):
            if qlist[j]:
                drow = p["row_ids"][j]-1
                dcol = p["column_ids"][j]-1
                dprob= p["probabilities"][j]
                cpps.append((drow,dcol,dprob))
        cpps = sorted(cpps, key=lambda x:x[2], reverse=True)
        tapas_parsed_outputs.append((dop,cpps)) # (aggr, cpps)
    else:
        # no outputs, could be something wrong?
        print("# warning: no output for i={}".format(i))
        tapas_parsed_outputs.append((0,[])) # (aggr, cpps)

In [None]:
# then build "expected_output" table and "candidate_outputs" table
for i in range(len(dt)):
    print("\r# processing {}/{}".format(i, len(dt)), end="")
    p = dt[i]

    if isinstance(p["repr_answer"], list):
        tmp_expected_output = normalize_table(pd.DataFrame.from_records(
            np.asarray([p["repr_answer"]]).T, columns=["ANSWER"],
        ))
    elif isinstance(p["repr_answer"], (int, float, str)):
        tmp_expected_output = normalize_table(pd.DataFrame.from_records(
            [[p["repr_answer"]]], columns=["ANSWER"],
        ))
    else:
        raise NotImplementedError("Unsupported type of answer, got: {}.".format(type(p["repr_answer"])))

    dt[i]["expected_output"] = tmp_expected_output

    tmp_outputs_original = tapas_logs[i]
    tmp_outputs_TaPas_A = strategy_TaPas_A(tapas_parsed_outputs[i], p["rendered_table"])
    tmp_outputs_TaPas_B = strategy_TaPas_B(tapas_parsed_outputs[i], p["rendered_table"])
    tmp_probs_TaPas_C, tmp_outputs_TaPas_C = strategy_TaPas_C(tapas_parsed_outputs[i], p["rendered_table"])
    dt[i]["candidate_outputs"] = {
        "TaPas_original": tmp_outputs_original,
        "TaPas_A": tmp_outputs_TaPas_A,
        "TaPas_B": tmp_outputs_TaPas_B,
        "TaPas_C": tmp_outputs_TaPas_C,
        "TaPas_probs_C": tmp_probs_TaPas_C,
    }


In [None]:
with open("{}/tapas_on_visqa_dataset.pkl".format(data_path), "wb") as f:
    pickle.dump(dt, f)