In [1]:
import json
import pathlib
from gpt3 import MewlCaptioner
from tqdm import tqdm

In [2]:
captioner = MewlCaptioner()

In [3]:
def test_task(captioner, dataset_path, task_name):
    dataset_path = pathlib.Path(dataset_path) / task_name
    all_episodes = dataset_path.glob('*/')
    all_episodes = sorted(all_episodes, key=lambda x: int(x.name))

    all_datasets = []

    caption_func = getattr(captioner, task_name)

    for i, epi in tqdm(enumerate(all_episodes), total=len(all_episodes)):
        prompt = ""
        contexts, query = caption_func(epi)
        for context in contexts:
            prompt += f"Context: {context[0]}\nName: {context[1]}\n\n"

        prompt += f"Context: {query[0]}\nName: "

        query_description, choices, answer = query

        label = choices.index(answer)

        all_datasets.append({
            "start_phrase": prompt,
            "choices": choices,
            "label": label,
            "task": task_name
        })

    return all_datasets

In [6]:
dataset_path = pathlib.Path("/home/guangyuan/MEWL")

In [7]:
from model.consts import task_names

for task in task_names:
    dataset = {}
    for split in ["train", "val", "test"]:
        dataset[split] = test_task(captioner, dataset_path / split, task)

    with open(f"./dataset/{task}.json", "w+") as f:
        json.dump(dataset, f, indent=4)

100%|██████████| 3000/3000 [00:00<00:00, 7303.80it/s]
100%|██████████| 600/600 [00:00<00:00, 9153.64it/s]
100%|██████████| 600/600 [00:00<00:00, 7475.57it/s]
100%|██████████| 3000/3000 [00:00<00:00, 4200.57it/s]
100%|██████████| 600/600 [00:00<00:00, 4284.91it/s]
100%|██████████| 600/600 [00:00<00:00, 4050.50it/s]
100%|██████████| 3000/3000 [00:00<00:00, 6675.68it/s]
100%|██████████| 600/600 [00:00<00:00, 6658.86it/s]
100%|██████████| 600/600 [00:00<00:00, 6574.18it/s]
100%|██████████| 3000/3000 [00:03<00:00, 758.20it/s]
100%|██████████| 600/600 [00:00<00:00, 739.81it/s]
100%|██████████| 600/600 [00:00<00:00, 718.76it/s]
100%|██████████| 3000/3000 [00:00<00:00, 4195.26it/s]
100%|██████████| 600/600 [00:00<00:00, 4187.67it/s]
100%|██████████| 600/600 [00:00<00:00, 4077.30it/s]
100%|██████████| 3000/3000 [00:00<00:00, 6241.86it/s]
100%|██████████| 600/600 [00:00<00:00, 6147.60it/s]
100%|██████████| 600/600 [00:00<00:00, 6139.62it/s]
100%|██████████| 3000/3000 [00:00<00:00, 6458.41it/s]
1