In [2]:
import transformers
from transformers import pipeline, AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import torch
from pprint import pprint
import logging
import re

transformers.logging.set_verbosity_error()

def print_header(header):
    print("\n\n----------------------------------------------------------")
    print(header)
    print("----------------------------------------------------------")

BERT = 'bert-large-uncased-whole-word-masking'



In [10]:
def setup(model):
    tokenizer = AutoTokenizer.from_pretrained(model)
    model = AutoModelForMaskedLM.from_pretrained(model)
    bert = pipeline("fill-mask", model=model, tokenizer=tokenizer)
    mask = bert.tokenizer.mask_token
    return bert, mask

def runBERT(bert, sentences, fillers):
    filler_results = {}
    top_preds = {}
    for sentence in sentences:
        filler_results[sentence] = {}
        for filler in fillers:
            filler_results[sentence][filler] = bert(sentence, targets=[filler])[0]["score"]
    return filler_results

In [13]:
sentences = [
    'The key [MASK] on the table', 
    'They keys [MASK] on the table', 
    'The key to the cabinets [MASK] on the table', 
    'The keys to the cabinet [MASK] on the table',
    'Yet the ratio of men who survive to the women and children who survive [MASK] not clear in this story',
    'The roses in the vase by the door [MASK] red']
fillers = ['is', 'are']
bert, mask = setup(BERT)
filler_results = runBERT(bert, sentences, fillers)

print_header('Filler Results')
pprint(filler_results, sort_dicts=False)



----------------------------------------------------------
Filler Results
----------------------------------------------------------
{'The key [MASK] on the table': {'is': 0.39408349990844727,
                                 'are': 0.0034097961615771055},
 'They keys [MASK] on the table': {'is': 5.7286724768346176e-05,
                                   'are': 0.32688844203948975},
 'The key to the cabinets [MASK] on the table': {'is': 0.2598051428794861,
                                                 'are': 0.005281456280499697},
 'The keys to the cabinet [MASK] on the table': {'is': 0.0019026300869882107,
                                                 'are': 0.682071328163147},
 'Yet the ratio of men who survive to the women and children who survive [MASK] not clear in this story': {'is': 0.983444094657898,
                                                                                                           'are': 0.00609173160046339},
 'The roses in the vase by the door 