In [1]:
import torch

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM

torch_device=torch.device('cuda')

bert_model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_model_mlm.eval()
bert_model_mlm.to(torch_device)

for param in bert_model_mlm.parameters():
    param.requires_grad = False

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

bert_id2tok = dict()
for tok, tok_id in bert_tokenizer.vocab.items():
    bert_id2tok[tok_id] = tok

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
import pickle
import numpy as np

ft_compiled_path = "../data/jigsaw/ft_compiled.npy" # Embeddings generated from the vocabulary
data_vocab_path = "../data/jigsaw/data_vocab.bin"

MAX_BERT_LEN=128
MAX_COSINE_DIST=0.3

num_threads=8
K=10

In [3]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

def tokenizer(s: str):
    s = s.lower()
    return [w.text for w in _spacy_tok(s)]

In [4]:
'世' in bert_tokenizer.vocab

True

In [5]:
bert_id2tok[2]

'[unused1]'

In [6]:
bert_tokenizer.vocab['test']

3231

In [7]:
# Returns arrays of arrays if there's an OOV word or an empty array instead
# Each array element is a tuple: 
# position of OOV word (with respect to the original tokenizer), sent for BERT tokenizer
def get_bert_masked_inputs(toks, bert_tokenizer):
    res = []
    
    oov_pos = []
    bert_vocab = bert_tokenizer.vocab
    
    for i in range(len(toks)):
        if toks[i] not in bert_vocab:
            oov_pos.append(i)
            

    for pos in oov_pos:
        res.append( (pos, '[CLS] %s [MASK] %s [SEP]' % 
                     (' '.join(toks[0:pos]), ' '.join(toks[pos+1:])) ) )
        
    return res

In [8]:
get_bert_masked_inputs(tokenizer('This is a *strangge* sentence.'), bert_tokenizer)

[(4, '[CLS] this is a * [MASK] * sentence . [SEP]')]

In [9]:
toks = bert_tokenizer.tokenize('[CLS] what the [MASK] are you doing here ? [SEP]')
toks

['[CLS]', 'what', 'the', '[MASK]', 'are', 'you', 'doing', 'here', '?', '[SEP]']

In [10]:
from collections import namedtuple
# pos_bert is OOV index with respect to the original (not BERT) tokenizer!!!
UtterData = namedtuple('SentData', ['sent_id', 'pos_bert', 'tok_ids', 'oov_token'])

def get_batch_data(torch_device, tokenizer, bert_tokenizer, sent_list, max_len=MAX_BERT_LEN):
    
    batch_data_raw = []
    batch_max_seq_qty = 0
    sent_id = -1
    for sent in sent_list:
        sent_id += 1
        sent_toks = tokenizer(sent)
        for sent_oov_pos, text in get_bert_masked_inputs(sent_toks, bert_tokenizer):
            # To accurately get what is the position of [MASK] according
            # to BERT tokenizer, we need to re-tokenize the sentence using
            # the BERT tokenizer
            all_bert_toks = bert_tokenizer.tokenize(text)
            bert_toks = all_bert_toks[0:max_len] # 512 is the max. Bert seq. length

            tok_ids = bert_tokenizer.convert_tokens_to_ids(bert_toks)
            pos_bert = None
            for i in range(len(bert_toks)):
                if bert_toks[i] == '[MASK]':
                    pos_bert = i
                    break
            assert(pos_bert is not None or len(all_bert_toks) > max_len)
            if pos_bert is not None:
                tok_qty = len(tok_ids)
                batch_max_seq_qty = max(batch_max_seq_qty, tok_qty)
                batch_data_raw.append( 
                    UtterData(sent_id=sent_id, 
                              pos_bert=pos_bert, 
                              tok_ids=tok_ids, 
                              oov_token=sent_toks[sent_oov_pos]) )
            
    batch_qty = len(batch_data_raw)
    tok_ids_batch = np.zeros( (batch_qty, batch_max_seq_qty), dtype=np.int64) # zero is a padding symbol
    for k in range(batch_qty):
        tok_ids = batch_data_raw[k].tok_ids
        tok_ids_batch[k, 0:len(tok_ids)] = tok_ids
        
                   
    tok_ids_batch = torch.from_numpy(tok_ids_batch).to(device=torch_device) 
    
    return batch_data_raw, tok_ids_batch

In [11]:
import torch
import numpy as np
from collections import namedtuple

BertPred = namedtuple('BertPred', ['sent_id', 'pos_bert', 'probs', 'toks'])

def get_bert_top_preds_batch(torch_device, bert_model_mlm, tokenizer, bert_tokenizer, sent_list, k, max_len=128):
    
    batch_data_raw, tok_ids_batch = get_batch_data(torch_device, 
                                                    tokenizer, 
                                                    bert_tokenizer, 
                                                    sent_list,
                                                    max_len)
    seg_ids = torch.zeros_like(tok_ids_batch, device=torch_device)

    batch_qty = len(batch_data_raw)
    
    model_out = bert_model_mlm(tok_ids_batch, seg_ids)
        
    t=torch.topk(torch.nn.functional.softmax(model_out, dim=2), k=k,dim=2)
    
    probs=t[0].detach().cpu().numpy()
    preds=t[1].cpu().numpy()
    
    res = []
    
    for k in range(batch_qty):
        
        pos_bert = batch_data_raw[k].pos_bert         
        res.append( BertPred(sent_id = batch_data_raw[k].sent_id,
                             pos_bert = pos_bert,
                             probs = probs[k, pos_bert],
                             toks = bert_tokenizer.convert_ids_to_tokens(preds[k, pos_bert]) 
                            ) 
                  )
    torch.cuda.empty_cache()
        
    return res

In [12]:
get_bert_top_preds_batch(torch_device,
                         bert_model_mlm, tokenizer, bert_tokenizer, 
                         ['What the fcuk are you doingg here?',
                          'This is a *strangge* sentence'], 5)

[BertPred(sent_id=0, pos_bert=3, probs=array([9.2102903e-01, 5.3823169e-02, 1.7165799e-02, 5.3161602e-03,
        2.7156417e-04], dtype=float32), toks=['hell', 'fuck', 'heck', 'devil', 'shit']),
 BertPred(sent_id=0, pos_bert=7, probs=array([9.9966323e-01, 1.8645465e-04, 3.9865798e-05, 1.3824779e-05,
        1.3223614e-05], dtype=float32), toks=['doing', 'doin', 'wearing', 'making', 'thinking']),
 BertPred(sent_id=1, pos_bert=5, probs=array([0.13034092, 0.07515147, 0.05074421, 0.05045162, 0.03581695],
       dtype=float32), toks=['a', '*', '.', '-', 'b'])]

In [13]:
import torch
import numpy as np
from collections import namedtuple
import sys

BertPredProbs = namedtuple('BertPred', ['sent_id', 'pos_bert', 'logits'])

def get_bert_preds_for_words_batch(torch_device, bert_model_mlm, 
                                   batch_data_raw, tok_ids_batch, # comes from get_batch_data
                                   word_ids, # a list of IDS for which we generate logits
                                   max_len=MAX_BERT_LEN):

    seg_ids = torch.zeros_like(tok_ids_batch, device=torch_device)
    
    batch_qty = len(batch_data_raw)
    
    # Main BERT model see modeling.py in https://github.com/huggingface/pytorch-pretrained-BERT
    bert = bert_model_mlm.bert 
    # cls is an instance of BertOnlyMLMHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    cls = bert_model_mlm.cls
    # predictions are of the type BertLMPredictionHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    predictions = cls.predictions
    transform = predictions.transform
   
    # We don't use the complete decoding matrix, but only selected rows
    word_ids = torch.from_numpy(np.array(word_ids, dtype=np.int64)).to(device=torch_device)
                                
    weight = predictions.decoder.weight[word_ids,:]
    bias = predictions.bias[word_ids]

    # Transformations from the main BERT model
    sequence_output, _= bert(tok_ids_batch, seg_ids, attention_mask=None, output_all_encoded_layers=False)
    # Transformations from the BertLMPredictionHead model with the restricted last layer
    hidden_states = transform(sequence_output)    
    logits = torch.nn.functional.linear(hidden_states, weight) + bias                            
                                        
    logits=logits.detach().cpu().numpy()
    
    res = []
    
    for k in range(batch_qty):
        
        pos_bert = batch_data_raw[k].pos_bert         
        res.append( BertPredProbs(sent_id = batch_data_raw[k].sent_id,
                             pos_bert = pos_bert,
                             logits = logits[k, pos_bert]
                            ) 
                  )
                                
    torch.cuda.empty_cache()
        
    return res

In [14]:
bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                      'doin', 'wearing', 'making', 'thinking', 'all', 'test'])

[3109, 6616, 17752, 6548, 4485, 2725, 24341, 4147, 2437, 3241, 2035, 3231]

In [15]:
sent_list = ['What the fcuk are you doingg here?',
             'This is a *strangge* sentence']

batch_data_raw, tok_ids_batch = get_batch_data(torch_device, 
                                                tokenizer, 
                                                bert_tokenizer, 
                                                sent_list,
                                                MAX_BERT_LEN)

get_bert_preds_for_words_batch(torch_device,
                               bert_model_mlm, 
                               batch_data_raw, tok_ids_batch,
                               bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                                                  'doin', 'wearing', 'making', 'thinking']))

[BertPred(sent_id=0, pos_bert=3, logits=array([15.889928 , 13.05014  , 11.907355 , 10.735188 ,  7.7608795,
         1.1303823, -0.362491 , -2.6011868, -0.1715025, -1.2756928],
       dtype=float32)),
 BertPred(sent_id=0, pos_bert=7, logits=array([-6.5866728e+00, -3.7194698e+00, -4.6205878e+00, -7.7321987e+00,
        -1.2921356e-03,  2.0477596e+01,  1.1890611e+01,  1.0347941e+01,
         9.2888851e+00,  9.2444267e+00], dtype=float32)),
 BertPred(sent_id=1, pos_bert=5, logits=array([-3.2651033 , -0.14229655, -4.819803  , -4.6907854 , -0.7839051 ,
        -1.4695988 , -3.3453448 , -5.069694  , -2.1900396 , -2.5339675 ],
       dtype=float32))]

In [16]:
#bert_model_mlm.to('cpu')
#torch.zeros(3,device=torch.device("cuda"))

In [17]:
bert_tokenizer.tokenize('б')

['б']

In [18]:
from typing import *
from overrides import overrides
import allennlp
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
from allennlp.data.fields import TextField, SequenceLabelField, LabelField, MetadataField, ArrayField
class MemoryOptimizedTextField(TextField):
    @overrides
    def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
        self.tokens = tokens
        self._token_indexers = token_indexers
        self._indexed_tokens: Optional[Dict[str, TokenList]] = None
        self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
        # skip checks for tokens
    @overrides
    def index(self, vocab):
        super().index(vocab)
        self.tokens = None # empty tokens


In [19]:
import pickle
import numpy as np

ft_compiled_path = "../data/jigsaw/ft_compiled.npy" # Embeddings generated from the vocabulary
data_vocab_path = "../data/jigsaw/data_vocab.bin"

In [20]:
fasttext_embeds = np.load(ft_compiled_path)
vocab=pickle.load(open(data_vocab_path,'rb'))

In [21]:
bert_vocab_toks = bert_tokenizer.vocab.keys()
vocab_toks = set( [w for idx, w in vocab.get_index_to_token_vocabulary().items() ])
len(vocab_toks), len(bert_vocab_toks)

(305140, 30522)

In [22]:
# These are BERT vocabulary word IDs in the *MAIN vocabulary!*, not IDs in the BERT vocab.
bert_vocab_ids = []

for tok in bert_vocab_toks:
    tok_id = vocab.get_token_index(tok)
    if tok_id > 1:
        bert_vocab_ids.append(tok_id)
        
bert_vocab_ids = np.array(bert_vocab_ids)
fasttext_embeds[bert_vocab_ids].shape

(22778, 300)

In [23]:
import nmslib, time

M = 30
efC = 200

num_threads = 0
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post' : 0}
print('Index-time parameters', index_time_params)

# Space name should correspond to the space name 
# used for brute-force search
space_name='cosinesimil'


# Intitialize the library, specify the space, the type of the vector and add data points 
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) 
index.addDataPointBatch(fasttext_embeds[bert_vocab_ids], bert_vocab_ids)

# Create an index
start = time.time()
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
index.createIndex(index_time_params) 
end = time.time() 
print('Index-time parameters', index_time_params)
print('Indexing time = %f' % (end-start))


Index-time parameters {'M': 30, 'indexThreadQty': 0, 'efConstruction': 200, 'post': 0}
Index-time parameters {'M': 30, 'indexThreadQty': 0, 'efConstruction': 200}
Indexing time = 4.955989


In [24]:
# Setting query-time parameters
efS = 200
K=10
query_time_params = {'efSearch': efS}
print('Setting query-time parameters', query_time_params)
index.setQueryTimeParams(query_time_params)

Setting query-time parameters {'efSearch': 200}


In [25]:
import pickle
with open("../data/jigsaw/val_ds.bin", "rb") as f:
    val_ds = pickle.load(f)

In [26]:
import time, gc
t0 = time.time()
preds = []
sents = [' '.join(t['tokens']) for t in val_ds]
    #print(s)
    #preds.append(get_bert_top_preds(tokenizer, bert_tokenizer, s, 2))
    #preds.append(get_bert_masked_inputs(tokenizer(s), bert_tokenizer, sent))


batch_qty = 25

for sent_id in range(0, len(sents), batch_qty):
    print('Batch start', sent_id)
    
    sent_list = sents[sent_id:sent_id + batch_qty]
    
    # batch_data raw contains elements
    # UtterData = namedtuple('SentData', ['sent_id', 'pos_bert', 'tok_ids', 'oov_token'])
    # NOTE: pos_bert is OOV index with respect to the original (not BERT) tokenizer!!!
    #
    # tok_ids_batch is a Tensor with padded Bert-specific token IDs ready
    # to be fed into a BERT model
    batch_data_raw, tok_ids_batch = get_batch_data(torch_device,
                                                 tokenizer, bert_tokenizer,
                                                 sent_list, 
                                                 MAX_BERT_LEN)
    
    query_arr = []
    query_tok_oov_id = []
    
    for e in batch_data_raw: 
        w = e.oov_token
        wid = vocab.get_token_index(w)
        query_arr.append(fasttext_embeds[wid])
        query_tok_oov_id.append(wid)
        
    query_arr = np.array(query_arr)
    query_matrix = np.array(query_arr)
    query_qty = query_matrix.shape[0]
    
    print('Query matrix shape:', query_matrix.shape)
    
    start = time.time() 
    # nbrs is array of tuples (neighbor array, distance array)
    # For cosine, the distance is 1 - cosine similarity
    # k-NN search returns Bert-specific token IDs
    nbrs = index.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)
    end = time.time() 
    print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % 
          (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty))
    
    neighb_tok_ids=set()
    
    for qid in range(query_qty):
        if query_tok_oov_id[qid] > 1:
            nbrs_ids = nbrs[qid][0]
            nbrs_dist = nbrs[qid][1]
            
            nqty = len(nbrs_ids)
            for t in range(nqty):
                if nbrs_dist[t] < MAX_COSINE_DIST:
                    neighb_tok_ids.add(nbrs_ids[t])
                
    
    
    neighb_tok_ids = list(neighb_tok_ids)
    
    preds = get_bert_preds_for_words_batch(torch_device,
                                           bert_model_mlm,
                                           batch_data_raw, tok_ids_batch,
                                           neighb_tok_ids)
    
    assert(len(preds) == query_qty)
    for qid in range(query_qty):
        if query_tok_oov_id[qid] > 1:  
            # Let's map neighbor IDs from each queries to respective 
            # logits from the prediction set
            logit_map = dict() # from Bert-specific token IDs to predicted logits
            assert(len(preds[qid].logits) == len(neighb_tok_ids))
            for i in range(len(neighb_tok_ids)):
                logit_map[neighb_tok_ids[i]] = preds[qid].logits[i]

            e = batch_data_raw[qid]
            print(sent_list[e.sent_id])
            print("### OOV ###", e.oov_token)
            print([vocab.get_token_from_index(bert_vocab_ids[i]) for i in nbrs[qid][0]])

            nbr_logits = []
            
            nbrs_ids = nbrs[qid][0]
            nbrs_dist = nbrs[qid][1]
            
            #print('Logit map:', logit_map)
            #print('neighb_tok_ids', neighb_tok_ids)
            
            nqty = len(nbrs_ids)
            for t in range(nqty):
                nid = nbrs_ids[t]
                # nid is Bert-speicifc token ID
                if not nid in neighb_tok_ids:
                    print('Missing %s distance %g ' 
                          % (bert_id2tok[nid],
                            nbrs_dist[t]))
                else:
                    #print()
                    nbr_logits.append(logit_map[nid])


            #print(preds[qid])
            print("====================================================================")

    
    
    gc.collect()
    torch.cuda.empty_cache()
    
    break
    
t1 = time.time()
print('# of sentences:', len(sents), ' time elapsed:', t1 - t0)

Batch start 0
Query matrix shape: (86, 300)
kNN time total=0.024678 (sec), per query=0.000287 (sec), per query adjusted for thread number=0.000000 (sec)


RuntimeError: CUDA error: device-side assert triggered

In [None]:
#torch_device=torch.device('cuda')
#bert_model_mlm.to(torch_device)