### References

BERT LM prediction
https://github.com/huggingface/transformers/blob/master/docs/source/quickstart.md#bert-example

Masking script
https://github.com/huggingface/pytorch-pretrained-BERT/blob/f9cde97b313c3218e1b29ea73a42414dfefadb40/examples/lm_finetuning/simple_lm_finetuning.py#L276-L301

In [1]:
import random
import torch
import copy
from transformers import BertTokenizer, BertForMaskedLM


In [2]:
from typing import List, Union

In [3]:
def mask_sentence_bertstyle(tokens, tokenizer):
    """
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
    :param tokens: list of str, tokenized sentence.
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
    """
    output_label = []

    for i, token in enumerate(tokens):
        prob = random.random()
        # mask token with 15% probability
        if prob < 0.15:
            prob /= 0.15

            # 80% randomly change token to mask token
            if prob < 0.8:
                tokens[i] = "[MASK]"

            # 10% randomly change token to random token
            elif prob < 0.9:
                tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]

            # -> rest 10% randomly keep current token

            # append current token to output (we will predict these later)
            try:
                output_label.append(tokenizer.vocab[token])
            except KeyError:
                # For unknown words (should not occur with BPE vocab)
                output_label.append(tokenizer.vocab["[UNK]"])
                print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
        else:
            # no masking token (will be ignored by loss function later)
            output_label.append(-1)

    return tokens, output_label



In [4]:
class BERT_processor:
    def __init__(self, modeltype='bert-base-uncased'):
        
        # Load pre-trained model tokenizer (vocabulary)
        self.tokenizer = BertTokenizer.from_pretrained(f'{modeltype}')
        # Load pre-trained model (weights)
        self.model = BertForMaskedLM.from_pretrained(f'{modeltype}')
        self.model.eval()
        
    def prepare_input(self, modelname:str, text: str):
        if modelname=="BERT":
            # Tokenize input
            tokenized_text = self.tokenizer.tokenize(text)
            tokenized_text = ['[CLS]']  +  tokenized_text +['[SEP]']
            tokenized_text_ = copy.copy(tokenized_text)
            masktokenized_text, mask_labels = mask_sentence_bertstyle(tokenized_text_, self.tokenizer)

            # Convert token to vocabulary indices
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(masktokenized_text)

            # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
            # segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
            segments_ids = [0]* len(indexed_tokens)

            # Convert inputs to PyTorch tensors
            input_tensor = torch.tensor([indexed_tokens])
            segments_tensors = torch.tensor([segments_ids])
                        
            return tokenized_text, masktokenized_text, input_tensor,\
                   mask_labels, segments_tensors
        
    def predict(self, tokenized_text, masktokenized_text, \
                input_tensor, mask_labels, segments_tensors=None):
        
        input_sentence = self.tokenizer.convert_ids_to_tokens(input_tensor[0].tolist())
        sent_length = len(input_sentence)
        print(tokenized_text)
        print(input_sentence)

        # If you have a GPU, put everything on cuda
        input_tensor = input_tensor.to('cuda')
        if not segments_tensors == None:
            segments_tensors = segments_tensors.to('cuda')

        self.model.to('cuda')

        # Predict all tokens
        with torch.no_grad():
            if not segments_tensors==None:
                outputs = self.model(input_tensor,token_type_ids=segments_tensors)
            else:
                outputs = self.model(input_tensor)

            #outputs => Tuple((batchsize, seq_len, vocab_size))
            predictions = outputs[0]

        # using list comprehension + enumerate() 
        # index of matching element 
        mask_positions = [idx for idx, val in enumerate(mask_labels) if val > -1] 
        print(mask_positions)

        actual_words = [val for idx, val in enumerate(tokenized_text) if idx in mask_positions]
        fake_words = [input_sentence[mp] for mp in mask_positions]

        for mask_position in list(mask_positions):
            # confirm we were able to predict the actual words
            predicted_index = torch.argmax(predictions[0, mask_position]).item()
            predicted_token = self.tokenizer.convert_ids_to_tokens([predicted_index])[0]
            input_sentence[mask_position] = predicted_token

        print(input_sentence)

        print("Actual one is ",actual_words)
        print("Fake one is  ",fake_words)
        predicted_words = [input_sentence[mp] for mp in mask_positions]
        print("Predicted one is  ",predicted_words)

In [5]:
processor = BERT_processor()

In [10]:
text = 'The White House statement said that these changes would help \
        protect the salaries of American workers and ensure that foreign \
        labour coming into the US is high-skilled and do not undercut \
        the United States labour market.'
tokenized_text, masktokenized_text, input_tensor,\
                   mask_labels, segments_tensors = processor.prepare_input('BERT', f'{text}')

In [11]:
processor.predict(tokenized_text, masktokenized_text, input_tensor,\
                   mask_labels, segments_tensors)

['[CLS]', 'the', 'white', 'house', 'statement', 'said', 'that', 'these', 'changes', 'would', 'help', 'protect', 'the', 'salaries', 'of', 'american', 'workers', 'and', 'ensure', 'that', 'foreign', 'labour', 'coming', 'into', 'the', 'us', 'is', 'high', '-', 'skilled', 'and', 'do', 'not', 'under', '##cut', 'the', 'united', 'states', 'labour', 'market', '.', '[SEP]']
['[CLS]', 'the', 'white', 'house', 'statement', 'said', 'that', 'these', 'changes', 'would', 'help', 'protect', '[MASK]', 'salaries', 'of', '[MASK]', 'workers', 'and', 'ensure', '[MASK]', 'foreign', 'labour', 'coming', 'into', 'the', 'us', 'is', 'high', '-', '[MASK]', '[MASK]', '[MASK]', '[MASK]', 'under', '##cut', 'the', 'united', 'states', 'labour', 'market', '.', '[SEP]']
[3, 12, 15, 19, 20, 29, 30, 31, 32]
['[CLS]', 'the', 'white', 'house', 'statement', 'said', 'that', 'these', 'changes', 'would', 'help', 'protect', 'the', 'salaries', 'of', 'foreign', 'workers', 'and', 'ensure', 'that', 'foreign', 'labour', 'coming', 'into