In [14]:
from datasets import load_dataset

DATASET_NAME = "csqa"
dataset = load_dataset("tau/commonsense_qa", "", split=["train", "validation",'test'])

In [15]:
NUM_PERMUTE_QUESTION = None
EACH_HAS = 50
NUM_CHOICE = 5

In [16]:
choice_len = max([len(x["label"]) for x in dataset[1]["choices"]])
print("max", choice_len)
print(min([len(x["label"]) for x in dataset[1]["choices"]]))

max 5
5


In [17]:
dataset[1][0]

{'id': '1afa02df02c908a558b4036e80242fac',
 'question': 'A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?',
 'question_concept': 'revolving door',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['bank', 'library', 'department store', 'mall', 'new york']},
 'answerKey': 'A'}

In [18]:
import numpy as np


def get_prompt(data, ques_index, location=-1, has_choice=False, is_test=False):
    try:
        json_line = data[ques_index]
        question = json_line["question"]
        choices = json_line["choices"]
        choice_texts = choices["text"]
        if not is_test:
            answer_key = json_line["answerKey"][0]
            answer_key_idx = ord(answer_key) - (
                ord("A") if answer_key in "ABCDE" else ord("1")
            )
            answer_text = choices["text"][answer_key_idx]
            if location > -1:
                perm = np.random.permutation(len(choice_texts) - 1)
                choice_texts.remove(answer_text)
                choice_texts = [
                    choice_texts[perm[i]].replace("\n", " ")
                    for i in range(len(choice_texts))
                ]
                choice_texts.insert(location, answer_text)
            else:
                perm = list(range(NUM_CHOICE))
                choice_texts = [
                    choice_texts[perm[i]].replace("\n", " ")
                    for i in range(len(choice_texts))
                ]
        candidates = " ".join(
            [
                f"({label if has_choice else ' '}) {text}"
                for text, label in zip(choice_texts, choices["label"])
            ]
        ).replace("\n", " ")

        fact = f"{json_line['fact1']}. " if "fact1" in json_line else ""
        prompt = f"{fact}{question} \\n {candidates}"
        if is_test:
            return json_line['id'], prompt
        else: 
            return prompt, answer_text
    except:
        print(answer_key)

In [19]:
import random
from tqdm import tqdm, trange
import pickle
import itertools

all_permutes = list(itertools.permutations(list(range(choice_len))))
if NUM_PERMUTE_QUESTION:
    container = []
    each_choice_has = EACH_HAS / NUM_CHOICE
    for ques_index in trange(NUM_PERMUTE_QUESTION):
        for idx in range(EACH_HAS):
            location = int(idx / each_choice_has)
            container.append(
                get_prompt(dataset[1], ques_index, location, has_choice=False)
            )
    pickle.dump(
        (NUM_PERMUTE_QUESTION, EACH_HAS, container),
        open(f"{DATASET_NAME}_test_permute.pkl", "wb"),
    )
else:
    for i,name in enumerate(['train','dev','test']):
        pickle.dump([get_prompt(dataset[i], x, has_choice=False, is_test=name=='test') for x in trange(len(dataset[i]))], open(f"{DATASET_NAME}_{name}.pkl", "wb"))

100%|██████████| 9741/9741 [00:00<00:00, 17118.73it/s]
100%|██████████| 1221/1221 [00:00<00:00, 16970.35it/s]
100%|██████████| 1140/1140 [00:00<00:00, 17375.35it/s]
