In [None]:
import json
import numpy as np
import re
import random
import pandas as pd
from datasets import Dataset, DatasetDict
from collections import defaultdict

In [None]:
random.seed(111)

In [None]:
data_path = "train_456-fixedIds.json"

with open(data_path) as f:
    ds = json.load(f)

data = ds["data"]

In [None]:
sentence_pattern = re.compile(r'<b>(.*?)<br>')
sentence_id_pattern = re.compile(r'Sent (\d+): </b>(.*?)$')

In [None]:
class SummaryHash:
    def __init__(self):
        self.summary_hash = defaultdict(list)

    def add_summary(self, summary, sample_id):
        length = len(summary)
        self.summary_hash[length].append({"id": sample_id, "text": summary})

    def get_summary_by_length(self, length):
        return self.summary_hash.get(length, [])

summary_hasher = SummaryHash()

def generate_sentence_dict(text, sentence_pattern, sentence_id_pattern):
    output = {}
    sentences = sentence_pattern.findall(text)

    for sentence in sentences:
        sent_info = sentence_id_pattern.search(sentence)
        sent_id = int(sent_info.group(1))
        sent_text = sent_info.group(2)
        output[sent_id] = sent_text

    return output

for i, sample in enumerate(ds["data"]):
    text = sample["paragraph"]["text"]
    sentence_dict = generate_sentence_dict(text, sentence_pattern, sentence_id_pattern)

    sample_id = sample["id"]
    for question in sample["paragraph"]["questions"]:
        for answer in question["answers"]:
            summary = question["question"][:-1] + " " + answer["text"]
            summary_hasher.add_summary(summary, sample_id)

In [None]:
def get_documents(sentence_dict, indeces):
    docs = [sentence_dict[sent_id+1] for sent_id in indeces]
    docs_str = " ||||| ".join(docs)

    return docs_str

def get_random_summary(summary_hasher, summary_lengths, sample_id):
    min_len = min(summary_lengths)
    max_len = max(summary_lengths)
    for length in range(min_len, max_len + 1):
        summary_candidates = summary_hasher.get_summary_by_length(length)
        random.shuffle(summary_candidates)
        for summary in summary_candidates:
            if summary["id"] != sample_id:
                return summary["text"]
            
    print("not found")


parsed_dataset = []
for i, sample in enumerate(ds["data"]):
    text = sample["paragraph"]["text"]
    sentence_dict = generate_sentence_dict(text, sentence_pattern, sentence_id_pattern)

    sample_id = sample["id"]
    for question in sample["paragraph"]["questions"]:
        question_id = question["idx"]

        sentences_used = question["sentences_used"]

        docs_str = get_documents(sentence_dict, sentences_used)

        summary_lengths = []
        for answer in question["answers"]:
            summary = question["question"][:-1] + " " + answer["text"]

            isAnswer = 1 if answer["isAnswer"] else 0
            new_entry = {"document": docs_str, \
                         "summary": summary, \
                         "isAnswer": isAnswer, \
                         "q_id": question_id, \
                         "sample_id": sample_id, \
                         "isMultisent": question["multisent"]}
            
            parsed_dataset.append(new_entry)
            summary_lengths.append(len(summary))

        random_summary = get_random_summary(summary_hasher, summary_lengths, sample_id)
        new_entry = {"document": docs_str, \
                     "summary": random_summary, \
                     "isAnswer": -1, \
                     "q_id": question_id, \
                     "sample_id": sample_id, \
                     "isMultisent": question["multisent"]}
            
        parsed_dataset.append(new_entry)

In [None]:
parsed_dataset[15]

In [None]:
df_multirc_processed = pd.DataFrame(parsed_dataset)
df_multirc_processed["n_docs"] = df_multirc_processed["document"].apply(lambda x: x.count("|||||")+1)
df_multirc_processed["n_docs"].value_counts()

In [None]:
df = pd.DataFrame(parsed_dataset)
 
train_dataset = Dataset.from_pandas(df)

dataset_dict = {
    "train": train_dataset
}

hf_dataset = DatasetDict(dataset_dict)
hf_dataset

In [None]:
hf_dataset.push_to_hub("multiRC_MDS")