In [None]:
import json
import numpy as np
import os
import pandas as pd
import random
from tqdm import tqdm
from transformers import T5Tokenizer
from unidecode import unidecode


TASKS_SPLITS = "data/custom_tasks_splits/train_non_mc_qa_test_mc.json"
OUTPUT_FILE = {
    "train": "data/train-train_non_mc_qa_test_mc.tsv",
    #"dev": "data/dev-train_non_mc_qa_test_mc.tsv",
    "test": "data/test-train_non_mc_qa_test_mc.tsv"
}
COUNT_OUTPUT_FILE = "data/counts-train_non_mc_qa_test_mc.json"
DATA_PATH = "data/crossfit"
MIN_EXAMPLES_PER_PREFIX = 32
INPUT_MAX_LEN = 1024
random.seed(0)


def read_prompt_dict(filename: str) -> dict:
    result = {}
    df = pd.read_csv(filename, header=None, sep="\t", names=["task_name", "task_prefix", "prompt", "prompt_len", "io_sep"])
    for _, row in df.iterrows():
        result[row.task_prefix] = row.prompt_len
    return result

PROMPT_DICT = read_prompt_dict("data/prompt/prompt.tsv")
tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=INPUT_MAX_LEN)

def get_task_prefixes(data_path: str, task_name: str) -> list:
    """Returns all task prefixes (e.g., adversarialqa_32_13) of a task."""
    files = sorted(os.listdir(os.path.join(data_path, task_name)))
    prefixes = []
    for filename in files:
        if not filename.endswith(".tsv"):
            continue
        prefix = "_".join(filename.split("_")[:-1])
        if prefix not in prefixes:
            prefixes.append(prefix)
    return prefixes

def get_tasks_list(filename, split_name):
    with open(filename, "r") as fin:
        split_dict = json.load(fin)
    return split_dict[split_name]

def get_n_tokens(text: str) -> int:
    return len(tokenizer(text)["input_ids"])

def is_input_valid(task_prefix: str, input_text: str) -> bool:
    max_allowed = INPUT_MAX_LEN - PROMPT_DICT[task_prefix]
    n_tokens = get_n_tokens(input_text)
    return n_tokens <= max_allowed


stats = []
all_targets_len = []
for split in OUTPUT_FILE.keys():
    print("Generating data for split: {}".format(split))
    task_names = get_tasks_list(TASKS_SPLITS, split)
    data = []
    for task_name in tqdm(task_names):
        for prefix in get_task_prefixes(DATA_PATH, task_name):
            filename = prefix + "_test.tsv" if split == "test" else prefix + "_train.tsv"
            with open(os.path.join(DATA_PATH, task_name, filename)) as fin:
                lines = fin.readlines()
            targets_len = []
            for line in lines:
                d = unidecode(line).strip().split("\t")
                if is_input_valid(prefix, d[0]):
                    target = random.choice(d[1:])
                    data.append([task_name, prefix, d[0], target] + d[1:])
                    targets_len.append(get_n_tokens(target))

            # If the number of examples per task prefix is less than the threshold, sample from test set.
            n_retry = 10
            while len(targets_len) < MIN_EXAMPLES_PER_PREFIX and n_retry > 0:
                n_retry -= 1
                with open(os.path.join(DATA_PATH, task_name, prefix + "_test.tsv")) as fin:
                    lines = fin.readlines()
                lines = random.sample(lines, MIN_EXAMPLES_PER_PREFIX - len(targets_len))
                for line in lines:
                    d = unidecode(line).strip().split("\t")
                    if is_input_valid(prefix, d[0]):
                        target = random.choice(d[1:])
                        data.append([task_name, prefix, d[0], target] + d[1:])
                        targets_len.append(get_n_tokens(target))

            stats.append([split, task_name, prefix, len(targets_len), np.max(targets_len) if targets_len else 0])
            all_targets_len.extend(targets_len)

    # Save every split into a TSV file.
    df = pd.DataFrame(data)
    df.to_csv(OUTPUT_FILE[split], index=False, sep="\t", header=None)

stats_df = pd.DataFrame(stats, columns=["split", "task_name", "task_prefix", "n_examples", "max_target_len"])

# Save number of examples.
count_df = stats_df[["split", "n_examples"]].groupby(["split"]).sum().reset_index()
count_df.loc[count_df.split == "dev", "split"] = "validation"
json.dump(dict(zip(count_df.split, count_df.n_examples)), open(COUNT_OUTPUT_FILE, "w"))

In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None): 
    display(stats_df)

In [None]:
print("Max target len: {}".format(np.max(all_targets_len)))

import matplotlib.pyplot as plt
n, bins, patches = plt.hist([x for x in all_targets_len if x > 32], 40)
plt.show()