In [1]:
import os
#os.environ["CUDA_LAUNCH_BLOCKING"]="1"
os.environ["TOKENIZERS_PARALLELISM"]="true"
from transformers import BartForConditionalGeneration,BartTokenizer
import torch

from transformers import pipeline

summarizer = pipeline("summarization", 
                      model="facebook/bart-large-cnn",
                      device=0)

In [2]:
def getContextSummary(context,max_len=5000):
    return summarizer(context[:max_len], 
                      max_length=max_len//3 + 150, 
                      min_length=150, 
                      do_sample=False)[0]['summary_text']

In [1]:
from datasets import load_dataset_builder,load_dataset
convAI_data = load_dataset_builder('multidoc2dial')

train_data= load_dataset('multidoc2dial', split='train', ignore_verifications=True)
test_data= load_dataset('multidoc2dial', split='test', 
                                            ignore_verifications=True)
dev_data= load_dataset('multidoc2dial', split='validation', 
                                            ignore_verifications=True)

Downloading builder script:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/7.21k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/52.3k [00:00<?, ?B/s]

No config specified, defaulting to: multidoc2dial/multidoc2dial
No config specified, defaulting to: multidoc2dial/multidoc2dial


Downloading and preparing dataset multidoc2dial/multidoc2dial to /home/nlplab/.cache/huggingface/datasets/multidoc2dial/multidoc2dial/1.0.0/6e2a407c09eb478a5b80845fc04406bb9c9d4bea9f135dd7b8b7610a6b608d84...


Downloading data:   0%|          | 0.00/6.87M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/4201 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/21451 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5 [00:00<?, ? examples/s]

Dataset multidoc2dial downloaded and prepared to /home/nlplab/.cache/huggingface/datasets/multidoc2dial/multidoc2dial/1.0.0/6e2a407c09eb478a5b80845fc04406bb9c9d4bea9f135dd7b8b7610a6b608d84. Subsequent calls will reuse this data.


No config specified, defaulting to: multidoc2dial/multidoc2dial
Found cached dataset multidoc2dial (/home/nlplab/.cache/huggingface/datasets/multidoc2dial/multidoc2dial/1.0.0/6e2a407c09eb478a5b80845fc04406bb9c9d4bea9f135dd7b8b7610a6b608d84)
No config specified, defaulting to: multidoc2dial/multidoc2dial
Found cached dataset multidoc2dial (/home/nlplab/.cache/huggingface/datasets/multidoc2dial/multidoc2dial/1.0.0/6e2a407c09eb478a5b80845fc04406bb9c9d4bea9f135dd7b8b7610a6b608d84)


In [2]:
import re
from mosestokenizer import MosesDetokenizer, MosesSentenceSplitter
import tqdm
from nltk.util import ngrams
from nltk.tokenize import sent_tokenize
import random

splitter = MosesSentenceSplitter("en")
detokenizer = MosesDetokenizer("en")


def normalize_whitespace(string):
    return re.sub(r"(\s)\1{1,}", r"\1", string)


def cleanDocument(document):
    document = re.sub(r"\[\d+\s?\]", "", document)
    document = re.sub(r"(\d\.\s+|[a-z]\)\s+|•\s+|[A-Z]\.\s+|[IVX]+\.\s+)", "", document)
    document = normalize_whitespace(document.replace("\n", "")).strip()
    return document


def raw_checks(answer, context, n=3):
    dat = context.strip().split("\n\n")
    if len(dat) == 1:
        dat = dat[0].strip().split("\n")

    if len(dat) < 4:
        n = len(dat)
    document_list = ["".join(s) for s in list(ngrams(dat, n))]
    # print(document_list)
    contains = [s.strip() for s in document_list if answer.strip() in s.strip()]
    if len(contains) > 1:
        contains = random.choice(contains)
        return contains
    else:
        return contains[-1]


def finalOverlap(answer_context, document):
    def nltk_splitter_based():
        context_sentences = sent_tokenize(document)
        if len(context_sentences) < 5:
            return [document]
        document_list = [" ".join(s) for s in list(ngrams(context_sentences, 5))]
        context_sentences = [s.strip() for s in document_list if answer_context in s]
        return context_sentences

    def split_sentences():
        context_sentences = splitter([document])
        document_list = [" ".join(s) for s in list(ngrams(context_sentences, 5))]
        context_sentences = [s.strip() for s in document_list if answer_context in s]
        if len(context_sentences) == 0:
            context_sentences = nltk_splitter_based()

        if len(context_sentences) > 1:
            return random.choice(context_sentences)
        else:
            return context_sentences[-1]

    return split_sentences()

In [3]:
from dataclasses import dataclass
from dataclasses_json import dataclass_json

@dataclass_json
@dataclass
class Multidoc2DialData:
    document: str
    current_question: str
    conv_history: str
    answer_phrase: str
    facts: str
    utterance: str
    
    
    
    

In [7]:
def processDataset(data_pack):
    not_present = []
    processed_data = []
    for idx in tqdm.tqdm(range(0,len(data_pack)),desc="Processing: "):
        dat= train_data[idx]
        document_context = dat['context']
        question, history= dat['question'].split('[SEP]')
        answer_text = dat['answers']['text'][0]
        answer_ids = dat['answers']['answer_start'][0]
        
        if dat['answers']['text'][0] not in document_context:
            not_present.append(idx)
        else:
            doc_context =  raw_checks(answer_text,document_context)
            document = cleanDocument(doc_context)
            answer_context = cleanDocument(answer_text)
            assert answer_context in document, "Error: 1"
            context_sentences=finalOverlap(answer_context,document)
            assert answer_context.strip() in context_sentences.strip(), "Error: 2"
            conversation = Multidoc2DialData(document=cleanDocument(document_context),
                                             conv_history=history,
                                             utterance=dat['utterance'],
                                             facts=context_sentences,
                                             answer_phrase= answer_context.strip(),
                                             current_question=question)
            processed_data.append(conversation)
    return processed_data

In [8]:
processed_train= processDataset(train_data)
processed_test= processDataset(test_data)
processed_dev= processDataset(dev_data)


Processing: 100%|██████████| 21451/21451 [00:12<00:00, 1660.47it/s]
Processing: 100%|██████████| 4094/4094 [00:02<00:00, 1772.69it/s]
Processing: 100%|██████████| 4201/4201 [00:02<00:00, 1716.65it/s]


In [9]:
import json
with open("processed_data/train_data.json","w") as train_file:
    json.dump(Multidoc2DialData.schema().dump(processed_train, many=True,),train_file)
with open("processed_data/test_data.json","w") as test_file:
    json.dump(Multidoc2DialData.schema().dump(processed_test, many=True,),test_file)
    
with open("processed_data/dev_data.json","w") as dev_file:
    json.dump(Multidoc2DialData.schema().dump(processed_dev, many=True,),dev_file)