In [1]:
import datasets
from tools.dumps import get_filename_path
from transformers.data.processors import squad
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
squad.squad_convert_example_to_features_init(tokenizer)

In [2]:
trivia_ds = datasets.load_dataset("trivia_qa", "rc", cache_dir=get_filename_path("trivia_qa/cache"))

Reusing dataset trivia_qa (/nfs/colla-framework/src/data/trivia_qa/cache/trivia_qa/rc/1.1.0/e734e28133f4d9a353af322aa52b9f266f6f27cbf2f072690a1694e577546b0d)


In [3]:
import nltk

from io import StringIO
import pandas as pd

nltk.download('punkt')
sent_tokenize = nltk.data.load('tokenizers/punkt/english.pickle')

tokenizer = AutoTokenizer.from_pretrained(
    "bert-base-uncased", use_fast=False)
#self.dataset = datasets.load_dataset("squad")


#sentence_tokenzier = nltk.load("punkt")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
max_num_tokens = 800

def split_join_context(text: str):
    """
    Convert the text into an easily whitespace-separable text.
    
    Copycat of triviaqa's select_relevant_portion()
    
    :param text
    :returns an easily whitespace-separable text
    """
    # We assume that answers are not made of multiple sentences
    paras = text.split('\n')
    selected = []
    done = False
    for para in paras:
        sents = sent_tokenize.tokenize(para)
        for sent in sents:
            words = nltk.word_tokenize(sent)
            for word in words:
                selected.append(word)
                if len(selected) >= max_num_tokens:
                    done = True
                    break
            if done:
                break
        if done:
            break
        selected.append('\n')
    st = ' '.join(selected).strip()
    return st

def encode(examples):
    qas_id = examples["question_id"]
    question_text = examples["question"]
    context = [entity_page["wiki_context"] for entity_page in examples["entity_pages"]]
    answers = [example["value"] for example in examples["answer"]]
    context_text = []
    
    for text, answer in zip(context, answers):
        new_context_buffer = StringIO()
        for t in text:
            new_context_buffer.write(split_join_context(t))
            new_context_buffer.write(" ")
        
        context_text.append(split_join_context(new_context_buffer.getvalue()))
            
    # print(context_text)
    answers_start = [text.find(answer) for text, answer in zip(context_text, answers)]
    answers_block = [
        {
            "text": [answer],
            "answer_start": [answer_start],
        }
     for answer, answer_start in zip(answers, answers_start)]
    is_impossible = [start == -1 for start in answers_start]


    squad_examples = list(squad.SquadExample(*args) for args in zip(
        qas_id,
        question_text,
        context_text,
        answers,
        answers_start,
        qas_id, # use the ids as titles
        answers_block,
        is_impossible
    ))
    
    features = squad.squad_convert_examples_to_features(
        squad_examples,
        tokenizer,
        512,
        128,
        384,
        True
    )
    
    examples_pd = pd.DataFrame({
         "id": qas_id,
         "answers": answers,
         "context": context_text,
         "question": question_text
     })
    result = pd.DataFrame([vars(feature) for feature in features])
    
    result_dropped = result.drop(["example_index", "token_is_max_context",
                              "encoding", "token_to_orig_map", "cls_index", "paragraph_len",
                              "unique_id"], axis=1)

    # the HF processor silently removes problematic rows. In this case we have no choice but
    # to add singletons for now and filter them out later

    result_grouped = result_dropped.groupby(
        "qas_id", as_index=False).agg(list)

    result_grouped = examples_pd.set_index("id").join(
        result_grouped.set_index("qas_id"),
        how="left"
    ).reset_index()

    result_as_dict = result_grouped.to_dict()

    result = {}

    for k, v in result_as_dict.items():
        values = list(v.values())
        for i in range(len(values)):
            if values[i] != values[i]:  # NaN:
                values[i] = []
        result[k] = values

    return result

trivia_ds.map(encode, batched=True)

convert squad examples to features: 100%|██████████| 2/2 [00:00<00:00, 31.66it/s]
add example index and unique id: 100%|██████████| 2/2 [00:00<00:00, 18517.90it/s]


HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))




KeyboardInterrupt: 

In [43]:
fast_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", is_fast=True)

In [135]:
from copy import deepcopy
import numpy as np

def encode_closed_book(
    batch,
    is_question: bool,
    dropout_rate: float
):
    if is_question:
        questions = batch["question"]
        answers = [b["value"] for b in batch["answer"]]
        
        answer_ids = tokenizer.batch_encode_plus(answers, add_special_tokens=False)
        
        masked_answer_ids = [[tokenizer.mask_token_id] * len(answer) \
                             for answer in answer_ids["input_ids"]]
        
        tokenized = tokenizer.batch_encode_plus(
            questions,
            answer_ids,
            padding="max_length",
            max_length=512
        )
        
        tokenized_masked = tokenizer.batch_encode_plus(
            list(zip(questions,masked_answer_ids)),
            padding="max_length",
            max_length=512
        )
        
        tokenized_masked["output_ids"] = tokenized["input_ids"]
    
        return tokenized_masked
    
    else:
        context = [entity_page["wiki_context"] for entity_page in batch["entity_pages"]]
        context_text = []

        for text in context:
            new_context_buffer = StringIO()
            for t in text:
                new_context_buffer.write(t)
                new_context_buffer.write(" ")

            context_text.append(new_context_buffer.getvalue())
        
        tokenized = fast_tokenizer.batch_encode_plus(
            context_text,
            add_special_tokens=False,
            #return_tensors="np",
            #return_overflowing_tokens=True
        )
        unmasked_input_ids = tokenized["input_ids"]
        tokenized["output_ids"] = deepcopy(unmasked_input_ids)
        
        masked_sentences = []
        
        for sentence in unmasked_input_ids:
            #print(sentence)
            masked_sentence = np.array(sentence)
            toks_to_mask = np.random.choice([True, False], len(sentence),
                                            p=(dropout_rate, 1.0 - dropout_rate))
            masked_sentence[toks_to_mask] = fast_tokenizer.mask_token_id
            masked_sentences.append(masked_sentence.tolist())
        
        tokenized["input_ids"] = masked_sentences
        
        return tokenized

def encode_closed_book_question(batch):
    return encode_closed_book(batch, True, 0.0)

def encode_closed_book_context(batch, dropout_rate=0.15):
    return encode_closed_book(batch, False, dropout_rate=dropout_rate)

In [None]:
contexts_closed_book = trivia_ds.map(encode_closed_book_context, batched=True)

HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))

In [47]:
questions_closed_book = trivia_ds.map(encode_closed_book_question, batched=True)

HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))




KeyboardInterrupt: 