# Estimate sentence probability with XLNet
## Based and updated from https://github.com/huggingface/transformers/issues/37

In [3]:
import numpy as np
import torch
from transformers import XLNetLMHeadModel, XLNetTokenizer

# Load pre-trained model (weights)
with torch.no_grad():
    model = XLNetLMHeadModel.from_pretrained('xlnet-base-cased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=641.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=467042463.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798011.0, style=ProgressStyle(descripti…




In [61]:
def print_top_predictions(probs, index, k=5):
    probs = probs.detach().numpy()
    top_indexes = np.argpartition(probs, -k)[-k:]
    sorted_indexes = top_indexes[np.argsort(-probs[top_indexes])]
    top_tokens = tokenizer.convert_ids_to_tokens(sorted_indexes)
    print(f"Ordered top predicted tokens: {top_tokens}")
    print(f"Ordered top predicted values: {probs[sorted_indexes]}")

In [62]:
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=True)).unsqueeze(0)  # We will predict the masked token
tokenized_input = tokenizer.convert_ids_to_tokens(input_ids[0])
print(tokenized_input)

['▁', 'Hello', ',', '▁my', '▁dog', '▁is', '▁very', '<mask>', '<sep>', '<cls>']


In [74]:
MASK_TOKEN = '<mask>'

def get_sentence_prob(sentence):
    sm = torch.nn.Softmax(dim=0) # used to convert last hidden state to probs
    
    # Pre-process sentence, adding special tokens
    input_ids = torch.tensor(tokenizer.encode(sentence, add_special_tokens=True)).unsqueeze(0)  # We will predict the masked token
    tokenized_input = tokenizer.convert_ids_to_tokens(input_ids[0])
    print(f"Processing sentence: {tokenized_input}")
    print(f"Processing sentence: {input_ids}")
    
    sent_prob = 1
    # Mask non-special tokens and calculate their probabilities
    for i in range(0,len(tokenized_input)-2): # Ignore final tokens
        current_tokenized = tokenized_input[:]
        current_tokenized[i] = MASK_TOKEN
        masked_input = torch.tensor([tokenizer.convert_tokens_to_ids(current_tokenized)])
        
        perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
        perm_mask[:, :, i] = 1.0 

        target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)
        target_mapping[0, 0, i] = 1.0
        
        outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
        predictions = outputs[0]
        print(predictions)
        
        current_probs = sm(predictions[0, -1]) # Softmax to get probabilities
        current_prob = current_probs[input_ids[0][i]] # Prediction for masked word
        sent_prob *= current_prob
        
        print(f"Word: {tokenized_input[i]} \t Prob: {current_prob}")
        #print_top_predictions(current_probs, ids_input[i])

    print(f"\nSentence probability: {sent_prob.item()}\n")
    return sent_prob

In [75]:
#get_sentence_prob("I fed my cat some of it and he damn near passed out")
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in Chile.")
get_sentence_prob("He was born in window.")


Processing sentence: ['▁He', '▁was', '▁born', '▁in', '▁Berlin', '.', '<sep>', '<cls>']
Processing sentence: tensor([[  69,   30, 1094,   25, 4158,    9,    4,    3]])
tensor([[[-30.8018, -42.2576, -41.8839,  ..., -37.6388, -36.3895, -40.6825]]],
       grad_fn=<AddBackward0>)
Word: ▁He 	 Prob: 0.0003586793318390846
tensor([[[-26.9867, -39.1739, -38.9740,  ..., -32.0969, -32.5971, -36.2555]]],
       grad_fn=<AddBackward0>)
Word: ▁was 	 Prob: 0.006640327163040638
tensor([[[-35.5227, -43.8349, -43.6165,  ..., -36.6232, -40.9318, -45.7496]]],
       grad_fn=<AddBackward0>)
Word: ▁born 	 Prob: 9.768486052053049e-05
tensor([[[-34.4562, -44.3209, -43.8899,  ..., -36.0626, -38.1389, -43.3396]]],
       grad_fn=<AddBackward0>)
Word: ▁in 	 Prob: 0.17678529024124146
tensor([[[-29.7587, -45.5898, -45.4703,  ..., -42.9988, -34.7171, -37.9252]]],
       grad_fn=<AddBackward0>)
Word: ▁Berlin 	 Prob: 9.891665797567839e-08
tensor([[[-31.5731, -44.9830, -44.7597,  ..., -36.9262, -35.1503, -39.9347]]],


tensor(6.3891e-25, grad_fn=<MulBackward0>)

In [65]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

Processing sentence: ['▁I', '▁fed', '▁my', '▁cat', '▁some', '▁of', '▁it', '▁and', '▁he', '▁damn', '▁near', '▁passed', '▁out', '.', '<sep>', '<cls>']
Processing sentence: tensor([[  35, 8124,   94, 4777,  106,   20,   36,   21,   43, 7757,  479, 1400,
           78,    9,    4,    3]])
Word: ▁fed 	 Prob: 0.0006848873454146087
Word: ▁my 	 Prob: 0.0032348744571208954
Word: ▁cat 	 Prob: 7.678403198951855e-05
Word: ▁some 	 Prob: 0.0014662300236523151
Word: ▁of 	 Prob: 0.6216524243354797
Word: ▁it 	 Prob: 0.18512940406799316
Word: ▁and 	 Prob: 0.11187209188938141
Word: ▁he 	 Prob: 0.0001272847002837807
Word: ▁damn 	 Prob: 3.8606871385127306e-05
Word: ▁near 	 Prob: 0.0004729019710794091
Word: ▁passed 	 Prob: 0.006773877423256636
Word: ▁out 	 Prob: 0.001013346598483622
Word: . 	 Prob: 0.001979386666789651

Sentence probability: 1.0139868223547045e-34

Processing sentence: ['▁I', '▁fed', '▁my', '▁dog', '▁some', '▁of', '▁it', '▁and', '▁he', '▁damn', '▁near', '▁passed', '▁out', '.', '<sep>', '<cl

tensor(2.6907e-36, grad_fn=<MulBackward0>)