In [1]:
import torch

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM

#torch_device=torch.device('cuda')
torch_device=torch.device('cpu')

bert_model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_model_mlm.eval()
bert_model_mlm.to(torch_device)

for param in bert_model_mlm.parameters():
    param.requires_grad = False

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

bert_id2tok = dict()
for tok, tok_id in bert_tokenizer.vocab.items():
    bert_id2tok[tok_id] = tok
    

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
import spacy
spacy_nlp=spacy.load('en_core_web_sm')

In [3]:
#Parameters

MAX_BERT_LEN=256
MAX_COSINE_DIST=0.3
BERT_VOCAB_QTY=30000

num_threads=8
K=10

In [4]:
import numpy as np

ft_compiled_path = "../data/jigsaw/ft_model_bert_basic_tok.npy" # Embeddings generated from the vocabulary
fasttext_embeds = np.load(ft_compiled_path)

In [5]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from pytorch_pretrained_bert.tokenization import BasicTokenizer
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
import re

#_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words
_bert_tok = BasicTokenizer(do_lower_case=True)

spacy_tokenizer = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False)

from allennlp.data.token_indexers import SingleIdTokenIndexer
token_indexer = SingleIdTokenIndexer(
    lowercase_tokens=True,
)

from itertools import groupby

def remove_url(s):
    return re.sub(r"http\S+", "", s)

def remove_extra_chars(s, max_qty=2):
    res = [c * min(max_qty, len(list(group_iter))) for c, group_iter in groupby(s)] 
    return ''.join(res)

def tokenizer(x: str):
    return [remove_extra_chars(w) for w in _bert_tok.tokenize(remove_url(x))]
    #return [w.text for w in _spacy_tok(x.lower())]

In [6]:
tokenizer("This is a test don't do this a thome!")

['this', 'is', 'a', 'test', 'don', "'", 't', 'do', 'this', 'a', 'thome', '!']

In [7]:
"n't" in bert_tokenizer.vocab

False

In [8]:
len(bert_tokenizer.vocab)

30522

In [9]:
bert_id2tok[2]

'[unused1]'

In [10]:
bert_tokenizer.vocab['sh']

14021

In [11]:
# Returns arrays of arrays if there's an OOV word or an empty array instead
# Each array element is a tuple: 
# position of OOV word (with respect to the original tokenizer), sent for BERT tokenizer
def get_bert_masked_inputs(toks, bert_tokenizer):
    res = []
    
    oov_pos = []
    bert_vocab = bert_tokenizer.vocab
    
    for i in range(len(toks)):
        if toks[i] not in bert_vocab:
            oov_pos.append(i)
            

    for pos in oov_pos:
        res.append( (pos, '[CLS] %s [MASK] %s [SEP]' % 
                     (' '.join(toks[0:pos]), ' '.join(toks[pos+1:])) ) )
        
    return res

In [12]:
#get_bert_masked_inputs(tokenizer('This is a *strangge* sentence.'), bert_tokenizer)

In [13]:
toks = bert_tokenizer.tokenize('[CLS] detete what the [MASK] are you doing here ? [SEP]')
toks

['[CLS]',
 'det',
 '##ete',
 'what',
 'the',
 '[MASK]',
 'are',
 'you',
 'doing',
 'here',
 '?',
 '[SEP]']

In [14]:
toks = spacy_tokenizer.split_words("[CLS] what the [MASK] are you don't here ? [SEP]")
toks

[[, CLS, ], what, the, [, MASK, ], are, you, do, n't, here, ?, [, SEP, ]]

In [17]:
doc = spacy_nlp("[CLS] what the [MASK] are you don't here sh#t fcuk? [SEP]")
print([token.text for token in doc])


['[', 'CLS', ']', 'what', 'the', '[', 'MASK', ']', 'are', 'you', 'do', "n't", 'here', 'sh#t', 'fcuk', '?', '[', 'SEP', ']']


In [18]:
qty=0
for w in spacy_nlp.vocab:
    if w.text.find('\n') >=0:
        print('##%s##' % w.text)
        qty +=1
print(qty)

##
##
1


In [19]:
len(spacy_nlp.vocab)

57852

In [20]:
tokenizer("don't  couldn't can't you're I'm sheeeet")

['don',
 "'",
 't',
 'couldn',
 "'",
 't',
 'can',
 "'",
 't',
 'you',
 "'",
 're',
 'i',
 "'",
 'm',
 'sheet']

In [21]:
bert_tokenizer.tokenize("don't  couldn't can't you're I'm fcuk")

['don',
 "'",
 't',
 'couldn',
 "'",
 't',
 'can',
 "'",
 't',
 'you',
 "'",
 're',
 'i',
 "'",
 'm',
 'fc',
 '##uk']

In [22]:
bert_tokenizer.tokenize("tesst")

['tess', '##t']

In [23]:
tokenizer("You ' re right. It ' s a miracle! You'd been deceived!") # 've', 're', 's', 'd', 'll'

['you',
 "'",
 're',
 'right',
 '.',
 'it',
 "'",
 's',
 'a',
 'miracle',
 '!',
 'you',
 "'",
 'd',
 'been',
 'deceived',
 '!']

In [24]:
from collections import namedtuple
# pos_oov is OOV index with respect to the original (not BERT) tokenizer!!!
UtterData = namedtuple('SentData', ['batch_sent_id', 'sent_pos_oov', 'bert_pos_oov', 'tok_ids', 'oov_token'])

def get_batch_data(torch_device, bert_tokenizer, sent_list, max_len=MAX_BERT_LEN):
    
    batch_data_raw = []
    batch_max_seq_qty = 0
    batch_sent_id = -1
    for sent_toks in sent_list:
        batch_sent_id += 1
        for sent_pos_oov, text in get_bert_masked_inputs(sent_toks, bert_tokenizer):
            # To accurately get what is the position of [MASK] according
            # to BERT tokenizer, we need to re-tokenize the sentence using
            # the BERT tokenizer
            all_bert_toks = bert_tokenizer.tokenize(text)
            bert_toks = all_bert_toks[0:max_len] # 512 is the max. Bert seq. length

            tok_ids = bert_tokenizer.convert_tokens_to_ids(bert_toks)
            bert_pos_oov = None
            for i in range(len(bert_toks)):
                if bert_toks[i] == '[MASK]':
                    bert_pos_oov = i
                    break
            assert(bert_pos_oov is not None or len(all_bert_toks) > max_len)
            if bert_pos_oov is not None:
                tok_qty = len(tok_ids)
                batch_max_seq_qty = max(batch_max_seq_qty, tok_qty)
                batch_data_raw.append( 
                    UtterData(batch_sent_id=batch_sent_id, 
                              sent_pos_oov=sent_pos_oov, 
                              bert_pos_oov=bert_pos_oov,
                              tok_ids=tok_ids, 
                              oov_token=sent_toks[sent_pos_oov]))
            
    batch_qty = len(batch_data_raw)
    tok_ids_batch = np.zeros( (batch_qty, batch_max_seq_qty), dtype=np.int64) # zero is a padding symbol
    tok_mask_batch = np.zeros( (batch_qty, batch_max_seq_qty), dtype=np.int64)
    for k in range(batch_qty):
        tok_ids = batch_data_raw[k].tok_ids
        tok_qty = len(tok_ids)
        tok_ids_batch[k, 0:tok_qty] = tok_ids
        tok_mask_batch[k, 0:tok_qty] = np.ones(tok_qty)
                   
    tok_ids_batch = torch.from_numpy(tok_ids_batch).to(device=torch_device) 
    
    return batch_data_raw, tok_ids_batch, tok_mask_batch

In [25]:
import torch
import numpy as np
from collections import namedtuple
import sys

BertPredProbs = namedtuple('BertPred', ['batch_sent_id', 'sent_pos_oov', 'bert_pos_oov', 'logits'])

def get_bert_preds_for_words_batch(torch_device, bert_model_mlm, 
                                   batch_data_raw, tok_ids_batch, tok_mask_batch, # comes from get_batch_data
                                   word_ids # a list of IDS for which we generate logits
                                   ):

    seg_ids = torch.zeros_like(tok_ids_batch, device=torch_device)
    
    batch_qty = len(batch_data_raw)
    
    # Main BERT model see modeling.py in https://github.com/huggingface/pytorch-pretrained-BERT
    bert = bert_model_mlm.bert 
    # cls is an instance of BertOnlyMLMHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    cls = bert_model_mlm.cls
    # predictions are of the type BertLMPredictionHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    predictions = cls.predictions
    transform = predictions.transform
   
    # We don't use the complete decoding matrix, but only selected rows
    word_ids = torch.from_numpy(np.array(word_ids, dtype=np.int64)).to(device=torch_device)
    tok_mask_batch = torch.from_numpy(np.array(tok_mask_batch, dtype=np.int64)).to(device=torch_device)
                                
    weight = predictions.decoder.weight[word_ids,:]
    bias = predictions.bias[word_ids]

    # Transformations from the main BERT model
    sequence_output, _= bert(tok_ids_batch, seg_ids, attention_mask=tok_mask_batch, output_all_encoded_layers=False)
    # Transformations from the BertLMPredictionHead model with the restricted last layer
    hidden_states = transform(sequence_output)    
    logits = torch.nn.functional.linear(hidden_states, weight) + bias                            
                                        
    logits=logits.detach().cpu().numpy()
    
    res = []
    
    for k in range(batch_qty):
        e = batch_data_raw[k]
        bert_pos_oov = e.bert_pos_oov         
        res.append( BertPredProbs(batch_sent_id = batch_data_raw[k].batch_sent_id,
                                  bert_pos_oov = bert_pos_oov,
                                  sent_pos_oov = e.sent_pos_oov,
                                  logits = logits[k, bert_pos_oov]
                            ) 
                  )
        
    return res

In [29]:
bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                      'doin', 'wearing', 'making', 'thinking', 'all', 'tess', '##t'])

[3109,
 6616,
 17752,
 6548,
 4485,
 2725,
 24341,
 4147,
 2437,
 3241,
 2035,
 15540,
 2102]

In [30]:
bert_tokenizer.tokenize("shiiit shit crap flllyyy")

['shi', '##ii', '##t', 'shit', 'crap', 'fl', '##lly', '##y', '##y']

In [31]:
sent_list = ['you are just flailing making up shiiit on the flllyyy .'.split()]

batch_data_raw, tok_ids_batch, tok_mask_batch = get_batch_data(torch_device, 
                                                bert_tokenizer, 
                                                sent_list,
                                                MAX_BERT_LEN)

for s in batch_data_raw:
    print(s)


bert_ids_to_make_pred = bert_tokenizer.convert_tokens_to_ids(
    ['fl', '##ailing', 'flaming', 'shi', '##ii', '##t', 'shit', 'crap', 'fl', '##lly', '##y', '##y', 'fly'])    
    
get_bert_preds_for_words_batch(torch_device,
                               bert_model_mlm, 
                               batch_data_raw, tok_ids_batch, tok_mask_batch,
                               bert_ids_to_make_pred)

SentData(batch_sent_id=0, sent_pos_oov=3, bert_pos_oov=4, tok_ids=[101, 2017, 2024, 2074, 103, 2437, 2039, 11895, 6137, 2102, 2006, 1996, 13109, 9215, 2100, 2100, 1012, 102], oov_token='flailing')
SentData(batch_sent_id=0, sent_pos_oov=6, bert_pos_oov=8, tok_ids=[101, 2017, 2024, 2074, 13109, 29544, 2437, 2039, 103, 2006, 1996, 13109, 9215, 2100, 2100, 1012, 102], oov_token='shiiit')
SentData(batch_sent_id=0, sent_pos_oov=9, bert_pos_oov=13, tok_ids=[101, 2017, 2024, 2074, 13109, 29544, 2437, 2039, 11895, 6137, 2102, 2006, 1996, 103, 1012, 102], oov_token='flllyyy')


[BertPred(batch_sent_id=0, sent_pos_oov=3, bert_pos_oov=4, logits=array([-2.2550676 , -6.9771633 , -2.9452157 , -6.796351  , -5.1999493 ,
        -0.16598895, -1.003738  , -3.052899  , -2.2550676 , -1.2968991 ,
        -0.26938516, -0.26938516, -2.8959308 ], dtype=float32)),
 BertPred(batch_sent_id=0, sent_pos_oov=6, bert_pos_oov=8, logits=array([-0.14233172, -5.738283  , -2.9242346 , -5.991565  , -7.521471  ,
        -1.7374325 ,  4.7143626 ,  2.8124328 , -0.14233172, -3.4299695 ,
        -0.6083918 , -0.6083918 , -0.99990255], dtype=float32)),
 BertPred(batch_sent_id=0, sent_pos_oov=9, bert_pos_oov=13, logits=array([-2.0155    , -2.015716  , -1.2033051 , -1.7736413 , -2.8350475 ,
        -1.4686524 ,  2.0290556 ,  0.31409857, -2.0155    , -1.7462821 ,
        -1.1763945 , -1.1763945 ,  4.2806478 ], dtype=float32))]

In [32]:
#np.mean([-1.6624606, -11.621088 ,  -3.8026168])

In [36]:
sent_list = ['What the fcuk are you doingg here ?'.lower().split(),
             'This is a *strangge* sentence'.lower().split(),
             'This is a bbbbbbest'.lower().split()]

batch_data_raw, tok_ids_batch, tok_mask_batch = get_batch_data(torch_device, 
                                                bert_tokenizer, 
                                                sent_list,
                                                MAX_BERT_LEN)

for s in batch_data_raw:
    print(s)


bert_ids_to_make_pred = bert_tokenizer.convert_tokens_to_ids(
    ['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 'doin', 
     'thinking', 'strange', 'rest', 'test', 'tess', '##t'])    
    
get_bert_preds_for_words_batch(torch_device,
                               bert_model_mlm, 
                               batch_data_raw, tok_ids_batch, tok_mask_batch,
                               bert_ids_to_make_pred)

SentData(batch_sent_id=0, sent_pos_oov=2, bert_pos_oov=3, tok_ids=[101, 2054, 1996, 103, 2024, 2017, 2725, 2290, 2182, 1029, 102], oov_token='fcuk')
SentData(batch_sent_id=0, sent_pos_oov=5, bert_pos_oov=7, tok_ids=[101, 2054, 1996, 4429, 6968, 2024, 2017, 103, 2182, 1029, 102], oov_token='doingg')
SentData(batch_sent_id=1, sent_pos_oov=3, bert_pos_oov=4, tok_ids=[101, 2023, 2003, 1037, 103, 6251, 102], oov_token='*strangge*')
SentData(batch_sent_id=2, sent_pos_oov=3, bert_pos_oov=4, tok_ids=[101, 2023, 2003, 1037, 103, 102], oov_token='bbbbbbest')


[BertPred(batch_sent_id=0, sent_pos_oov=2, bert_pos_oov=3, logits=array([ 1.5889928e+01,  1.3050140e+01,  1.1907359e+01,  1.0735191e+01,
         7.7608805e+00,  1.1303816e+00, -3.6249363e-01, -1.2756933e+00,
         1.4766751e-02, -2.3537908e-01, -2.4631200e+00,  4.4421670e-01,
        -9.1704965e-01], dtype=float32)),
 BertPred(batch_sent_id=0, sent_pos_oov=5, bert_pos_oov=7, logits=array([-6.5866718e+00, -3.7194700e+00, -4.6205916e+00, -7.7321978e+00,
        -1.2913905e-03,  2.0477596e+01,  1.1890603e+01,  9.2444258e+00,
        -9.0952718e-01, -5.5600286e+00, -4.1842346e+00, -7.5178404e+00,
        -1.5601634e+00], dtype=float32)),
 BertPred(batch_sent_id=1, sent_pos_oov=3, bert_pos_oov=4, logits=array([-2.5699182 , -2.195084  , -1.8034788 , -4.066658  , -0.28844202,
        -0.7897885 , -0.93698156, -0.15000674,  1.9507252 ,  0.99364173,
         3.9929423 , -1.4381317 , -1.2424338 ], dtype=float32)),
 BertPred(batch_sent_id=2, sent_pos_oov=3, bert_pos_oov=4, logits=array([-3.16

In [None]:
#bert_model_mlm.to('cpu')
#torch.zeros(3,device=torch.device("cuda"))

In [37]:
bert_tokenizer.tokenize('б')

['б']

In [38]:
from typing import *
from overrides import overrides
import allennlp
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
from allennlp.data.fields import TextField, SequenceLabelField, LabelField, MetadataField, ArrayField
class MemoryOptimizedTextField(TextField):
    @overrides
    def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
        self.tokens = tokens
        self._token_indexers = token_indexers
        self._indexed_tokens: Optional[Dict[str, TokenList]] = None
        self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
        # skip checks for tokens
    @overrides
    def index(self, vocab):
        super().index(vocab)
        self.tokens = None # empty tokens


In [39]:
def count_lowercase(s):
  return sum([int(c == c.lower()) for c in s])


def get_canon_case_map(nlp):
  dict_tmp = dict()

  for t in nlp.vocab:
    dict_tmp[t.text.lower()] = set()

  for t in nlp.vocab:
    dict_tmp[t.text.lower()].add(t.text)

  dict_map = dict()

  for key, tset in dict_tmp.items():
    lst = [(count_lowercase(s), s) for s in tset]
    lst.sort(reverse=True)
    choice_str = lst[0][1]
    dict_map[key] = choice_str

  return dict_map

In [40]:
import numpy as np

# returned word list can contain upper-case letters
def read_embeds_and_words_subset(file_name, word_map):
  word_list, embed_list = [], []

  with open(file_name, encoding="utf8") as f:
    for line in f:
      line = line.strip()
      if not line:
        continue
      fld = line.split()
      w = fld[0]
      if w in word_map:
        word_list.append(word_map[w])
        embed_list.append(np.array([float(s) for s in fld[1:]]))

  return word_list, np.vstack(embed_list)

In [41]:
spacy_nlp = spacy.load("en_core_web_sm",
                     disable=['parser', 'ner', 'pos'])

spacy_word_map = get_canon_case_map(spacy_nlp)

In [42]:
list(spacy_nlp("This haven't true."))

[This, have, n't, true, .]

In [70]:
 "shiit" in spacy_nlp.vocab

False

In [44]:
import os

THIS FUNCTION IS WRONG, IT NEEDS TO CREATE GLOBAL WORD->EMBED MAPS, but later we need to index
only SPACY embeds.

word_arr, embed_arr = read_embeds_and_words_subset(os.path.join('../data/jigsaw/', 'ft_basic_toks.txt'), 
                                                   spacy_word_map)
print('Read %d spacy words from the fasttext-dictionary file' % len(word_arr))

Read 49601 spacy words from the fasttext-dictionary file


In [54]:
word_to_embed_map = { word_arr[i].lower() : embed_arr[i] for i in range(len(word_arr)) }

In [71]:
'gillette' in spacy_word_map
word_to_embed_map['shiit']

KeyError: 'shiit'

In [56]:
import nmslib, time

def create_embed_index(embeds, M = 30, efC = 200, efS = 200):
  num_threads = 0
  index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post': 0}
  print('Index-time parameters', index_time_params)

  # Space name should correspond to the space name
  # used for brute-force search
  space_name = 'cosinesimil'

  # Intitialize the library, specify the space, the type of the vector and add data points
  index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR)
  index.addDataPointBatch(embeds)

  # Create an index
  start = time.time()
  index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
  index.createIndex(index_time_params)
  end = time.time()
  print('Index-time parameters', index_time_params)
  print('Indexing time = %f' % (end - start))

  # Setting query-time parameters
  query_time_params = {'efSearch': efS}
  print('Setting query-time parameters', query_time_params)
  index.setQueryTimeParams(query_time_params)

  return index


def create_word_index(words, M = 30, efC = 200, efS = 200):
  num_threads = 0
  index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post': 0}
  print('Index-time parameters', index_time_params)

  # Space name should correspond to the space name
  # used for brute-force search
  space_name = 'leven'

  # Intitialize the library, specify the space, the type of the vector and add data points
  index = nmslib.init(method='hnsw', space=space_name, dtype=nmslib.DistType.INT, data_type=nmslib.DataType.OBJECT_AS_STRING)
  index.addDataPointBatch(list(words))

  # Create an index
  start = time.time()
  index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
  index.createIndex(index_time_params)
  end = time.time()
  print('Index-time parameters', index_time_params)
  print('Indexing time = %f' % (end - start))

  # Setting query-time parameters
  query_time_params = {'efSearch': efS}
  print('Setting query-time parameters', query_time_params)
  index.setQueryTimeParams(query_time_params)

  return index

def query_index(index, K, query_arr, num_threads=0, efS=200):
  # Querying
  query_time_params = {'efSearch': efS}
  print('Setting query-time parameters', query_time_params)
  index.setQueryTimeParams(query_time_params)

  query_qty = len(query_arr)
  start = time.time()
  res = index.knnQueryBatch(query_arr, k=K, num_threads=num_threads)
  end = time.time()
  print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' %
        (end - start, float(end - start) / query_qty, num_threads * float(end - start) / query_qty))

  return res

In [57]:
SURFACE_SIMIL = 'surface'
GLOBAL_SIMIL = 'global'
CONTEXT_SIMIL = 'context'

def add_pooled_res(res, word, simil, res_key):
    if not word in res:
        res[word] = dict()
    res[word][res_key] = simil


def get_pooled_neighbors(embed_index, word_index, 
                     word_arr, word_to_embed_map,
                     K, query_word, num_threads=0, efS=200):
    tmp_res = dict()

    ids_word, dists = query_index(word_index, 2*K, [query_word], num_threads=num_threads, efS=efS)[0]
    qty_word = len(ids)
    for i in range(qty_word):
        w_res = word_arr[ids_word[i]]
        max_len = float(max(len(w_res), len(query_word)))
        simil = 1 - float(dists[i])/max_len
        add_pooled_res(tmp_res, w_res, simil, SURFACE_SIMIL)
        
    ids_embed, dists = query_index(embed_index, 2*K, np.array([word_to_embed_map[query_word]]), 
                                  num_threads=num_threads, efS=efS)[0]
    qty_embed = len(ids)
    for i in range(qty_embed):
        w_res = word_arr[ids_embed[i]]
        simil = 1 - 0.5 * dists[i]
        add_pooled_res(tmp_res, w_res, simil, GLOBAL_SIMIL)
    
    final_res = dict()
    
    i_word = 0
    i_embed = 0
    while (i_word < qty_word or i_embed < qty_embed) and len(final_res) < K:
        if i_word < qty_word:
            w_res = word_arr[ids_word[i_word]]
            if not w_res in seen:
                final_res[w_res] = tmp_res[w_res]                
            i_word += 1
        if i_embed < qty_embed:
            w_res = word_arr[ids_embed[i_embed]]
            if not w_res in seen:
                final_res[w_res] = tmp_res[w_res] 
            i_embed += 1
    
    return final_res

In [58]:
from allennlp.data.fields import  ArrayField, MetadataField, TextField
from allennlp.data.token_indexers import TokenIndexer
from overrides import overrides

class MemoryOptimizedTextField(TextField):

  @overrides
  def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
    self.tokens = tokens
    self._token_indexers = token_indexers
    self._indexed_tokens: Optional[Dict[str, TokenList]] = None
    self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
    # skip checks for tokens

  @overrides
  def index(self, vocab):
    super().index(vocab)
    self.tokens = None  # empty tokens

In [59]:
from typing import *
from allennlp.data import Instance
from allennlp.data.fields import  ArrayField, MetadataField, TextField

#data_vocab_path = "../data/jigsaw/data_vocab.bin"
#vocab=pickle.load(open(data_vocab_path,'rb'))

def get_spacy_vocab_instances(nlp) -> Iterator[Instance]:
  words = set([t.text.lower().strip() for t in nlp.vocab])

  fields = {}

  for w in words:
    w = w.strip()
    if w and w.find('\n') < 0:
      fields["tokens"] = MemoryOptimizedTextField([w], {"tokens": SingleIdTokenIndexer()})
      yield Instance(fields)


In [60]:
spacy_word_arr = []
spacy_embed_arr = []

from allennlp.data.vocabulary import Vocabulary

vocab = Vocabulary.from_instances(get_spacy_vocab_instances(spacy_nlp))

for idx, token in vocab.get_index_to_token_vocabulary().items():
    if idx > 1:
        spacy_word_arr.append(token)
        spacy_embed_arr.append(word_to_embed_map[token])
    

49584it [00:00, 68003.04it/s]


In [None]:
print(vocab)

In [61]:
for idx, token in vocab.get_index_to_token_vocabulary().items():
    if idx > 5: break
    print(token)

@@PADDING@@
@@UNKNOWN@@
arbitration
actuary
recalculated
59


In [62]:
embed_index = create_embed_index(embed_arr)

Index-time parameters {'M': 30, 'indexThreadQty': 0, 'efConstruction': 200, 'post': 0}
Index-time parameters {'M': 30, 'indexThreadQty': 0, 'efConstruction': 200}
Indexing time = 15.615927
Setting query-time parameters {'efSearch': 200}


In [63]:
word_index = create_word_index(word_arr)

Index-time parameters {'M': 30, 'indexThreadQty': 0, 'efConstruction': 200, 'post': 0}
Index-time parameters {'M': 30, 'indexThreadQty': 0, 'efConstruction': 200}
Indexing time = 11.862540
Setting query-time parameters {'efSearch': 200}


In [64]:
res = query_index(word_index, 10, ['shiiit'])

Setting query-time parameters {'efSearch': 200}
kNN time total=0.003677 (sec), per query=0.003677 (sec), per query adjusted for thread number=0.000000 (sec)


In [65]:
res

[(array([ 2304,  5072,  6422, 23011,   320, 32017, 36871, 40168, 40247,
         40869], dtype=int32),
  array([2, 2, 2, 2, 2, 3, 3, 3, 3, 3], dtype=int32))]

In [66]:
for ids, dists in res:
    qty = len(ids)
    for i in range(qty):
        print(word_arr[ids[i]], dists[i])

spirit 2
shift 2
shirt 2
shrift 2
shit 2
chitin 3
Zhiyi 3
Zhimin 3
Bailit 3
Zhilin 3


In [68]:
res = query_index(embed_index, 10, np.array([word_to_embed_map['shiit']]))

KeyError: 'shiiit'

In [None]:
for ids, dists in res:
    qty = len(ids)
    for i in range(qty):
        print(word_arr[ids[i]], dists[i])

In [None]:
test_file = "../data/jigsaw/test_basic.jsonl"

In [None]:
import time, gc
import copy
from scipy.special import softmax


DEBUG_PRINT=False


all_src_sents = []

with open(test_file) as f
  for line in f:
    obj = json.loads(line)
    all_src_sents.append(obj["tokens"])
    
t0 = time.time()
preds = []

all_dst_sents = []

batch_qty_step = 20

for batch_start_sent_id in range(0, len(all_src_sents), batch_qty_step):
    print('Batch start', batch_start_sent_id)

    batch_qty = min(batch_qty_step, len(all_src_sents) - batch_start_sent_id)

    batch_sents = [all_src_sents[k] for k in range(batch_start_sent_id,
                                                   batch_start_sent_id + batch_qty)]

    # batch_data raw contains elements
    # UtterData = namedtuple('SentData', ['batch_sent_id', 'sent_pos_oov', 'bert_pos_oov', tok_ids', 'oov_token')
    #
    # tok_ids_batch is a Tensor with padded Bert-specific token IDs ready
    # to be fed into a BERT model
    batch_data_raw, tok_ids_batch = get_batch_data(torch_device,
                                                 bert_tokenizer,
                                                 batch_sents, 
                                                 MAX_BERT_LEN)

    query_arr = [e.oov_token for e in batch_data_raw]
    
    


    neighb_tok_ids = []
    
    preds = get_bert_preds_for_words_batch(torch_device,
                                           bert_model_mlm,
                                           batch_data_raw, tok_ids_batch,
                                           neighb_tok_ids)

    assert(len(preds) == query_qty)
    for qid in range(query_qty):
        e = batch_data_raw[qid]
        glob_sent_id = batch_start_sent_id + e.batch_sent_id
        assert(batch_sents[e.batch_sent_id] == all_src_sents[glob_sent_id])
        if is_apost_token(e.oov_token) or e.oov_token == "n't":
            # Thing's like "I don't" or "You're" are tokenized as do "I do n't" or "You 're'"
            pass # TODO fix this
        elif query_tok_oov_id[qid] > 1:  
            # Let's map neighbor IDs from each queries to respective 
            # logits from the prediction set
            logit_map = dict() # from Bert-specific token IDs to predicted logits
            assert(len(preds[qid].logits) == len(neighb_tok_ids))
            for i in range(len(neighb_tok_ids)):
                logit_map[neighb_tok_ids[i]] = preds[qid].logits[i]

            e = batch_data_raw[qid]
            if DEBUG_PRINT: 
                print(all_src_sents[glob_sent_id])
                print("### OOV ###", e.oov_token)
                print([bert_id2tok[bert_tok_id] for bert_tok_id in nbrs[qid][0]])

            nbrs_sel_logits = []
            nbrs_sel_toks = []
            nbrs_sel_dists = []

            nbrs_ids = nbrs[qid][0]
            nbrs_dist = nbrs[qid][1]

            #print('Logit map:', logit_map)
            #print('neighb_tok_ids', neighb_tok_ids)

            nqty = len(nbrs_ids)
            for t in range(nqty):
                bert_tok_id = nbrs_ids[t]
                # nid is Bert-speicifc token ID
                if not bert_tok_id in neighb_tok_ids:
                    if DEBUG_PRINT: 
                        print('Missing %s distance %g ' 
                              % (bert_id2tok[bert_tok_id],
                                 nbrs_dist[t]))
                else:
                    if nbrs_dist[t] < MAX_COSINE_DIST:
                        nbrs_sel_logits.append(logit_map[bert_tok_id])
                        nbrs_sel_toks.append(bert_id2tok[bert_tok_id]) 
                        nbrs_sel_dists.append(nbrs_dist[t])

            if nbrs_sel_logits:
                nbrs_softmax = softmax(np.array(nbrs_sel_logits))
                nbrs_simil = 1 - np.array(nbrs_sel_dists)
                nbrs_simil_adj = nbrs_softmax * nbrs_simil 

                best_tok_id = np.argmax(nbrs_simil_adj)

                #print("batch sent id:",e.batch_sent_id, e.pos_oov, best_tok_id)
                #print(replace_dict[e.batch_sent_id])
                assert(not e.pos_oov in replace_dict[e.batch_sent_id])
                replace_dict[e.batch_sent_id][e.pos_oov] = nbrs_sel_toks[best_tok_id]

                if DEBUG_PRINT: 
                    print('Selected info, best_tok:', nbrs_sel_toks[best_tok_id])
                    for k in range(len(nbrs_sel_logits)):
                        print(nbrs_sel_toks[k], nbrs_softmax[k], 
                              nbrs_sel_dists[k], nbrs_simil_adj[k])
            else:
                if DEBUG_PRINT: print('Nothing found!')

            #if DEBUG_PRINT: print(preds[qid])
            if DEBUG_PRINT: 
                print("====================================================================")



    #gc.collect()
    #torch.cuda.empty_cache()
    for k in range(0, batch_qty):
        src_sent = batch_sents[k]
        rd = replace_dict[k]
        #print('Replacement dict:', rd)
        dst_sent = replace_by_patterns(tokenizer, src_sent, rd)
        all_dst_sents.append(dst_sent)
        if DEBUG_PRINT:
            print("====================================================================")
            print('Replacement dict:', rd)
            print(src_sent)
            print('------------')
            print(dst_sent)
            print("====================================================================")

    #break

t1 = time.time()
print('# of src sentences:', len(all_src_sents), 
      "# of dst sentences:", len(all_dst_sents),
      ' time elapsed:', t1 - t0)
src_data['comment_text'] = all_dst_sents
src_data.to_csv(dst_file, index=False)

In [None]:
#src_data[src_data["toxic"]==1].head(20)

In [None]:
#fl = pd.read_csv(src_file)

In [None]:
#fl[fl["toxic"]==1].head(20)

In [None]:
'flaming' in nlp.vocab