# Estimate sentence probability with BERT
## Calculating probability more properly:
P_f, P_b: Probability forward pass, backward pass, respectively
P_f = P(w_0) * P(w_1|w_0) * P(w_2|w_0, w_1) * ... * P(w_N)
P_b = P(w_N-1|w_N) * P(w_N-2|w_N-1, w_N) * ... * P(w_0|w_1, w_2, ... ,w_N)

P_f, P_b become smaller as the number of

In [62]:
import numpy as np
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-large-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

In [63]:
def print_top_predictions(probs, k=5):
    probs = probs.detach().numpy()
    top_indexes = np.argpartition(probs, -k)[-k:]
    sorted_indexes = top_indexes[np.argsort(-probs[top_indexes])]
    top_tokens = tokenizer.convert_ids_to_tokens(sorted_indexes)
    print(f"Ordered top predicted tokens: {top_tokens}")
    print(f"Ordered top predicted values: {probs[sorted_indexes]}")

In [140]:
a=[1,2,3,4]
a[0:-1] = [i for i in range(10,11)]
a
np.power(8,1/3)

2.0

In [142]:
def get_forward_prob(tokenized_input, sm, verbose=False):
    sent_len = len(tokenized_input)
    ids_input = tokenizer.convert_tokens_to_ids(tokenized_input)
    print(f"Processing sentence: {tokenized_input}")
    #print(f"Sentence ids: {ids_input}")
    sent_prob = 1
#     sum_lp = 0
    # Mask non-special tokens in forward direction; calculate their probabilities
    for i in range(1, len(tokenized_input) - 1): # Don't loop first and last tokens
        current_tokenized = tokenized_input[:]
        current_tokenized[i:-1] = [MASK_TOKEN for j in range(len(tokenized_input) - 1 - i)]
        if verbose: print(current_tokenized)
        masked_input = torch.tensor([tokenizer.convert_tokens_to_ids(current_tokenized)])
        outputs = model(masked_input)
        predictions = outputs[0]
        current_probs = sm(predictions[0, i]) # Softmax to get probabilities
        current_prob = current_probs[ids_input[i]] # Prediction for masked word
        sent_prob *= current_prob

#         sum_lp += np.log(current_prob.detach().numpy())

        print(f"Word: {tokenized_input[i]} \t Prob: {current_prob}")
        if verbose: print_top_predictions(current_probs)

    geom_mean_sent_prob = np.power(sent_prob.item(), 1/sent_len)  # Calculate geometric mean
    print(f"\nGeometric-mean sentence probability: {geom_mean_sent_prob}\n")
#     print(f"\nNormalized forward sentence prob: log(P(sentence)) / sent_length: {sum_lp / sent_len}\n")
#     return sum_lp / sent_len
    return geom_mean_sent_prob

In [143]:
def get_backward_prob(tokenized_input, sm, verbose=False):
    sent_len = len(tokenized_input)
    ids_input = tokenizer.convert_tokens_to_ids(tokenized_input)
    print(f"Processing sentence: {tokenized_input}")
    #print(f"Sentence ids: {ids_input}")
    #sent_prob = 1
    sum_lp = 0
    # Mask non-special tokens in backward direction; calculate their probabilities
    for i in reversed(range(1, len(tokenized_input) - 1)): # Don't loop first and last tokens
        current_tokenized = tokenized_input[:]
        current_tokenized[1:i+1] = [MASK_TOKEN for j in range(i)]
        if verbose: print(current_tokenized)
        masked_input = torch.tensor([tokenizer.convert_tokens_to_ids(current_tokenized)])
        outputs = model(masked_input)
        predictions = outputs[0]
        current_probs = sm(predictions[0, i]) # Softmax to get probabilities
        current_prob = current_probs[ids_input[i]] # Prediction for masked word
        #sent_prob *= current_prob

        sum_lp += np.log(current_prob.detach().numpy())

        print(f"Word: {tokenized_input[i]} \t Prob: {current_prob}")
        if verbose: print_top_predictions(current_probs)

    #print(f"\nSentence probability: {sent_prob.item()}\n")
    print(f"\nNormalized backward sentence prob: log(P(sentence)) / sent_length: {sum_lp / sent_len}\n")
    return sum_lp / sent_len

In [144]:
BOS_TOKEN = '[CLS]'
EOS_TOKEN = '[SEP]'
MASK_TOKEN = '[MASK]'

def get_sentence_prob(sentence, verbose=False):
    sm = torch.nn.Softmax(dim=0) # used to convert last hidden state to probs
    
    # Pre-process sentence, adding special tokens
    tokenized_input = tokenizer.tokenize(sentence)
    if tokenized_input[0] != BOS_TOKEN:
        tokenized_input.insert(0, BOS_TOKEN)
    if tokenized_input[-1] != EOS_TOKEN:
        tokenized_input.append(EOS_TOKEN)
    
    forward_prob = get_forward_prob(tokenized_input, sm, verbose=verbose)
    backward_prob = get_backward_prob(tokenized_input, sm, verbose=verbose)
    avg_prob = (forward_prob + backward_prob) / 2
    print(f"\nAverage normalized sentence prob: log(P(sentence)) / sent_length: {avg_prob}\n")
    
    


In [145]:
get_sentence_prob("He answered unequivocally.")
get_sentence_prob("He answered quickly.", verbose=True)

Processing sentence: ['[CLS]', 'he', 'answered', 'une', '##qui', '##vo', '##cal', '##ly', '.', '[SEP]']
Word: he 	 Prob: 0.0005044482531957328
Word: answered 	 Prob: 0.0002913153439294547
Word: une 	 Prob: 2.243901064957754e-07
Word: ##qui 	 Prob: 0.0005762826185673475
Word: ##vo 	 Prob: 0.05035709589719772
Word: ##cal 	 Prob: 0.9999289512634277
Word: ##ly 	 Prob: 0.9982821941375732
Word: . 	 Prob: 0.9998167157173157

Geometric-mean sentence probability: 0.015776195271175384

Processing sentence: ['[CLS]', 'he', 'answered', 'une', '##qui', '##vo', '##cal', '##ly', '.', '[SEP]']
Word: . 	 Prob: 0.7406781911849976
Word: ##ly 	 Prob: 0.0010952855227515101
Word: ##cal 	 Prob: 2.658714583958499e-05
Word: ##vo 	 Prob: 3.2233551792160142e-06
Word: ##qui 	 Prob: 0.0008089180919341743
Word: une 	 Prob: 0.9987058639526367
Word: answered 	 Prob: 0.010978045873343945
Word: he 	 Prob: 0.2814375162124634

Normalized backward sentence prob: log(P(sentence)) / sent_length: -4.319790983572602


Average

In [146]:
get_sentence_prob("The guy with small hands demanded a quid pro quo.")
get_sentence_prob("The guy with small hands demanded an exchange.")

Processing sentence: ['[CLS]', 'the', 'guy', 'with', 'small', 'hands', 'demanded', 'a', 'qui', '##d', 'pro', 'quo', '.', '[SEP]']
Word: the 	 Prob: 0.020117253065109253
Word: guy 	 Prob: 4.888691910309717e-05
Word: with 	 Prob: 0.0006999199977144599
Word: small 	 Prob: 4.887327304459177e-05
Word: hands 	 Prob: 0.0014287744415923953
Word: demanded 	 Prob: 2.932926236098865e-06
Word: a 	 Prob: 0.000463048490928486
Word: qui 	 Prob: 3.862149696942652e-06
Word: ##d 	 Prob: 0.23301775753498077
Word: pro 	 Prob: 3.778058089665137e-05
Word: quo 	 Prob: 0.9989466071128845
Word: . 	 Prob: 0.9850805401802063

Geometric-mean sentence probability: 0.002836646561912836

Processing sentence: ['[CLS]', 'the', 'guy', 'with', 'small', 'hands', 'demanded', 'a', 'qui', '##d', 'pro', 'quo', '.', '[SEP]']
Word: . 	 Prob: 0.8537910580635071
Word: quo 	 Prob: 6.576073587893916e-07
Word: pro 	 Prob: 0.004886653274297714
Word: ##d 	 Prob: 7.432691973008332e-07
Word: qui 	 Prob: 0.5241817235946655
Word: a 	 Pro

In [147]:
get_sentence_prob("This is a sentence.")
get_sentence_prob("This is a macrame.", verbose=False)
get_sentence_prob("This is a joke.", verbose=False)
get_sentence_prob("Are you kidding?", verbose=False)


Processing sentence: ['[CLS]', 'this', 'is', 'a', 'sentence', '.', '[SEP]']
Word: this 	 Prob: 0.00012813373177777976
Word: is 	 Prob: 0.3984246850013733
Word: a 	 Prob: 0.06797751039266586
Word: sentence 	 Prob: 0.00013198245142120868
Word: . 	 Prob: 0.966092050075531

Geometric-mean sentence probability: 0.046100048266749315

Processing sentence: ['[CLS]', 'this', 'is', 'a', 'sentence', '.', '[SEP]']
Word: . 	 Prob: 0.8030216097831726
Word: sentence 	 Prob: 9.854356903815642e-07
Word: a 	 Prob: 0.23075106739997864
Word: is 	 Prob: 0.06603307276964188
Word: this 	 Prob: 0.060409143567085266

Normalized backward sentence prob: log(P(sentence)) / sent_length: -3.0057408596788133


Average normalized sentence prob: log(P(sentence)) / sent_length: -1.479820405706032

Processing sentence: ['[CLS]', 'this', 'is', 'a', 'mac', '##ram', '##e', '.', '[SEP]']
Word: this 	 Prob: 0.00017661201127339154
Word: is 	 Prob: 0.13002026081085205
Word: a 	 Prob: 0.7468298077583313
Word: mac 	 Prob: 3.4899

In [148]:
get_sentence_prob("Rachel was wearing a lovely satin dress last night.")

Processing sentence: ['[CLS]', 'rachel', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']
Word: rachel 	 Prob: 1.345694727206137e-05
Word: was 	 Prob: 0.2645097076892853
Word: wearing 	 Prob: 3.774421929847449e-05
Word: a 	 Prob: 0.9133190512657166
Word: lovely 	 Prob: 0.0004488571430556476
Word: satin 	 Prob: 0.0010469065746292472
Word: dress 	 Prob: 0.02077450044453144
Word: last 	 Prob: 0.003023444674909115
Word: night 	 Prob: 0.9490013122558594
Word: . 	 Prob: 0.9899153709411621

Geometric-mean sentence probability: 0.019693121325489497

Processing sentence: ['[CLS]', 'rachel', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']
Word: . 	 Prob: 0.8182555437088013
Word: night 	 Prob: 0.0001843223290052265
Word: last 	 Prob: 0.05977332592010498
Word: dress 	 Prob: 2.075352995234425e-06
Word: satin 	 Prob: 1.9559594875317998e-05
Word: lovely 	 Prob: 8.846891432767734e-05
Word: a 	 Prob: 0.8966140151023865
Word: wearing 	 P

In [152]:
get_sentence_prob("Rachel was wearing a lovely satin dress last night.")
get_sentence_prob("Grandma was wearing a lovely satin dress last night.")
get_sentence_prob("Mother was wearing a lovely satin dress last night.")
get_sentence_prob("She was wearing a lovely satin dress last night.")
get_sentence_prob("He was wearing a lovely satin dress last night.")
get_sentence_prob("I was wearing a lovely satin dress last night.")
get_sentence_prob("Angela was wearing a lovely satin dress last night.")
get_sentence_prob("Roberta was wearing a lovely satin dress last night.")
get_sentence_prob("Running was wearing a lovely satin dress last night.")

Processing sentence: ['[CLS]', 'rachel', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']
Word: rachel 	 Prob: 1.345694727206137e-05
Word: was 	 Prob: 0.2645097076892853
Word: wearing 	 Prob: 3.774421929847449e-05
Word: a 	 Prob: 0.9133190512657166
Word: lovely 	 Prob: 0.0004488571430556476
Word: satin 	 Prob: 0.0010469065746292472
Word: dress 	 Prob: 0.02077450044453144
Word: last 	 Prob: 0.003023444674909115
Word: night 	 Prob: 0.9490013122558594
Word: . 	 Prob: 0.9899153709411621

Geometric-mean sentence probability: 0.019693121325489497

Processing sentence: ['[CLS]', 'rachel', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']
Word: . 	 Prob: 0.8182555437088013
Word: night 	 Prob: 0.0001843223290052265
Word: last 	 Prob: 0.05977332592010498
Word: dress 	 Prob: 2.075352995234425e-06
Word: satin 	 Prob: 1.9559594875317998e-05
Word: lovely 	 Prob: 8.846891432767734e-05
Word: a 	 Prob: 0.8966140151023865
Word: wearing 	 P

Word: night 	 Prob: 0.0001843223290052265
Word: last 	 Prob: 0.05977332592010498
Word: dress 	 Prob: 2.075352995234425e-06
Word: satin 	 Prob: 1.9559594875317998e-05
Word: lovely 	 Prob: 8.846891432767734e-05
Word: a 	 Prob: 0.8966140151023865
Word: wearing 	 Prob: 0.1172424703836441
Word: was 	 Prob: 0.9986034035682678
Word: angela 	 Prob: 2.9840008210157976e-05

Normalized backward sentence prob: log(P(sentence)) / sent_length: -4.795881876110798


Average normalized sentence prob: log(P(sentence)) / sent_length: -2.3918587135479314

Processing sentence: ['[CLS]', 'roberta', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']
Word: roberta 	 Prob: 3.867816758429399e-06
Word: was 	 Prob: 0.0005287343519739807
Word: wearing 	 Prob: 0.0002253064449178055
Word: a 	 Prob: 0.8714913725852966
Word: lovely 	 Prob: 0.0008630887023173273
Word: satin 	 Prob: 0.0011990604689344764
Word: dress 	 Prob: 0.018998507410287857
Word: last 	 Prob: 0.0027865259908139706
Word

In [153]:
get_sentence_prob("The man ate the steak.")
get_sentence_prob("The man who arrived late ate the steak with a glass of wine.")
get_sentence_prob("The steak was eaten by the man.")
get_sentence_prob("The stake ate the man.")

Processing sentence: ['[CLS]', 'the', 'man', 'ate', 'the', 'steak', '.', '[SEP]']
Word: the 	 Prob: 0.006482909899204969
Word: man 	 Prob: 0.0008543190779164433
Word: ate 	 Prob: 0.00030972290551289916
Word: the 	 Prob: 0.16035766899585724
Word: steak 	 Prob: 0.004650171846151352
Word: . 	 Prob: 0.9944341778755188

Geometric-mean sentence probability: 0.032588342448014056

Processing sentence: ['[CLS]', 'the', 'man', 'ate', 'the', 'steak', '.', '[SEP]']
Word: . 	 Prob: 0.7197229266166687
Word: steak 	 Prob: 8.405078006035183e-06
Word: the 	 Prob: 0.024790508672595024
Word: ate 	 Prob: 0.014627532102167606
Word: man 	 Prob: 4.835774234379642e-05
Word: the 	 Prob: 0.9430958032608032

Normalized backward sentence prob: log(P(sentence)) / sent_length: -3.7416474418714643


Average normalized sentence prob: log(P(sentence)) / sent_length: -1.8545295497117251

Processing sentence: ['[CLS]', 'the', 'man', 'who', 'arrived', 'late', 'ate', 'the', 'steak', 'with', 'a', 'glass', 'of', 'wine', '.'

In [154]:
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in France.")
get_sentence_prob("He was born in window.")
get_sentence_prob("He was born in was.")


Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'berlin', '.', '[SEP]']
Word: he 	 Prob: 0.0006875116960145533
Word: was 	 Prob: 0.611086905002594
Word: born 	 Prob: 0.0001007601385936141
Word: in 	 Prob: 0.9955852031707764
Word: berlin 	 Prob: 0.02386103756725788
Word: . 	 Prob: 0.9999347925186157

Geometric-mean sentence probability: 0.0750414823747486

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'berlin', '.', '[SEP]']
Word: . 	 Prob: 0.7197229266166687
Word: berlin 	 Prob: 7.354922854574397e-05
Word: in 	 Prob: 0.07671105861663818
Word: born 	 Prob: 0.0007875352748669684
Word: was 	 Prob: 0.999992847442627
Word: he 	 Prob: 0.7967859506607056

Normalized backward sentence prob: log(P(sentence)) / sent_length: -2.473491515967453


Average normalized sentence prob: log(P(sentence)) / sent_length: -1.199225016796352

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'santiago', '.', '[SEP]']
Word: he 	 Prob: 0.0006875116960145533
Word: was 	 Prob: 0.6110

In [155]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'cat', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']
Word: i 	 Prob: 0.0005238083540461957
Word: fed 	 Prob: 1.7332116840407252e-05
Word: my 	 Prob: 0.012338080443441868
Word: cat 	 Prob: 0.00244920514523983
Word: some 	 Prob: 0.007367619778960943
Word: of 	 Prob: 1.723681816656608e-05
Word: it 	 Prob: 0.00012771011097356677
Word: and 	 Prob: 0.5446289777755737
Word: he 	 Prob: 0.0061583081260323524
Word: damn 	 Prob: 2.3060283638187684e-05
Word: near 	 Prob: 0.4169158637523651
Word: passed 	 Prob: 7.57075467845425e-05
Word: out 	 Prob: 0.9994945526123047
Word: . 	 Prob: 0.9991564750671387

Geometric-mean sentence probability: 0.006526921971877533

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'cat', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']
Word: . 	 Prob: 0.7609730958938599
Word: out 	 Prob: 0.0001479800557717681
Word: passed 	 Prob: 0.0015395785449072719
Word: 

In [129]:
print("Should have similar/high probs\n")
get_sentence_prob("I forgot to take my medicine.")
get_sentence_prob("I forgot to take my medicines.")
get_sentence_prob("I forgot to take my medication.")
get_sentence_prob("I forgot to take my pills.")
print("Should have low probs\n")
get_sentence_prob("I forgot to take my turn.")
get_sentence_prob("I forgot to take my medical.")
get_sentence_prob("I forgot to take my medically.")
get_sentence_prob("I forgot to take my turned.")

Should have similar/high probs

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medicine', '.', '[SEP]']
Word: i 	 Prob: 0.0010684080189093947
Word: forgot 	 Prob: 0.00016382183821406215
Word: to 	 Prob: 0.06681681424379349
Word: take 	 Prob: 0.0745348259806633
Word: my 	 Prob: 0.2779442071914673
Word: medicine 	 Prob: 0.0129512008279562
Word: . 	 Prob: 0.9979066848754883

Normalized forward sentence prob: log(P(sentence)) / sent_length: -2.9432893383006253

Processing sentence: ['[CLS]', 'i', 'forgot', 'to', 'take', 'my', 'medicine', '.', '[SEP]']
Word: . 	 Prob: 0.6934828162193298
Word: medicine 	 Prob: 4.02187870349735e-05
Word: my 	 Prob: 4.184951831120998e-05
Word: take 	 Prob: 0.012042323127388954
Word: to 	 Prob: 0.7947912812232971
Word: forgot 	 Prob: 0.007729522883892059
Word: i 	 Prob: 0.9906209707260132

Normalized backward sentence prob: log(P(sentence)) / sent_length: -3.343307657788197


Average normalized sentence prob: log(P(sentence)) / sent_length: 

In [132]:
get_sentence_prob("We will explore the elements used to construct sentences, and what parts of speech are used to expand and elaborate on them.")
get_sentence_prob("Wikipedia is a multilingual online encyclopedia created and maintained as an open collaboration project by a community of volunteer editors.")
get_sentence_prob("Once she gave her a little cap of red velvet, which suited her so well that she would never wear anything else.")

Processing sentence: ['[CLS]', 'we', 'will', 'explore', 'the', 'elements', 'used', 'to', 'construct', 'sentences', ',', 'and', 'what', 'parts', 'of', 'speech', 'are', 'used', 'to', 'expand', 'and', 'elaborate', 'on', 'them', '.', '[SEP]']
Word: we 	 Prob: 1.6642563423374668e-05
Word: will 	 Prob: 0.002434785244986415
Word: explore 	 Prob: 1.6068875993369147e-05
Word: the 	 Prob: 0.6480313539505005
Word: elements 	 Prob: 0.00011011799506377429
Word: used 	 Prob: 4.9860776925925165e-05
Word: to 	 Prob: 0.041129741817712784
Word: construct 	 Prob: 0.003627343336120248
Word: sentences 	 Prob: 7.629900551364699e-07
Word: , 	 Prob: 0.0999969020485878
Word: and 	 Prob: 0.08384449779987335
Word: what 	 Prob: 0.0009997623274102807
Word: parts 	 Prob: 0.005878846161067486
Word: of 	 Prob: 0.005989460740238428
Word: speech 	 Prob: 0.0002994531823787838
Word: are 	 Prob: 0.9055652022361755
Word: used 	 Prob: 0.10661078989505768
Word: to 	 Prob: 0.9451584219932556
Word: expand 	 Prob: 2.09432928386

In [None]:
# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
#get_sentence_prob("I fed my cat some of it and he damn near passed out")
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in France.")
get_sentence_prob("He was born in window.")
get_sentence_prob("He was born in was.")

In [None]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

In [None]:
print("Should have similar/high probs\n")
get_sentence_prob("I forgot to take my medicine.")
get_sentence_prob("I forgot to take my medicines.")
get_sentence_prob("I forgot to take my medication.")
get_sentence_prob("I forgot to take my pills.")
print("Should have low probs\n")
get_sentence_prob("I forgot to take my turn.")
get_sentence_prob("I forgot to take my medical.")
get_sentence_prob("I forgot to take my medically.")
get_sentence_prob("I forgot to take my turned.")