# Estimate sentence probability with BERT
## From https://github.com/huggingface/transformers/issues/37, but updated to newest transformers version

In [6]:
import numpy as np
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [64]:
BOS_TOKEN = '[CLS]'
EOS_TOKEN = '[SEP]'
MASK_TOKEN = '[MASK]'

def score(sentence):
    sm = torch.nn.Softmax(dim=0) # Softmax to convert logits to probs
    
    tokenized_input = tokenizer.tokenize(sentence)
    if tokenized_input[0] != BOS_TOKEN:
        tokenized_input.insert(0, BOS_TOKEN)
    if tokenized_input[-1] != EOS_TOKEN:
        tokenized_input.append(EOS_TOKEN)
    ids_input = tokenizer.convert_tokens_to_ids(tokenized_input)
    print(f"Processing sentence: {tokenized_input}")
    print(f"Sentence ids: {ids_input}")
#     tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_input)])
#     sentence_loss=0.
    sent_prob = 1
    
    # Mask non-special tokens and calculates their probabilities
    for i in range(1,len(tokenized_input)-1): # Ignore first and last tokens
        current_tokenized = tokenized_input[:]
        current_tokenized[i] = MASK_TOKEN
        masked_input = torch.tensor([tokenizer.convert_tokens_to_ids(current_tokenized)])
        outputs = model(masked_input)
        predictions = outputs[0]
        current_probs = sm(predictions[0, i]) # Softmax to get probabilities
        current_prob = current_probs[ids_input[i]] # Prediction for masked word
        sent_prob *= current_prob
        
        print(f"Word: {tokenized_input[i]} \t Prob: {current_prob}")

    return sent_prob

In [65]:
score("I fed my cat some of it and he damn near passed out")

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'cat', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '[SEP]']
Sentence ids: [101, 1045, 7349, 2026, 4937, 2070, 1997, 2009, 1998, 2002, 4365, 2379, 2979, 2041, 102]
Word: i 	 Prob: 0.9639779925346375
Word: fed 	 Prob: 0.13039857149124146
Word: my 	 Prob: 0.06296561658382416
Word: cat 	 Prob: 0.0026130476035177708
Word: some 	 Prob: 0.7925142645835876
Word: of 	 Prob: 0.9992772936820984
Word: it 	 Prob: 0.7004916667938232
Word: and 	 Prob: 0.8272475004196167
Word: he 	 Prob: 0.17570537328720093
Word: damn 	 Prob: 0.9917864203453064
Word: near 	 Prob: 0.9969384670257568
Word: passed 	 Prob: 0.550364077091217
Word: out 	 Prob: 6.173141997578568e-08


tensor(5.6021e-14, grad_fn=<MulBackward0>)