In [1]:
import json
import os
from os import listdir, path
from os.path import isfile, join
from functools import partial

"""
Script for adding the "Domains" field for all the tasks of a given dataset.
"""

def get_files(dataset_names = ["mctaco"], tasks_path = '../tasks/'):
    files = []
    # find the task files containing the dataset
    for file_name in listdir(tasks_path):
        file = join(tasks_path, file_name)
        if isfile(file) and any([dataset_name in file_name for dataset_name in dataset_names]):
            files.append(file)
    files.sort()
    return files

In [2]:
files = get_files(dataset_names=["mctaco", "quoref", "cosmosqa", "drop"]); files

['../tasks/task001_quoref_question_generation.json',
 '../tasks/task002_quoref_answer_generation.json',
 '../tasks/task003_mctaco_question_generation_event_duration.json',
 '../tasks/task004_mctaco_answer_generation_event_duration.json',
 '../tasks/task005_mctaco_wrong_answer_generation_event_duration.json',
 '../tasks/task006_mctaco_question_generation_transient_stationary.json',
 '../tasks/task007_mctaco_answer_generation_transient_stationary.json',
 '../tasks/task008_mctaco_wrong_answer_generation_transient_stationary.json',
 '../tasks/task009_mctaco_question_generation_event_ordering.json',
 '../tasks/task010_mctaco_answer_generation_event_ordering.json',
 '../tasks/task011_mctaco_wrong_answer_generation_event_ordering.json',
 '../tasks/task012_mctaco_question_generation_absolute_timepoint.json',
 '../tasks/task013_mctaco_answer_generation_absolute_timepoint.json',
 '../tasks/task014_mctaco_wrong_answer_generation_absolute_timepoint.json',
 '../tasks/task015_mctaco_question_generat

In [3]:
def add_categories(category, data):
    data['Categories'].append(category)
    return data

In [4]:
def rename_categories(old_to_new_map, data):
    """
    old_to_new_map should look like 
    {
        "old category name 1": "new category name 1",
        "old category name 2": "new category name 2"
    }
    """
    for i, category in enumerate(data["Categories"]):
        if category in old_to_new_map:
            data["Categories"][i] = old_to_new_map[category]
    return data

In [5]:
# test rename_category
def test_rename_categories():
    data = {"Categories": ["old1", "old2", "old3"]}
    new_data = rename_categories({"old1": "new1", "old2": "new2"}, data)
    assert new_data == {'Categories': ['new1', 'new2', 'old3']}

test_rename_categories()

In [6]:
files

['../tasks/task001_quoref_question_generation.json',
 '../tasks/task002_quoref_answer_generation.json',
 '../tasks/task003_mctaco_question_generation_event_duration.json',
 '../tasks/task004_mctaco_answer_generation_event_duration.json',
 '../tasks/task005_mctaco_wrong_answer_generation_event_duration.json',
 '../tasks/task006_mctaco_question_generation_transient_stationary.json',
 '../tasks/task007_mctaco_answer_generation_transient_stationary.json',
 '../tasks/task008_mctaco_wrong_answer_generation_transient_stationary.json',
 '../tasks/task009_mctaco_question_generation_event_ordering.json',
 '../tasks/task010_mctaco_answer_generation_event_ordering.json',
 '../tasks/task011_mctaco_wrong_answer_generation_event_ordering.json',
 '../tasks/task012_mctaco_question_generation_absolute_timepoint.json',
 '../tasks/task013_mctaco_answer_generation_absolute_timepoint.json',
 '../tasks/task014_mctaco_wrong_answer_generation_absolute_timepoint.json',
 '../tasks/task015_mctaco_question_generat

In [7]:
old_to_new_map = {
    "Answer Generation": "Question Answering",
"Answer Generation -> Commonsense Question Answering": 
    "Question Answering -> Commonsense Question Answering",
"Answer Generation -> Contextual Question Answering": 
    "Question Answering -> Contextual Question Answering",
"Answer Generation -> Extractive": 
    "Question Answering -> Contextual Question Answering -> Extractive",
"Answer Generation -> Abstractive": 
    "Question Answering -> Contextual Question Answering -> Abstractive", 
"Answer Generation -> Fill in the Blank": 
    "Question Answering -> Fill in the Blank",
"Answer Generation -> Multiple Choice Question Answering": 
    "Question Answering -> Multiple Choice Question Answering",
"Answer Generation -> Open Question Answering": 
    "Question Answering -> Open Question Answering",
"Incorrect Answer Generation":
    "Question Answering -> Incorrect Answer Generation",
"Contextual Question Generation":
    "Question Generation -> Contextual Question Generation"
}

date_manipulate_func = partial(rename_categories, old_to_new_map)
# date_manipulate_func = partial(add_categories, "new category 123!!")

In [8]:
def modify_files(file, date_manipulate_func):
    """
    manipulate_data is a function 
    """
    with open(file, 'r') as f:
        data = json.load(f)
        date_manipulate_func(data)
        
    os.remove(file)
    with open(file, 'w') as f:
        modified_json = json.dumps(data, indent=4, ensure_ascii=False)
        print(modified_json, file=f)

In [9]:
# modify_files(files[0], date_manipulate_func)

In [10]:
# add the domain
for file in files: 
    modify_files(file, date_manipulate_func)