In [6]:
import os
import pandas as pd
import json
from tqdm import tqdm
import re

In [7]:
label_set = {"neutral", "contradiction", "entailment"}

In [8]:
def parse_mnli(input_filepath:str):
    """Extract the sentence pair and the corresponding label
    Args:
        input_filepath: path of the file containing bulk data
        output_filepath: path of the file where the parsed data will be saved
    """
    def trim(text):
        out = text.strip()
        out = re.sub(" +", " ", out)
        return out
    
    paths = input_filepath.split("/")
    filename = ".".join(paths[-1].split(".")[:-1])
    paths[-1] = f"{filename}.csv"
    output_filepath = "/".join(paths)
    data = []
    for line in tqdm(open(input_filepath).readlines()):
        item = json.loads(line)
        if item["gold_label"] not in label_set:
            continue
        data.append([item["pairID"], trim(item["sentence1"]), trim(item["sentence2"]), item["gold_label"]])
    data = pd.DataFrame(data=data, columns=["pairID", "sentence1", "sentence2", "label"])
    data.to_csv(output_filepath, index=False)

In [9]:
parse_mnli("multinli_1.0/multinli_1.0_train.jsonl")
parse_mnli("multinli_1.0/multinli_1.0_dev_matched.jsonl")
parse_mnli("multinli_1.0/multinli_1.0_dev_mismatched.jsonl")

100%|██████████| 392702/392702 [00:05<00:00, 75069.95it/s]
100%|██████████| 10000/10000 [00:00<00:00, 71742.14it/s]
100%|██████████| 10000/10000 [00:00<00:00, 73503.04it/s]


In [10]:
parse_mnli("snli_1.0/snli_1.0_dev.jsonl")
parse_mnli("snli_1.0/snli_1.0_test.jsonl")
parse_mnli("snli_1.0/snli_1.0_train.jsonl")

100%|██████████| 10000/10000 [00:00<00:00, 81442.64it/s]
100%|██████████| 10000/10000 [00:00<00:00, 88390.27it/s]
100%|██████████| 550152/550152 [00:05<00:00, 95167.08it/s] 
