# FAIL... arithmetic average is clearly not the way


# Estimate sentence probability with BERT
## Calculating probability more properly:
P_f, P_b: Probability forward pass, backward pass, respectively
P_f = P(w_0) * P(w_1|w_0) * P(w_2|w_0, w_1) * ... * P(w_N)
P_b = P(w_N-1|w_N) * P(w_N-2|w_N-1, w_N) * ... * P(w_0|w_1, w_2, ... ,w_N)

P_f, P_b become smaller as the sentence length increases, hence, I try normalizing them by sentence length.
Here I use an arithmetic mean

Finally, the sentence probability P(S) is the geometric mean of forward and backwards probabilities:
```
P(S) = (mean(P_f(S)) * mean(P_b(S))) ^ (1/2)
```

In [1]:
import numpy as np
import torch
from transformers import BertTokenizer, BertForMaskedLM

BOS_TOKEN = '[CLS]'
EOS_TOKEN = '[SEP]'
MASK_TOKEN = '[MASK]'

# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-large-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

In [2]:
def print_top_predictions(probs, k=5):
    probs = probs.detach().numpy()
    top_indexes = np.argpartition(probs, -k)[-k:]
    sorted_indexes = top_indexes[np.argsort(-probs[top_indexes])]
    top_tokens = tokenizer.convert_ids_to_tokens(sorted_indexes)
    print(f"Ordered top predicted tokens: {top_tokens}")
    print(f"Ordered top predicted values: {probs[sorted_indexes]}")

In [3]:
def get_sentence_prob(sentence, verbose=False):
    # Pre-process sentence, adding special tokens
    tokenized_input = tokenizer.tokenize(sentence)
    if tokenized_input[0] != BOS_TOKEN:
        tokenized_input.insert(0, BOS_TOKEN)
    if tokenized_input[-1] != EOS_TOKEN:
        tokenized_input.append(EOS_TOKEN)
    sent_len = len(tokenized_input)
    ids_input = tokenizer.convert_tokens_to_ids(tokenized_input)
    print(f"Processing sentence: {tokenized_input}\n")
    
    sm = torch.nn.Softmax(dim=0) # used to convert last hidden state to probs
    
    # Trying with arithmetic average
    sent_prob_forward = 0
    sent_prob_backwards = 0
    # Mask non-special tokens in forward direction; calculate their probabilities
    for i in range(1, len(tokenized_input) - 1): # Don't loop first and last tokens
        probs_forward = get_directional_prob(sm, tokenized_input, i, 'forward', verbose=verbose)
        probs_backwards = get_directional_prob(sm, tokenized_input, i, 'backwards', verbose=verbose)
        prob_forward = probs_forward[ids_input[i]] # Prediction for masked word 
        prob_backwards = probs_backwards[ids_input[i]] # Prediction for masked word 
        sent_prob_forward += prob_forward
        sent_prob_backwards += prob_backwards

        print(f"Word: {tokenized_input[i]} \t Prob_forward: {prob_forward}; Prob_backwards: {prob_backwards}")

    sent_prob_forward = sent_prob_forward.detach().numpy() / sent_len
    sent_prob_backwards = sent_prob_backwards.detach().numpy() / sent_len
    print(f"Geometric-mean forward sentence probability: {sent_prob_forward}")
    print(f"Geometric-mean backward sentence probability: {sent_prob_backwards}\n")
    
    # Obtain geometric average of forward and backward probs
    geom_mean_sent_prob = np.sqrt(sent_prob_forward * sent_prob_backwards)
    print(f"Average normalized sentence prob: {geom_mean_sent_prob}\n")
    return geom_mean_sent_prob

In [4]:
def get_directional_prob(sm, tokenized_input, i, direction, verbose=False):
    current_tokens = tokenized_input[:]
    if direction == 'backwards':
        current_tokens[1:i+1] = [MASK_TOKEN for j in range(i)]
    elif direction == 'forward':
        current_tokens[i:-1] = [MASK_TOKEN for j in range(len(tokenized_input) - 1 - i)]
    else:
        print("Direction can only be 'forward' or 'backwards'")
        exit()
    if verbose: 
        print()
        print(current_tokens)
        
    masked_input = torch.tensor([tokenizer.convert_tokens_to_ids(current_tokens)])
    predictions = model(masked_input)
    predictions = predictions[0]
    probs = sm(predictions[0, i]) # Softmax to get probabilities
    if verbose: 
        print_top_predictions(probs)
    
    return probs # Model redictions

In [5]:
get_sentence_prob("He answered unequivocally.")
get_sentence_prob("He answered quickly.", verbose=True)

Processing sentence: ['[CLS]', 'he', 'answered', 'une', '##qui', '##vo', '##cal', '##ly', '.', '[SEP]']

Word: he 	 Prob_forward: 0.0005044482531957328; Prob_backwards: 0.2814375162124634
Word: answered 	 Prob_forward: 0.0002913153439294547; Prob_backwards: 0.010978045873343945
Word: une 	 Prob_forward: 2.243901064957754e-07; Prob_backwards: 0.9987058639526367
Word: ##qui 	 Prob_forward: 0.0005762826185673475; Prob_backwards: 0.0008089180919341743
Word: ##vo 	 Prob_forward: 0.05035709589719772; Prob_backwards: 3.2233551792160142e-06
Word: ##cal 	 Prob_forward: 0.9999289512634277; Prob_backwards: 2.658714583958499e-05
Word: ##ly 	 Prob_forward: 0.9982821941375732; Prob_backwards: 0.0010952855227515101
Word: . 	 Prob_forward: 0.9998167157173157; Prob_backwards: 0.7406781911849976
Geometric-mean forward sentence probability: 0.30497572422027586
Geometric-mean backward sentence probability: 0.20337338447570802

Average normalized sentence prob: 0.249046070472127

Processing sentence: ['[CL

0.18735233926749056

In [6]:
get_sentence_prob("The guy with small hands demanded a quid pro quo.")
get_sentence_prob("The guy with small hands demanded an exchange.")

Processing sentence: ['[CLS]', 'the', 'guy', 'with', 'small', 'hands', 'demanded', 'a', 'qui', '##d', 'pro', 'quo', '.', '[SEP]']

Word: the 	 Prob_forward: 0.020117253065109253; Prob_backwards: 0.6742717027664185
Word: guy 	 Prob_forward: 4.888691910309717e-05; Prob_backwards: 0.0006307517760433257
Word: with 	 Prob_forward: 0.0006999199977144599; Prob_backwards: 0.06413324922323227
Word: small 	 Prob_forward: 4.887327304459177e-05; Prob_backwards: 0.008030521683394909
Word: hands 	 Prob_forward: 0.0014287744415923953; Prob_backwards: 7.560867379652336e-05
Word: demanded 	 Prob_forward: 2.932926236098865e-06; Prob_backwards: 1.76298769360983e-07
Word: a 	 Prob_forward: 0.000463048490928486; Prob_backwards: 0.06345833837985992
Word: qui 	 Prob_forward: 3.862149696942652e-06; Prob_backwards: 0.5241817235946655
Word: ##d 	 Prob_forward: 0.23301775753498077; Prob_backwards: 7.432691973008332e-07
Word: pro 	 Prob_forward: 3.778058089665137e-05; Prob_backwards: 0.004886653274297714
Word: qu

0.12314628185524801

In [7]:
get_sentence_prob("This is a sentence.")
get_sentence_prob("This is a macrame.", verbose=False)
get_sentence_prob("This is a joke.", verbose=False)
get_sentence_prob("Are you kidding me?", verbose=False)


Processing sentence: ['[CLS]', 'this', 'is', 'a', 'sentence', '.', '[SEP]']

Word: this 	 Prob_forward: 0.00012813373177777976; Prob_backwards: 0.060409143567085266
Word: is 	 Prob_forward: 0.3984246850013733; Prob_backwards: 0.06603307276964188
Word: a 	 Prob_forward: 0.06797751039266586; Prob_backwards: 0.23075106739997864
Word: sentence 	 Prob_forward: 0.00013198245142120868; Prob_backwards: 9.854356903815642e-07
Word: . 	 Prob_forward: 0.966092050075531; Prob_backwards: 0.8030216097831726
Geometric-mean forward sentence probability: 0.20467919962746756
Geometric-mean backward sentence probability: 0.16574512209211076

Average normalized sentence prob: 0.1841862615179811

Processing sentence: ['[CLS]', 'this', 'is', 'a', 'mac', '##ram', '##e', '.', '[SEP]']

Word: this 	 Prob_forward: 0.00017661201127339154; Prob_backwards: 0.055456843227148056
Word: is 	 Prob_forward: 0.13002026081085205; Prob_backwards: 0.1228228434920311
Word: a 	 Prob_forward: 0.7468298077583313; Prob_backwards:

0.285582080896641

In [8]:
get_sentence_prob("Rachel was wearing a lovely satin dress last night.")

Processing sentence: ['[CLS]', 'rachel', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']

Word: rachel 	 Prob_forward: 1.345694727206137e-05; Prob_backwards: 0.0008810244617052376
Word: was 	 Prob_forward: 0.2645097076892853; Prob_backwards: 0.9986034035682678
Word: wearing 	 Prob_forward: 3.774421929847449e-05; Prob_backwards: 0.1172424703836441
Word: a 	 Prob_forward: 0.9133190512657166; Prob_backwards: 0.8966140151023865
Word: lovely 	 Prob_forward: 0.0004488571430556476; Prob_backwards: 8.846891432767734e-05
Word: satin 	 Prob_forward: 0.0010469065746292472; Prob_backwards: 1.9559594875317998e-05
Word: dress 	 Prob_forward: 0.02077450044453144; Prob_backwards: 2.075352995234425e-06
Word: last 	 Prob_forward: 0.003023444674909115; Prob_backwards: 0.05977332592010498
Word: night 	 Prob_forward: 0.9490013122558594; Prob_backwards: 0.0001843223290052265
Word: . 	 Prob_forward: 0.9899153709411621; Prob_backwards: 0.8182555437088013
Geometric-mean forwar

0.2511898221404551

In [9]:
get_sentence_prob("Rachel was wearing a lovely satin dress last night.")
get_sentence_prob("Grandma was wearing a lovely satin dress last night.")
get_sentence_prob("Mother was wearing a lovely satin dress last night.")
get_sentence_prob("She was wearing a lovely satin dress last night.")
get_sentence_prob("He was wearing a lovely satin dress last night.")
get_sentence_prob("I was wearing a lovely satin dress last night.")
get_sentence_prob("Angela was wearing a lovely satin dress last night.")
get_sentence_prob("Roberta was wearing a lovely satin dress last night.")
get_sentence_prob("Running was wearing a lovely satin dress last night.")

Processing sentence: ['[CLS]', 'rachel', 'was', 'wearing', 'a', 'lovely', 'satin', 'dress', 'last', 'night', '.', '[SEP]']

Word: rachel 	 Prob_forward: 1.345694727206137e-05; Prob_backwards: 0.0008810244617052376
Word: was 	 Prob_forward: 0.2645097076892853; Prob_backwards: 0.9986034035682678
Word: wearing 	 Prob_forward: 3.774421929847449e-05; Prob_backwards: 0.1172424703836441
Word: a 	 Prob_forward: 0.9133190512657166; Prob_backwards: 0.8966140151023865
Word: lovely 	 Prob_forward: 0.0004488571430556476; Prob_backwards: 8.846891432767734e-05
Word: satin 	 Prob_forward: 0.0010469065746292472; Prob_backwards: 1.9559594875317998e-05
Word: dress 	 Prob_forward: 0.02077450044453144; Prob_backwards: 2.075352995234425e-06
Word: last 	 Prob_forward: 0.003023444674909115; Prob_backwards: 0.05977332592010498
Word: night 	 Prob_forward: 0.9490013122558594; Prob_backwards: 0.0001843223290052265
Word: . 	 Prob_forward: 0.9899153709411621; Prob_backwards: 0.8182555437088013
Geometric-mean forwar

Word: roberta 	 Prob_forward: 3.867816758429399e-06; Prob_backwards: 1.2631488971237559e-05
Word: was 	 Prob_forward: 0.0005287343519739807; Prob_backwards: 0.9986034035682678
Word: wearing 	 Prob_forward: 0.0002253064449178055; Prob_backwards: 0.1172424703836441
Word: a 	 Prob_forward: 0.8714913725852966; Prob_backwards: 0.8966140151023865
Word: lovely 	 Prob_forward: 0.0008630887023173273; Prob_backwards: 8.846891432767734e-05
Word: satin 	 Prob_forward: 0.0011990604689344764; Prob_backwards: 1.9559594875317998e-05
Word: dress 	 Prob_forward: 0.018998507410287857; Prob_backwards: 2.075352995234425e-06
Word: last 	 Prob_forward: 0.0027865259908139706; Prob_backwards: 0.05977332592010498
Word: night 	 Prob_forward: 0.9439453482627869; Prob_backwards: 0.0001843223290052265
Word: . 	 Prob_forward: 0.9862309098243713; Prob_backwards: 0.8182555437088013
Geometric-mean forward sentence probability: 0.23552272717158
Geometric-mean backward sentence probability: 0.2408996820449829

Average no

0.23146321625894048

In [10]:
get_sentence_prob("The man ate the steak.")
get_sentence_prob("The man who arrived late ate the steak with a glass of wine.")
get_sentence_prob("The steak was eaten by the man.")
get_sentence_prob("The stake ate the man.")

Processing sentence: ['[CLS]', 'the', 'man', 'ate', 'the', 'steak', '.', '[SEP]']

Word: the 	 Prob_forward: 0.006482909899204969; Prob_backwards: 0.9430958032608032
Word: man 	 Prob_forward: 0.0008543190779164433; Prob_backwards: 4.835774234379642e-05
Word: ate 	 Prob_forward: 0.00030972290551289916; Prob_backwards: 0.014627532102167606
Word: the 	 Prob_forward: 0.16035766899585724; Prob_backwards: 0.024790508672595024
Word: steak 	 Prob_forward: 0.004650171846151352; Prob_backwards: 8.405078006035183e-06
Word: . 	 Prob_forward: 0.9944341778755188; Prob_backwards: 0.7197229266166687
Geometric-mean forward sentence probability: 0.1458861231803894
Geometric-mean backward sentence probability: 0.2127866894006729

Average normalized sentence prob: 0.17618917441504126

Processing sentence: ['[CLS]', 'the', 'man', 'who', 'arrived', 'late', 'ate', 'the', 'steak', 'with', 'a', 'glass', 'of', 'wine', '.', '[SEP]']

Word: the 	 Prob_forward: 0.08750123530626297; Prob_backwards: 0.93333035707473

0.16776947385712362

In [11]:
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in France.")
get_sentence_prob("He was born in window.")
get_sentence_prob("He was born in was.")


Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'berlin', '.', '[SEP]']

Word: he 	 Prob_forward: 0.0006875116960145533; Prob_backwards: 0.7967859506607056
Word: was 	 Prob_forward: 0.611086905002594; Prob_backwards: 0.999992847442627
Word: born 	 Prob_forward: 0.0001007601385936141; Prob_backwards: 0.0007875352748669684
Word: in 	 Prob_forward: 0.9955852031707764; Prob_backwards: 0.07671105861663818
Word: berlin 	 Prob_forward: 0.02386103756725788; Prob_backwards: 7.354922854574397e-05
Word: . 	 Prob_forward: 0.9999347925186157; Prob_backwards: 0.7197229266166687
Geometric-mean forward sentence probability: 0.3289070129394531
Geometric-mean backward sentence probability: 0.3242592215538025

Average normalized sentence prob: 0.326574848969319

Processing sentence: ['[CLS]', 'he', 'was', 'born', 'in', 'santiago', '.', '[SEP]']

Word: he 	 Prob_forward: 0.0006875116960145533; Prob_backwards: 0.7612152695655823
Word: was 	 Prob_forward: 0.611086905002594; Prob_backwards: 0.99992

0.3145921109671407

In [12]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

Processing sentence: ['[CLS]', 'i', 'fed', 'my', 'cat', 'some', 'of', 'it', 'and', 'he', 'damn', 'near', 'passed', 'out', '.', '[SEP]']

Word: i 	 Prob_forward: 0.0005238083540461957; Prob_backwards: 0.9958482980728149
Word: fed 	 Prob_forward: 1.7332116840407252e-05; Prob_backwards: 0.17128999531269073
Word: my 	 Prob_forward: 0.012338080443441868; Prob_backwards: 0.10841826349496841
Word: cat 	 Prob_forward: 0.00244920514523983; Prob_backwards: 1.8786481632560026e-06
Word: some 	 Prob_forward: 0.007367619778960943; Prob_backwards: 0.0040093897841870785
Word: of 	 Prob_forward: 1.723681816656608e-05; Prob_backwards: 0.06822692602872849
Word: it 	 Prob_forward: 0.00012771011097356677; Prob_backwards: 2.4723083697608672e-05
Word: and 	 Prob_forward: 0.5446289777755737; Prob_backwards: 0.5079375505447388
Word: he 	 Prob_forward: 0.0061583081260323524; Prob_backwards: 0.048276130110025406
Word: damn 	 Prob_forward: 2.3060283638187684e-05; Prob_backwards: 0.7008697390556335
Word: near 	 Pr

0.17541330560253213

In [None]:
print("Should have similar/high probs\n")
get_sentence_prob("I forgot to take my medicine.")
get_sentence_prob("I forgot to take my medicines.")
get_sentence_prob("I forgot to take my medication.")
get_sentence_prob("I forgot to take my pills.")
print("Should have low probs\n")
get_sentence_prob("I forgot to take my turn.")
get_sentence_prob("I forgot to take my medical.")
get_sentence_prob("I forgot to take my medically.")
get_sentence_prob("I forgot to take my turned.")

In [6]:
get_sentence_prob("We will explore the elements used to construct sentences, and what parts of speech are used to expand and elaborate on them.")
get_sentence_prob("Wikipedia is a multilingual online encyclopedia created and maintained as an open collaboration project by a community of volunteer editors.")
get_sentence_prob("Once she gave her a little cap of red velvet, which suited her so well that she would never wear anything else.")

Processing sentence: ['[CLS]', 'we', 'will', 'explore', 'the', 'elements', 'used', 'to', 'construct', 'sentences', ',', 'and', 'what', 'parts', 'of', 'speech', 'are', 'used', 'to', 'expand', 'and', 'elaborate', 'on', 'them', '.', '[SEP]']

Word: we 	 Prob_forward: 1.6642563423374668e-05; Prob_backwards: 0.01786622405052185
Word: will 	 Prob_forward: 0.002434785244986415; Prob_backwards: 0.06851402670145035
Word: explore 	 Prob_forward: 1.6068875993369147e-05; Prob_backwards: 0.0008845478878356516
Word: the 	 Prob_forward: 0.6480313539505005; Prob_backwards: 0.5417339205741882
Word: elements 	 Prob_forward: 0.00011011799506377429; Prob_backwards: 8.632599201519042e-05
Word: used 	 Prob_forward: 4.9860776925925165e-05; Prob_backwards: 0.7221975922584534
Word: to 	 Prob_forward: 0.041129741817712784; Prob_backwards: 0.9706997871398926
Word: construct 	 Prob_forward: 0.003627343336120248; Prob_backwards: 5.39941611350514e-05
Word: sentences 	 Prob_forward: 7.629900551364699e-07; Prob_backw

0.1678972461534877

In [None]:
# Load pre-trained model (weights)
with torch.no_grad():
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.eval()
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
#get_sentence_prob("I fed my cat some of it and he damn near passed out")
get_sentence_prob("He was born in Berlin.")
get_sentence_prob("He was born in Santiago.")
get_sentence_prob("He was born in France.")
get_sentence_prob("He was born in window.")
get_sentence_prob("He was born in was.")

In [None]:
get_sentence_prob("I fed my cat some of it and he damn near passed out.")
get_sentence_prob("I fed my dog some of it and he damn near passed out.")
get_sentence_prob("I fed my window some of it and he damn near passed out.")
get_sentence_prob("I fed my the some of it and he damn near passed out.")

In [None]:
print("Should have similar/high probs\n")
get_sentence_prob("I forgot to take my medicine.")
get_sentence_prob("I forgot to take my medicines.")
get_sentence_prob("I forgot to take my medication.")
get_sentence_prob("I forgot to take my pills.")
print("Should have low probs\n")
get_sentence_prob("I forgot to take my turn.")
get_sentence_prob("I forgot to take my medical.")
get_sentence_prob("I forgot to take my medically.")
get_sentence_prob("I forgot to take my turned.")