In [27]:
from pathlib import Path
import tarfile
import json
import random

import pandas as pd
import datasets

## Load JSON-like, Compressed Data to Dataframe

In [19]:
archive_path = Path("../../data/RACE.tar.gz")

items = []
    
with tarfile.open(archive_path, "r:gz") as tar:
    for member in tar.getmembers():
        if not member.isfile() or not member.name.endswith('.txt'):
            continue

        item_dict = {}
        _, split, lvl, fn = member.name.split("/")
        item_dict["split"] = split

        if not lvl == "high":
            continue

        f = tar.extractfile(member)
        if f is not None:
            item_dict.update(json.load(f))
            items.append(item_dict)

df = pd.DataFrame(items)

df["passage_id"] = df["id"].str.extract("(\d+)")

# Two rows/passages have no questions or options
empty_options = df["options"].map(len) == 0
df = df[~empty_options]

In [20]:
# Explode questions into individual rows

df = df.explode(["answers", "options", "questions"]).reset_index(drop=True)
df = df.rename(columns={
    "id": "filename",
    "answers": "answer",
    "questions": "question",
})
df.index.name = "idx"
df.to_parquet("../../data/RACE.parquet")

In [21]:
# All items have four options
df["options"].map(len).value_counts()

options
4    69394
Name: count, dtype: int64

In [22]:
df.sample(1)

Unnamed: 0_level_0,split,answer,options,question,article,filename,passage_id
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
34366,train,A,"[in need of help, not interested in studies, n...","After his tour to the school, the author found...","I'd always dreamed of exploring Africa, ever s...",high4561.txt,4561


## Convert to Contrastive Pairs

Create four rows, one for each item option, where distractors are labeled as incorrect and the correct answer is labeled as correct.

In [30]:
def construct_pairs(df):
    records = []
    
    option_letters = ["A", "B", "C", "D"]

    for row in df.itertuples():
        incorrect_records = []
        for i, option in enumerate(row.options):
            is_correct = 1 if option_letters[i] == row.answer else 0

            record = {
                "split": row.split,
                "passage": row.article,
                "question": row.question,
                "answer": option,
                "label": is_correct,
                "passage_id": int(row.passage_id),
                "question_id": int(row.Index)
            }

            if is_correct:
                records.append(record)
            else:
                incorrect_records.append(record)

        # Sample one incorrect option
        records.append(random.sample(incorrect_records, 1)[0])

    # Create the transformed dataframe
    transformed_df = pd.DataFrame(records)
    transformed_df.index.name = "answer_id"

    return transformed_df

contrastive_df = construct_pairs(df)
contrastive_df

Unnamed: 0_level_0,split,passage,question,answer,label,passage_id,question_id
answer_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,test,The rain had continued for a week and the floo...,What did Nancy try to do before she fell over?,Protect her cows from being drowned,1,19432,0
1,test,The rain had continued for a week and the floo...,What did Nancy try to do before she fell over?,Look for a fallen tree trunk,0,19432,0
2,test,The rain had continued for a week and the floo...,The following are true according to the passag...,Nancy took hold of the rope and climbed into t...,1,19432,1
3,test,The rain had continued for a week and the floo...,The following are true according to the passag...,It was raining harder when Nancy managed to ge...,0,19432,1
4,test,The rain had continued for a week and the floo...,What did the local people do to help those in ...,They put up shelter for them in a school.,1,19432,2
...,...,...,...,...,...,...,...
138783,dev,How come it seems like every kid today is a wi...,"In the passage, parents are advised to _ .",stop kids taking part in any competition.,0,18939,69391
138784,dev,I clearly remember my mom telling me to drink ...,From what the author's mother did we may infer...,she knew her children would benefit from milk,1,11113,69392
138785,dev,I clearly remember my mom telling me to drink ...,From what the author's mother did we may infer...,she didn't like her daughter who didn't obey h...,0,11113,69392
138786,dev,I clearly remember my mom telling me to drink ...,"Based on the passage, which of the following i...",The author's sister will suffer from osteopros...,1,11113,69393


In [31]:
dd = datasets.DatasetDict({
    "train": datasets.Dataset.from_pandas(contrastive_df[contrastive_df["split"] == "train"]),
    "dev": datasets.Dataset.from_pandas(contrastive_df[contrastive_df["split"] == "dev"]),
    "test": datasets.Dataset.from_pandas(contrastive_df[contrastive_df["split"] == "test"]),
})
dd["train"].features

{'split': Value(dtype='string', id=None),
 'passage': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None),
 'answer': Value(dtype='string', id=None),
 'label': Value(dtype='int64', id=None),
 'passage_id': Value(dtype='int64', id=None),
 'question_id': Value(dtype='int64', id=None),
 'answer_id': Value(dtype='int64', id=None)}

In [32]:
dd

DatasetDict({
    train: Dataset({
        features: ['split', 'passage', 'question', 'answer', 'label', 'passage_id', 'question_id', 'answer_id'],
        num_rows: 124890
    })
    dev: Dataset({
        features: ['split', 'passage', 'question', 'answer', 'label', 'passage_id', 'question_id', 'answer_id'],
        num_rows: 6902
    })
    test: Dataset({
        features: ['split', 'passage', 'question', 'answer', 'label', 'passage_id', 'question_id', 'answer_id'],
        num_rows: 6996
    })
})

In [33]:
dd.save_to_disk("../../data/RACE_contrastive_pairs.hf")

Saving the dataset (0/1 shards):   0%|          | 0/124890 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6902 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6996 [00:00<?, ? examples/s]