# Estimate sentence probability with BERT
## From https://github.com/huggingface/transformers/issues/37, with bugs fixed and updated to newest transformers version

In [1]:
import numpy as np
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-large-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

In [2]:
def print_top_predictions(probs, k=5):
    probs = probs.detach().numpy()
    top_indexes = np.argpartition(probs, -k)[-k:]
    sorted_indexes = top_indexes[np.argsort(-probs[top_indexes])]
    top_tokens = tokenizer.convert_ids_to_tokens(sorted_indexes)
    print(f"Ordered top predicted tokens: {top_tokens}")
    print(f"Ordered top predicted values: {probs[sorted_indexes]}")

In [12]:
BOS_TOKEN = '[CLS]'
EOS_TOKEN = '[SEP]'
MASK_TOKEN = '[MASK]'

def get_sentence_prob(sentence, verbose=False):
    sm = torch.nn.Softmax(dim=0) # used to convert last hidden state to probs
    
    # Pre-process sentence, adding special tokens
    tokenized_input = tokenizer.tokenize(sentence)
    sent_len = len(tokenized_input)
    if tokenized_input[0] != BOS_TOKEN:
        tokenized_input.insert(0, BOS_TOKEN)
    if tokenized_input[-1] != EOS_TOKEN:
        tokenized_input.append(EOS_TOKEN)
    ids_input = tokenizer.convert_tokens_to_ids(tokenized_input)
    print(f"Processing sentence: {tokenized_input}")
    #print(f"Sentence ids: {ids_input}")
    
    #sent_prob = 1
    sum_lp = 0
    # Mask non-special tokens and calculate their probabilities
    for i in range(1,len(tokenized_input)-1): # Ignore first and last tokens
        current_tokenized = tokenized_input[:]
        current_tokenized[i] = MASK_TOKEN
        if verbose: print(current_tokenized)
        masked_input = torch.tensor([tokenizer.convert_tokens_to_ids(current_tokenized)])
        outputs = model(masked_input)
        predictions = outputs[0]
        current_probs = sm(predictions[0, i]) # Softmax to get probabilities
        current_prob = current_probs[ids_input[i]] # Prediction for masked word
        #sent_prob *= current_prob
        
        sum_lp += np.log(current_prob.detach().numpy())
        
        print(f"Word: {tokenized_input[i]} \t Prob: {current_prob}")
        if verbose: print_top_predictions(current_probs)

    #print(f"\nSentence probability: {sent_prob.item()}\n")
    print(f"\nNormalized sentence prob: log(P(sentence)) / sent_length: {sum_lp / sent_len}\n")
    return sum_lp / sent_len

In [15]:
get_sentence_prob("The man ate the steak.")
get_sentence_prob("The man who arrived late ate the steak with a glass of wine.")
get_sentence_prob("The steak was eaten by the man.")
get_sentence_prob("The stake ate the man.")

Processing sentence: ['[CLS]', 'the', 'man', 'ate', 'the', 'steak', '.', '[SEP]']
Word: the 	 Prob: 0.9430958032608032
Word: man 	 Prob: 0.15097321569919586
Word: ate 	 Prob: 0.11828337609767914
Word: the 	 Prob: 0.10330334305763245
Word: steak 	 Prob: 0.004455209709703922
Word: . 	 Prob: 0.9944341778755188

Normalized sentence prob: log(P(sentence)) / sent_length: -1.9622100537332396

Processing sentence: ['[CLS]', 'the', 'man', 'who', 'arrived', 'late', 'ate', 'the', 'steak', 'with', 'a', 'glass', 'of', 'wine', '.', '[SEP]']
Word: the 	 Prob: 0.9333303570747375
Word: man 	 Prob: 0.06445129215717316
Word: who 	 Prob: 0.9256716966629028
Word: arrived 	 Prob: 0.10185236483812332
Word: late 	 Prob: 0.003638619789853692
Word: ate 	 Prob: 0.15281958878040314
Word: the 	 Prob: 0.005081328563392162
Word: steak 	 Prob: 0.013520020060241222
Word: with 	 Prob: 0.37167930603027344
Word: a 	 Prob: 0.9855746030807495
Word: glass 	 Prob: 0.8360260725021362
Word: of 	 Prob: 0.9999445676803589
Word: 

-3.6060804301135554

In [4]:
#get_sentence_prob("I fed my cat some of it and he damn near passed out")
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in France.")
get_sentence_prob("He was born in window.")
get_sentence_prob("He was born in was.")


Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'berlin', '.', '[SEP]']
Word: he 	 Prob: 0.7967859506607056
Word: was 	 Prob: 0.9999992847442627
Word: born 	 Prob: 0.9977497458457947
Word: in 	 Prob: 0.9979470372200012
Word: berlin 	 Prob: 0.02355594001710415
Word: . 	 Prob: 0.9999347925186157

Normalized sentence prob: log(P(sentence)) / sent_length: -0.6633200494943973

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'santiago', '.', '[SEP]']
Word: he 	 Prob: 0.7612152695655823
Word: was 	 Prob: 0.9999862909317017
Word: born 	 Prob: 0.9960402250289917
Word: in 	 Prob: 0.997549831867218
Word: santiago 	 Prob: 0.0008775214664638042
Word: . 	 Prob: 0.9998825788497925

Normalized sentence prob: log(P(sentence)) / sent_length: -1.2196333849568266

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'france', '.', '[SEP]']
Word: he 	 Prob: 0.7930527329444885
Word: was 	 Prob: 0.9999958276748657
Word: born 	 Prob: 0.9916587471961975
Word: in 	 Prob: 0.9998917579650

-3.478431378123787

In [5]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'cat', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']
Word: i 	 Prob: 0.9958482980728149
Word: fed 	 Prob: 0.15173447132110596
Word: my 	 Prob: 0.09394065290689468
Word: cat 	 Prob: 0.00496681360527873
Word: some 	 Prob: 0.5760889053344727
Word: of 	 Prob: 0.9998531341552734
Word: it 	 Prob: 0.7705156207084656
Word: and 	 Prob: 0.8786880373954773
Word: he 	 Prob: 0.5077316761016846
Word: damn 	 Prob: 0.9221678972244263
Word: near 	 Prob: 0.9851245284080505
Word: passed 	 Prob: 0.8588521480560303
Word: out 	 Prob: 0.999413013458252
Word: . 	 Prob: 0.9991564750671387

Normalized sentence prob: log(P(sentence)) / sent_length: -0.8163514484040206

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'dog', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']
Word: i 	 Prob: 0.9969239830970764
Word: fed 	 Prob: 0.15674538910388947
Word: my 	 Prob: 0.13328789174556732
Word: dog 	 Prob: 0.

-3.137562180068926

In [6]:
print("Should have similar/high probs\n")
get_sentence_prob("I forgot to take my medicine.")
get_sentence_prob("I forgot to take my medicines.")
get_sentence_prob("I forgot to take my medication.")
get_sentence_prob("I forgot to take my pills.")
print("Should have low probs\n")
get_sentence_prob("I forgot to take my turn.")
get_sentence_prob("I forgot to take my medical.")
get_sentence_prob("I forgot to take my medically.")
get_sentence_prob("I forgot to take my turned.")

Should have similar/high probs

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medicine', '.', '[SEP]']
Word: i 	 Prob: 0.9906209707260132
Word: forgot 	 Prob: 0.006364563480019569
Word: to 	 Prob: 0.9999978542327881
Word: take 	 Prob: 0.9722825884819031
Word: my 	 Prob: 0.8687838912010193
Word: medicine 	 Prob: 0.016525747254490852
Word: . 	 Prob: 0.9979066848754883

Normalized sentence prob: log(P(sentence)) / sent_length: -1.334305136331815

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medicines', '.', '[SEP]']
Word: i 	 Prob: 0.977046549320221
Word: forgot 	 Prob: 0.04139229655265808
Word: to 	 Prob: 0.9999853372573853
Word: take 	 Prob: 0.8683955669403076
Word: my 	 Prob: 0.6597922444343567
Word: medicines 	 Prob: 0.0007449064869433641
Word: . 	 Prob: 0.989775538444519

Normalized sentence prob: log(P(sentence)) / sent_length: -1.5681947589606093

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medication', '.', '[SEP]']

-3.024814240185411

In [7]:
# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [8]:
#get_sentence_prob("I fed my cat some of it and he damn near passed out")
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in France.")
get_sentence_prob("He was born in window.")
get_sentence_prob("He was born in was.")

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'berlin', '.', '[SEP]']
Word: he 	 Prob: 0.8305719494819641
Word: was 	 Prob: 0.9999191761016846
Word: born 	 Prob: 0.9969794750213623
Word: in 	 Prob: 0.994476854801178
Word: berlin 	 Prob: 0.031274765729904175
Word: . 	 Prob: 0.9961835741996765

Normalized sentence prob: log(P(sentence)) / sent_length: -0.6105087432479195

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'santiago', '.', '[SEP]']
Word: he 	 Prob: 0.7335055470466614
Word: was 	 Prob: 0.9996452331542969
Word: born 	 Prob: 0.9062364101409912
Word: in 	 Prob: 0.9894782900810242
Word: santiago 	 Prob: 0.00146165257319808
Word: . 	 Prob: 0.9926425814628601

Normalized sentence prob: log(P(sentence)) / sent_length: -1.159146633008883

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'france', '.', '[SEP]']
Word: he 	 Prob: 0.8246520161628723
Word: was 	 Prob: 0.999785840511322
Word: born 	 Prob: 0.9959775805473328
Word: in 	 Prob: 0.9995790123939514

-3.2969138608993185

In [9]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'cat', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']
Word: i 	 Prob: 0.9571204781532288
Word: fed 	 Prob: 0.1349359005689621
Word: my 	 Prob: 0.0650763139128685
Word: cat 	 Prob: 0.002580614760518074
Word: some 	 Prob: 0.7807718515396118
Word: of 	 Prob: 0.9992412328720093
Word: it 	 Prob: 0.6963055729866028
Word: and 	 Prob: 0.8459337949752808
Word: he 	 Prob: 0.1988241821527481
Word: damn 	 Prob: 0.9891019463539124
Word: near 	 Prob: 0.9948875308036804
Word: passed 	 Prob: 0.778436541557312
Word: out 	 Prob: 0.9847735166549683
Word: . 	 Prob: 0.9980348944664001

Normalized sentence prob: log(P(sentence)) / sent_length: -0.9582437386769536

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'dog', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']
Word: i 	 Prob: 0.9549459218978882
Word: fed 	 Prob: 0.11256841570138931
Word: my 	 Prob: 0.10771261900663376
Word: dog 	 Prob: 0.0

-2.901533915943998

In [10]:
print("Should have similar/high probs\n")
get_sentence_prob("I forgot to take my medicine.")
get_sentence_prob("I forgot to take my medicines.")
get_sentence_prob("I forgot to take my medication.")
get_sentence_prob("I forgot to take my pills.")
print("Should have low probs\n")
get_sentence_prob("I forgot to take my turn.")
get_sentence_prob("I forgot to take my medical.")
get_sentence_prob("I forgot to take my medically.")
get_sentence_prob("I forgot to take my turned.")

Should have similar/high probs

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medicine', '.', '[SEP]']
Word: i 	 Prob: 0.9732086658477783
Word: forgot 	 Prob: 0.00880364328622818
Word: to 	 Prob: 0.9999449253082275
Word: take 	 Prob: 0.6397031545639038
Word: my 	 Prob: 0.4191467761993408
Word: medicine 	 Prob: 0.01624431647360325
Word: . 	 Prob: 0.9715754985809326

Normalized sentence prob: log(P(sentence)) / sent_length: -1.4607050482222153

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medicines', '.', '[SEP]']
Word: i 	 Prob: 0.959055483341217
Word: forgot 	 Prob: 0.0183637123554945
Word: to 	 Prob: 0.9998666048049927
Word: take 	 Prob: 0.21368975937366486
Word: my 	 Prob: 0.25066184997558594
Word: medicines 	 Prob: 0.0008340253843925893
Word: . 	 Prob: 0.9502522945404053

Normalized sentence prob: log(P(sentence)) / sent_length: -2.0152105095060375

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medication', '.', '[SEP]'

-2.841679330333136