In [None]:
import json
import numpy as np
import os
import pandas as pd
import random
import t5
import tensorflow.compat.v1 as tf
from tqdm import tqdm
from unidecode import unidecode


TASKS_SPLITS = "data/custom_tasks_splits/train_non_mc_qa_test_mc.json"
OUTPUT_FILE = {
    "train": "data/v2/train-train_non_mc_qa_test_mc.tsv",
    "test": "data/v2/test-train_non_mc_qa_test_mc.tsv"
}
COUNT_OUTPUT_FILE = "data/v2/counts-train_non_mc_qa_test_mc.json"
DATA_PATH = "data/crossfit"
INPUT_MAX_LEN = 1018
vocab = t5.data.get_default_vocabulary()
random.seed(0)

def get_n_tokens(text: str) -> int:
    return vocab.encode_tf(tf.constant(text)).shape[0]

def read_prompt_dict(filename: str) -> dict:
    result = {}
    df = pd.read_csv(filename, header=None, sep="\t", names=["task_name", "task_prefix", "prompt", "prompt_len", "io_sep"])
    for _, row in df.iterrows():
        result[row.task_prefix] = get_n_tokens(row.prompt)
    return result

def get_task_prefixes(data_path: str, task_name: str) -> list:
    """Returns all task prefixes (e.g., adversarialqa_32_13) of a task."""
    files = sorted(os.listdir(os.path.join(data_path, task_name)))
    prefixes = []
    for filename in files:
        if not filename.endswith(".tsv"):
            continue
        prefix = "_".join(filename.split("_")[:-1])
        if prefix not in prefixes:
            prefixes.append(prefix)
    return prefixes

def get_tasks_list(filename, split_name):
    with open(filename, "r") as fin:
        split_dict = json.load(fin)
    return split_dict[split_name]

def is_input_valid(task_prefix: str, input_text: str) -> bool:
    max_allowed = INPUT_MAX_LEN - PROMPT_DICT[task_prefix]
    n_tokens = get_n_tokens(input_text)
    return n_tokens <= max_allowed

PROMPT_DICT = read_prompt_dict("data/prompt/prompt.tsv")
stats = []

In [None]:
task_names = get_tasks_list(TASKS_SPLITS, "test")
data = []
for task_name in tqdm(task_names):
    for prefix in get_task_prefixes(DATA_PATH, task_name):
        with open(os.path.join(DATA_PATH, task_name, prefix + "_test.tsv")) as fin:
            lines = fin.readlines()
        targets_len = []
        for line in lines:
            d = unidecode(line).strip().split("\t")
            if is_input_valid(prefix, d[0]):
                target = random.choice(d[1:])
                data.append([task_name, prefix, d[0], target] + d[1:])
                targets_len.append(get_n_tokens(target))
        stats.append(["test", task_name, prefix, len(targets_len), np.max(targets_len) if targets_len else 0])

df = pd.DataFrame(data)
df.to_csv(OUTPUT_FILE["test"], index=False, sep="\t", header=None)

In [None]:
task_names = get_tasks_list(TASKS_SPLITS, "train")
data = []
for task_name in tqdm(task_names):
    prefixes = get_task_prefixes(DATA_PATH, task_name)
    prefixes_dict = {j:i for i, j in enumerate(prefixes)}
    with open(os.path.join(DATA_PATH, task_name, prefixes[0] + "_test.tsv")) as fin:
        test_lines = fin.readlines()
    test_prefixes = np.array(random.choices(range(len(prefixes)), k=len(test_lines)))

    for prefix in prefixes:
        with open(os.path.join(DATA_PATH, task_name, prefix + "_train.tsv")) as fin:
            lines = fin.readlines()
        targets_len = []
        for line in lines:
            d = unidecode(line).strip().split("\t")
            if is_input_valid(prefix, d[0]):
                target = random.choice(d[1:])
                data.append([task_name, prefix, d[0], target] + d[1:])
                targets_len.append(get_n_tokens(target))
        
        # Add examples from test set into `data`
        test_prefix_indices = np.where(test_prefixes == prefixes_dict[prefix])[0]
        for idx in test_prefix_indices:
            d = unidecode(test_lines[idx]).strip().split("\t")
            if is_input_valid(prefix, d[0]):
                target = random.choice(d[1:])
                data.append([task_name, prefix, d[0], target] + d[1:])
                targets_len.append(get_n_tokens(target))

        stats.append(["train", task_name, prefix, len(targets_len), np.max(targets_len) if targets_len else 0])

df = pd.DataFrame(data)
df.to_csv(OUTPUT_FILE["train"], index=False, sep="\t", header=None)

In [None]:
stats_df = pd.DataFrame(stats, columns=["split", "task_name", "task_prefix", "n_examples", "max_target_len"])

# Save number of examples.
count_df = stats_df[["split", "n_examples"]].groupby(["split"]).sum().reset_index()
json.dump(dict(zip(count_df.split, count_df.n_examples)), open(COUNT_OUTPUT_FILE, "w"))

In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None): 
    display(stats_df)