In [145]:
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM

bert_model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

#bert_model_mlm.cuda()
#bert_tokenizer.cuda()

bert_id2tok = dict()
for tok, tok_id in bert_tokenizer.vocab.items():
    bert_id2tok[tok_id] = tok

In [146]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

def tokenizer(s: str):
    s = s.lower()
    return [w.text for w in _spacy_tok(s)]

In [147]:
'doingg' in bert_tokenizer.vocab

False

In [173]:
# Returns arrays of arrays!
# Each array element is a tuple: 
# position of OOV word (with respect to the original tokenizer), sent for BERT tokenizer
def get_bert_masked_inputs(tokenizer, bert_tokenizer, orig_text):
    toks = tokenizer(orig_text)
    
    res = []
    
    oov_pos = []
    bert_vocab = bert_tokenizer.vocab
    
    for i in range(len(toks)):
        if toks[i] not in bert_vocab:
            oov_pos.append(i)
            

    for pos in oov_pos:
        res.append( (pos, '[CLS] %s [MASK] %s [SEP]' % (' '.join(toks[0:pos]), ' '.join(toks[pos+1:])) ) )
        
    return res

In [174]:
get_bert_masked_inputs(tokenizer, bert_tokenizer, 'What the fcuk are you doingg here?')

[(2, '[CLS] what the [MASK] are you doingg here ? [SEP]'),
 (5, '[CLS] what the fcuk are you [MASK] here ? [SEP]')]

In [175]:
toks = bert_tokenizer.tokenize('[CLS] what the [MASK] are you doing here ? [SEP]')
toks

['[CLS]', 'what', 'the', '[MASK]', 'are', 'you', 'doing', 'here', '?', '[SEP]']

In [176]:
import torch
def get_bert_top_preds(tokenizer, bert_tokenizer, sent, k):
    res = []
    for pos, text in get_bert_masked_inputs(tokenizer, bert_tokenizer, sent):
        # To accurately get what is the position of [MASK] according
        # to BERT tokenizer, we need to re-tokenize the sentence using
        # the BERT tokenizer
        toks = bert_tokenizer.tokenize(text)
        tok_ids = torch.LongTensor(bert_tokenizer.convert_tokens_to_ids(toks)).unsqueeze(0)
        pos_bert = None
        for i in range(len(toks)):
            if toks[i] == '[MASK]':
                pos_bert = i
                break
        assert(pos_bert is not None)
        print(pos, text)
        tok_ids = torch.LongTensor(bert_tokenizer.convert_tokens_to_ids(toks)).unsqueeze(0)
        seg_ids = torch.zeros(tok_ids.shape[1], dtype=torch.int64).unsqueeze(0)
        preds=torch.topk(bert_model_mlm(tok_ids, seg_ids), k=k,dim=2)[1].squeeze().numpy()
        res.append( (pos, bert_tokenizer.convert_ids_to_tokens(preds[pos_bert]) ) )
    
    return res
    

In [177]:
get_bert_top_preds(tokenizer, bert_tokenizer, 'What the fcuk are you doingg here?', 5)

2 [CLS] what the [MASK] are you doingg here ? [SEP]
5 [CLS] what the fcuk are you [MASK] here ? [SEP]


[(2, ['hell', 'fuck', 'heck', 'devil', '...']),
 (5, ['doing', 'doin', 'saying', 'thinking', 'wearing'])]