In [6]:
import json
import os
import pandas as pd
import random


TASKS_SPLITS = "data/custom_tasks_splits/train_classification_test_classification.json"
OUTPUT_FILE = {
    "train": "data/train-train_classification_test_classification.tsv",
    "test": "data/test-train_classification_test_classification.tsv"
}
DATA_PATH = "data/crossfit"


def get_task_prefixes(data_path: str, task_name: str) -> list:
    """Returns all task prefixes (e.g., adversarialqa_32_13) of a task."""
    files = sorted(os.listdir(os.path.join(data_path, task_name)))
    prefixes = []
    for filename in files:
        if not filename.endswith(".tsv"):
            continue
        prefix = "_".join(filename.split("_")[:-1])
        if prefix not in prefixes:
            prefixes.append(prefix)
    return prefixes

def get_tasks_list(filename, split_name):
    with open(filename, "r") as fin:
        split_dict = json.load(fin)
    return split_dict[split_name]

for split in ["train", "test"]:
    task_names = get_tasks_list(TASKS_SPLITS, split)
    data = []
    for task_name in task_names:
        prefixes = get_task_prefixes(DATA_PATH, task_name)
        for prefix in prefixes:
            with open(os.path.join(DATA_PATH, task_name, prefix + "_" + split + ".tsv")) as fin:
                lines = fin.readlines()
            for line in lines:
                d = line.strip().split("\t")
                data.append([task_name, prefix, d[0], random.choice(d[1:])])
    df = pd.DataFrame(data)
    df.to_csv(OUTPUT_FILE[split], index=False, sep="\t", header=None)