In [1]:
import torch

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM

torch_device=torch.device('cuda')

bert_model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_model_mlm.eval()
bert_model_mlm.to(torch_device)

for param in bert_model_mlm.parameters():
    param.requires_grad = False

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

bert_id2tok = dict()
for tok, tok_id in bert_tokenizer.vocab.items():
    bert_id2tok[tok_id] = tok

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
#Parameters

MAX_BERT_LEN=256
MAX_COSINE_DIST=0.3
BERT_VOCAB_QTY=30000

num_threads=8
K=10

In [3]:
import numpy as np

ft_compiled_path = "../data/jigsaw/ft_model_bert_basic_tok.npy" # Embeddings generated from the vocabulary
fasttext_embeds = np.load(ft_compiled_path)

In [5]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from pytorch_pretrained_bert.tokenization import BasicTokenizer
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
import re

#_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words
_bert_tok = BasicTokenizer(do_lower_case=True)

spacy_tokenizer = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False)

from allennlp.data.token_indexers import SingleIdTokenIndexer
token_indexer = SingleIdTokenIndexer(
    lowercase_tokens=True,
)

from itertools import groupby

def remove_url(s):
    return re.sub(r"http\S+", "", s)

def remove_extra_chars(s, max_qty=2):
    res = [c * min(max_qty, len(list(group_iter))) for c, group_iter in groupby(s)] 
    return ''.join(res)

def tokenizer(x: str):
    return [remove_extra_chars(w) for w in _bert_tok.tokenize(remove_url(x))]
    #return [w.text for w in _spacy_tok(x.lower())]

In [6]:
"n't" in bert_tokenizer.vocab

False

In [7]:
len(bert_tokenizer.vocab)

30522

In [8]:
bert_id2tok[2]

'[unused1]'

In [9]:
bert_tokenizer.vocab['sh']

14021

In [11]:
toks = bert_tokenizer.tokenize('[CLS] detete what the [MASK] are you doing here ? [SEP]')
toks

['[CLS]',
 'det',
 '##ete',
 'what',
 'the',
 '[MASK]',
 'are',
 'you',
 'doing',
 'here',
 '?',
 '[SEP]']

In [14]:
toks = spacy_tokenizer.split_words("[CLS] what the [MASK] are you don't here ? [SEP]")
toks

[[, CLS, ], what, the, [, MASK, ], are, you, do, n't, here, ?, [, SEP, ]]

In [16]:
tokenizer("don't  couldn't can't you're I'm sheeeet")

['don',
 "'",
 't',
 'couldn',
 "'",
 't',
 'can',
 "'",
 't',
 'you',
 "'",
 're',
 'i',
 "'",
 'm',
 'sheet']

In [29]:
bert_tokenizer.convert_tokens_to_ids(bert_tokenizer.tokenize("[CLS] [MASK] [SEP]"))

[101, 103, 102]

In [31]:
bert_tokenizer.convert_tokens_to_ids(bert_tokenizer.tokenize("mant"))

[2158, 2102]

In [44]:
from collections import namedtuple
# pos_oov is OOV index with respect to the original (not BERT) tokenizer!!!
UtterData = namedtuple('SentData', ['batch_sent_id', 'pos_oov', 'tok_ids', 'oov_token'])


In [109]:
# Sentence list contains an arrays of arrays (without [CLS] and [SEP] tokens)
def create_batch(bert_tokenizer, sent_list): 
    
    batch_data = []
    
    batch_max_seq_qty = 0
    
    for one_sent_tok_arr in sent_list:
    
        bert_toks = ["[CLS]"]
        
        for tok in one_sent_tok_arr:
            bert_toks.extend(bert_tokenizer.tokenize(tok))
            
        bert_toks.append("[SEP]")
        
        one_sent_ids = bert_tokenizer.convert_tokens_to_ids(bert_toks)
        batch_data.append(one_sent_ids)
    
        batch_max_seq_qty = max(batch_max_seq_qty, len(one_sent_ids))
        
    batch_qty = len(batch_data)
    
    tok_ids_batch = np.zeros( (batch_qty, batch_max_seq_qty), dtype=np.int64) # zero is a padding symbol
    tok_mask_batch = np.zeros( (batch_qty, batch_max_seq_qty), dtype=np.int64)
    for k in range(batch_qty):
        tok_ids = batch_data[k]
        tok_qty = len(tok_ids)
        tok_ids_batch[k, 0:tok_qty] = tok_ids
        tok_mask_batch[k, 0:tok_qty] = np.ones(tok_qty)
        
    return tok_ids_batch, tok_mask_batch

[[101, 2023, 2003, 1037, 3231, 102], [101, 2023, 2003, 1037, 27838, 3367, 102]]

In [112]:
import torch
import numpy as np
from collections import namedtuple
import sys


def get_bert_logits_for_words_batch(torch_device, bert_model_mlm, 
                                    tok_ids_batch, tok_mask_batch):
    
    print(tok_ids_batch)

    batch_qty, batch_seq_len = tok_ids_batch.shape
    
    word_ids = list(set([e for r in tok_ids_batch for e in r]))
    
    word_ids_logit_id = { word_ids[i] : i for i in range(len(word_ids)) }
    
    tok_ids_batch = torch.from_numpy(tok_ids_batch).to(device=torch_device) 
    tok_mask_batch = torch.from_numpy(tok_mask_batch).to(device=torch_device) 

    seg_ids = torch.zeros_like(tok_ids_batch, device=torch_device)

    # Main BERT model see modeling.py in https://github.com/huggingface/pytorch-pretrained-BERT
    bert = bert_model_mlm.bert 
    # cls is an instance of BertOnlyMLMHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    cls = bert_model_mlm.cls
    # predictions are of the type BertLMPredictionHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    predictions = cls.predictions
    transform = predictions.transform
   
    # We don't use the complete decoding matrix, but only selected rows
    word_ids = torch.from_numpy(np.array(word_ids, dtype=np.int64)).to(device=torch_device)
                                
    weight = predictions.decoder.weight[word_ids,:]
    bias = predictions.bias[word_ids]
    
    #print(bias[:10])


    # Transformations from the main BERT model
    sequence_output, _= bert(tok_ids_batch, 
                             seg_ids, 
                             attention_mask=tok_mask_batch, 
                             output_all_encoded_layers=False)
    # Transformations from the BertLMPredictionHead model with the restricted last layer
    hidden_states = transform(sequence_output)    
    logits_all = torch.nn.functional.linear(hidden_states, weight) + bias                            
                                        
    logits_all=logits_all.detach().cpu().numpy()
    tok_ids_batch = tok_ids_batch.cpu().numpy()
    
    avg_logits = np.zeros(batch_qty)
    logits_res = np.zeros( (batch_qty, batch_seq_len) )
    
    for r in range(batch_qty):

        qty = 0.0
        avg_logit = 0.0
    
        for c in range(batch_seq_len):
            word_id = tok_ids_batch[r, c]
            if word_id > 1:
                logits_res[r, c] = logits_all[r, c, word_ids_logit_id[word_id]]
                avg_logit += logits_res[r, c]
                qty += 1
                
        avg_logits[r] = avg_logit / qty

        
    return avg_logits, logits_res

In [140]:
batch_ids, batch_mask=create_batch(bert_tokenizer, [ 'What the fcuk are you [MASK] here?'.split(), 'What the fcuk are you [MASK] [MASK] here?'.split() ])
print(batch_ids.shape)
get_bert_logits_for_words_batch('cuda', bert_model_mlm, batch_ids, batch_mask)

(2, 12)
[[ 101 2054 1996 4429 6968 2024 2017  103 2182 1029  102    0]
 [ 101 2054 1996 4429 6968 2024 2017  103  103 2182 1029  102]]


(array([9.0451423 , 8.01771679]),
 array([[-5.64911366, 24.19732666, 14.6283226 , -1.78911757, 11.90443134,
         12.60415459, 16.94570732, -2.52056861, 12.2120018 , 23.39019966,
         -6.42677879,  0.        ],
        [-5.69555759, 23.93379021, 14.67827797, -1.88091636, 12.56620026,
         11.75910091, 17.70313454, -3.14499998, -3.27325583, 13.65651035,
         22.20828056, -6.29796362]]))

In [136]:
batch_ids, batch_mask=create_batch(bert_tokenizer, [ 'this is the best'.split()])
get_bert_logits_for_words_batch('cuda', bert_model_mlm, batch_ids, batch_mask)

[[ 101 2023 2003 1996 2190  102]]


(array([8.25289265]),
 array([[-6.02657032, 18.24155807, 19.6713829 , 18.3153038 ,  5.81563187,
         -6.49995041]]))

In [118]:
batch_ids, batch_mask=create_batch(bert_tokenizer, [ 'this is a test'.split(), 'this is a test test test'.split() ])
get_bert_logits_for_words_batch('cuda', bert_model_mlm, batch_ids, batch_mask)

[[ 101 2023 2003 1037 3231  102    0    0]
 [ 101 2023 2003 1037 3231 3231 3231  102]]


(array([7.16823435, 9.61663324]),
 array([[-6.08818054,  9.49454212, 20.47395706, 17.58529854,  7.42068577,
         -5.87689686,  0.        ,  0.        ],
        [-6.08461523,  5.79411411, 18.26132011, 18.86880875, 18.63547707,
         17.54462242,  9.32696629, -5.41362762]]))

In [24]:
word_ids = [ 10, 135, -1, 3]

word_ids_logit_id = { word_ids[i] : i for i in range(len(word_ids)) }

word_ids_logit_id

{10: 0, 135: 1, -1: 2, 3: 3}

In [34]:
bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                      'doin', 'wearing', 'making', 'thinking', 'all', 'test'])

[3109, 6616, 17752, 6548, 4485, 2725, 24341, 4147, 2437, 3241, 2035, 3231]

In [None]:
sent_list = ['What the fcuk are you doingg here?',
             'This is a *strangge* sentence']

batch_data_raw, tok_ids_batch = get_batch_data(torch_device, 
                                                tokenizer, 
                                                bert_tokenizer, 
                                                sent_list,
                                                MAX_BERT_LEN)

get_bert_preds_for_words_batch(torch_device,
                               bert_model_mlm, 
                               batch_data_raw, tok_ids_batch,
                               bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                                                  'doin', 'wearing', 'making', 'thinking']))