In [None]:
import json
import numpy as np
import os
import pandas as pd
import random
from transformers import T5Tokenizer


TASKS_SPLITS = "data/custom_tasks_splits/train_non_mrc_qa_test_mrc.json"
OUTPUT_FILE = {
    "train": "data/train-train_non_mrc_qa_test_mrc.tsv",
    "test": "data/test-train_non_mrc_qa_test_mrc.tsv"
}
COUNT_OUTPUT_FILE = "data/counts-train_non_mrc_qa_test_mrc.json"
DATA_PATH = "data/crossfit"
INPUT_MAX_LEN = 1024


def read_prompt_dict(filename: str) -> dict:
    result = {}
    df = pd.read_csv(filename, header=None, sep="\t", names=["task_name", "task_prefix", "prompt", "prompt_len", "io_sep"])
    for _, row in df.iterrows():
        result[row.task_prefix] = row.prompt_len
    return result

PROMPT_DICT = read_prompt_dict("data/prompt/prompt.tsv")
tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=INPUT_MAX_LEN)

def get_task_prefixes(data_path: str, task_name: str) -> list:
    """Returns all task prefixes (e.g., adversarialqa_32_13) of a task."""
    files = sorted(os.listdir(os.path.join(data_path, task_name)))
    prefixes = []
    for filename in files:
        if not filename.endswith(".tsv"):
            continue
        prefix = "_".join(filename.split("_")[:-1])
        if prefix not in prefixes:
            prefixes.append(prefix)
    return prefixes

def get_tasks_list(filename, split_name):
    with open(filename, "r") as fin:
        split_dict = json.load(fin)
    return split_dict[split_name]

def get_n_tokens(text: str) -> int:
    return len(tokenizer(text)["input_ids"])

def is_input_valid(task_prefix: str, input_text: str) -> bool:
    max_allowed = INPUT_MAX_LEN - PROMPT_DICT[task_prefix]
    n_tokens = get_n_tokens(input_text)
    return n_tokens <= max_allowed


n_examples = {}
targets_len = []
for split in ["train", "test"]:
    print("Generating data for split: {}".format(split))
    task_names = get_tasks_list(TASKS_SPLITS, split)
    data = []
    for task_name in task_names:
        print("Task: {}".format(task_name))
        prefixes = get_task_prefixes(DATA_PATH, task_name)
        for prefix in prefixes:
            with open(os.path.join(DATA_PATH, task_name, prefix + "_" + split + ".tsv")) as fin:
                lines = fin.readlines()
            for line in lines:
                d = line.strip().split("\t")
                if is_input_valid(prefix, d[0]):
                    target = random.choice(d[1:])
                    data.append([task_name, prefix, d[0], target, d[1:]])
                    targets_len.append(get_n_tokens(target))
    n_examples[split] = len(data)
    df = pd.DataFrame(data)
    df.to_csv(OUTPUT_FILE[split], index=False, sep="\t", header=None)

json.dump(n_examples, open(COUNT_OUTPUT_FILE, "w"))

print("Max target len: {}".format(np.max(targets_len)))

In [None]:
import matplotlib.pyplot as plt
n, bins, patches = plt.hist(targets_len, 20)
plt.show()