# Bart model

In [1]:
import random
import math

import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, BartTokenizer, BartForConditionalGeneration, BartModel

In [2]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large', return_dict=True).eval()

In [25]:
tokenizer.mask_token

'<mask>'

In [107]:
inputs = tokenizer("I've met <mask> is cute", return_tensors="pt")
outputs = model(**inputs)

outputs.keys()

odict_keys(['logits', 'encoder_last_hidden_state'])

In [108]:
probas = outputs.logits.squeeze()
tokens = probas.argmax(1)
tokenizer.decode(tokens[1:], skip_special_tokens=True)

"I've met him is cute"

In [134]:
def decode(masked_text: str, model, tokenizer):
    inputs = tokenizer(masked_text, return_tensors="pt")
    outputs = model(**inputs)
    print(outputs.logits.shape)
    probas = outputs.logits.squeeze()
    tokens = probas.argmax(1)
    return tokenizer.decode(tokens[1:], skip_special_tokens=True)

In [135]:
yelp = load_dataset("yelp_polarity", split="train")
text = yelp[0]["text"]
text

Reusing dataset yelp_polarity (/home/przemyslaw/.cache/huggingface/datasets/yelp_polarity/plain_text/1.0.0/2b33212d89209ed1ea0522001bccc5f5a5c920dd9c326f3c828e67a22c51a98c)


"Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff.  It seems that his staff simply never answers the phone.  It usually takes 2 hours of repeated calling to get an answer.  Who has time for that or wants to deal with it?  I have run into this problem with many other doctors and I just don't get it.  You have office workers, you have patients with medical needs, why isn't anyone answering the phone?  It's incomprehensible and not work the aggravation.  It's with regret that I feel that I have to give Dr. Goldberg 2 stars."

In [181]:
text = "Unfortunately, <mask> Dr. Goldberg's patient is a <mask> of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."

In [182]:
decode(text, model, tokenizer)

torch.Size([1, 33, 50265])


"Unfortunately, the experience. Goldberg's patient is a womanch the experience I've had with so many other doctors in NYC -- good doctor, terrible staff. I"

#### Better way

In [183]:
inputs = tokenizer(text, max_length=1024, return_tensors='pt')

# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=2, max_length=100, early_stopping=True)[0]
tokenizer.decode(summary_ids[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)

"Unfortunately, the experience with Dr. Goldberg reminds me of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."

In [178]:
inputs['input_ids']

tensor([[    0, 16861,     6,     5, 50264,   925,     4, 18835,    18,  3186,
            16,    10, 50264,     9,     5,   676,    38,   348,    56,    19,
            98,   171,    97,  3333,    11, 14415,   480,   205,  3299,     6,
          6587,   813,     4,     2]])

In [184]:
tokenizer.decode(inputs['input_ids'][0])

"<s>Unfortunately,<mask> Dr. Goldberg's patient is a<mask> of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff.</s>"

In [166]:
summary_ids

tensor([[    2, 16861, 16861,     6,     5,   676,     9,   925,     4, 18835,
            18,  3186,    16,    10,  5177, 16254,   119,     9,     5,   676,
            38,   348,    56,    19,    98,   171,    97,  3333,    11, 14415,
           480,   205,  3299,     6,  6587,   813,     4,     2]])

## Noising 

In [185]:
from fairseq.data import DenoisingDataset
print(
"""
A wrapper around TokenBlockDataset for BART dataset.
Args:
    dataset (TokenBlockDataset): dataset to wrap
    sizes (List[int]): sentence lengths
    vocab (~fairseq.data.Dictionary): vocabulary
    mask_idx (int): dictionary index used for masked token
    mask_whole_words: only mask whole words. This should be a byte mask
        over vocab indices, indicating whether it is the beginning of a
        word. We will extend any mask to encompass the whole word.
    shuffle (bool, optional): shuffle the elements before batching.
      Default: ``True``
    seed: Seed for random number generator for reproducibility.
    args: argparse arguments.
"""
     )


A wrapper around TokenBlockDataset for BART dataset.
Args:
    dataset (TokenBlockDataset): dataset to wrap
    sizes (List[int]): sentence lengths
    vocab (~fairseq.data.Dictionary): vocabulary
    mask_idx (int): dictionary index used for masked token
    mask_whole_words: only mask whole words. This should be a byte mask
        over vocab indices, indicating whether it is the beginning of a
        word. We will extend any mask to encompass the whole word.
    shuffle (bool, optional): shuffle the elements before batching.
      Default: ``True``
    seed: Seed for random number generator for reproducibility.
    args: argparse arguments.



### How to noise the dataset?
According to the original paper, text infilling works the best. That means, that we take random places in text, choose random (with Poisson(lambda=3) distribution) sequence length to mask and replace it with single `<mask>` token. The question is how to choose places to start, how many of them should there be etc.

In [189]:
text = "Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."
text

"Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."

#### Establish how many places to choose for masking
On average, we want fraction of words splitted. We assume Poisson(2) distributution for number of words for now.

In [293]:
round(1.4)

1

In [352]:
lambda_ = 2.5

words = np.array(text.split(), dtype='object')
fraction = 0.2
n_mask = int(fraction*len(words))
n_places = max(1, round(n_mask / lambda_))
n_places
print(f"{len(words)} words, {n_mask} to mask ({fraction} fraction). {n_places} places to insert <mask>")

28 words, 5 to mask (0.2 fraction). 2 places to insert <mask>


#### Choose words to mask

In [413]:
lengths = np.random.poisson(lambda_, size=n_places)

places = np.sort(np.random.choice(10, size=n_places, replace=False))

ends = {start: start + length for start, length in zip(places, lengths)}
to_mask = {start + i for start, length in zip(places, lengths) for i in range(length)}

In [417]:
print(" ".join(words))

masked_words = list()
i = 0
while i < len(words):
    if i in ends:
        if len(masked_words) == 0 or masked_words[-1] != tokenizer.mask_token:
            masked_words.append(tokenizer.mask_token)
            i = ends[i]
        else:
            masked_words.append(words[i])
            i += 1
    elif i in to_mask:
        i += 1
    else:
        masked_words.append(words[i])
        i += 1

masked_text = " ".join(masked_words)
masked_text

Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff.


"Unfortunately, <mask> frustration of being Dr. Goldberg's <mask> a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."

In [420]:
masked_text

"Unfortunately, <mask> frustration of being Dr. Goldberg's <mask> a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."

In [430]:
inputs = tokenizer(masked_text, max_length=1024, return_tensors='pt')

# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=2, max_length=512, early_stopping=True)[0]
generated_text = tokenizer.decode(summary_ids[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)

Because facebook researchers sometimes suffled sentences during pretraining, we see that Bart does it sometimes. Even though it wasn't relaly necessariy here. Oh well

In [433]:
inputs = tokenizer(masked_text, max_length=1024, return_tensors='pt')

# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=1, max_length=512, early_stopping=True)[0]
generated_text = tokenizer.decode(summary_ids[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(generated_text == text)
generated_text

True


"Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff."

Also, it seems that Bart have seen our examples during pretraining. Otherwise, how else would it be able to reconstruct the sentence so perfectly?

In [437]:
np.random.choice(10)

3

In [120]:
class BartAugmenter:
    def __init__(self, model=None, tokenizer=None, fraction: float = 0.2, min_mask: int = 1, max_mask: int = 100,
                 lambda_: float = 2.5, num_beams: int = 1, device=None):
        """
        :param model: huggingface/transformers model for masked language modeling
            e.g model = BartForConditionalGeneration.from_pretrained('facebook/bart-large', return_dict=True)
        :param tokenizer: huggingface/transformers tokenizer
            e.g tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        :param fraction: fraction of words to insert
        :param min_mask: minimum number of <mask> tokens to insert
        :param max_mask: maximum number ot tokens to mask
        :param topk: number of top words to sample from
        :param uniform: whether to sample uniformly from topk words (defaults to False)
        :param device: torch.device
        """
        self.device = device or torch.device('cuda')
        model = model or AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large', return_dict=True)
        self.model = model.eval().to(self.device)
        tokenizer = tokenizer or AutoTokenizer.from_pretrained('facebook/bart-large', use_fast=False)
        self.tokenizer = tokenizer
        self.mask_token = tokenizer.mask_token
        self.min_mask = min_mask
        self.max_mask = max_mask
        self.fraction = fraction
        self.lambda_ = lambda_
        self.num_beams = num_beams

    def __call__(self, text: str):
        if self.fraction == 0:
            return text
        
        words = text.split()
        n_mask = max(self.min_mask, round(len(words) * self.fraction))
        n_mask = min(n_mask, self.max_mask)
        # offset, since lenght might increase after tokenization
        max_masked_idx = min(self.tokenizer.model_max_length - 50, len(words))
        n_places = max(1, round(n_mask / self.lambda_))
        
        places = np.sort(np.random.choice(len(words), size=n_places, replace=False))
        lengths = np.random.poisson(self.lambda_, size=n_places)
        ends = {start: start + length for start, length in zip(places, lengths)}
        to_mask = {start + i for start, length in zip(places, lengths) for i in range(length)}
        
        masked_words = list()
        i = 0
        while i < len(words):
            if i in ends:
                if len(masked_words) == 0 or masked_words[-1] != tokenizer.mask_token:
                    masked_words.append(tokenizer.mask_token)
                    i = ends[i]
                else:
                    masked_words.append(words[i])
                    i += 1
            elif i in to_mask:
                i += 1
            else:
                masked_words.append(words[i])
                i += 1

        masked_text = " ".join(masked_words)
        inputs = tokenizer(masked_text, max_length=1024, return_tensors='pt')

        # Generate seq2seq output
        with torch.no_grad():
            summary_ids = model.generate(inputs['input_ids'], num_beams=self.num_beams, max_length=512, early_stopping=True)[0]
            generated_text = tokenizer.decode(summary_ids[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        
        return generated_text

In [121]:
augmenter = BartAugmenter(model, tokenizer, device=torch.device('cpu'))

In [123]:
augmenter('I love you so much')

'I love much'