In [49]:
import allennlp
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

In [50]:
from typing import *
from overrides import overrides

In [51]:
import numpy as np

In [52]:
# for papermill
testing = False # set to False when running experiments
debugging = False
seed = 1
use_bt = False
computational_batch_size = 128
batch_size = 128
lr = 4e-3
lr_schedule = "slanted_triangular"
epochs = 5 if not testing else 1
hidden_sz = 128
dataset = "jigsaw"
n_classes = 6
max_seq_len = 512
download_data = False
ft_model_path = "../data/jigsaw/ft_model.txt"
ft_compiled_path = "../data/jigsaw/ft_compiled.npy" # Embeddings generated from the vocabulary
data_vocab_path = "../data/jigsaw/data_vocab.bin"
max_vocab_size = 400000
dropouti = 0.2
dropoutw = 0.0
dropoute = 0.1
dropoute_max = None
dropoutr = 0.3 # TODO: Implement
val_ratio = 0.0
use_focal_loss = False
focal_loss_alpha = 1.
focal_loss_gamma = 2.
use_augmented = False
freeze_embeddings = True
mixup_ratio = 0.0
discrete_mixup_ratio = 0.0
attention_bias = True
use_attention_aux = False
weight_decay = 0.
bias_init = True
neg_splits = 1
num_layers = 2
rnn_type = "lstm"
rnn_residual = False
pooling_type = "augmented_multipool" # attention or multipool or augmented_multipool
model_type = "standard"
cache_elmo_embeddings = True
use_word_level_features = False
use_sentence_level_features = False
bucket = True
compute_thres_on_test = True
find_lr = False
permute_sentences = False
run_id = "error_analysis"
bert_oov_map_path = "../data/jigsaw/bert_oov_map.bin"

In [53]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

In [54]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    use_bt=use_bt,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    lr=lr,
    lr_schedule=lr_schedule,
    epochs=epochs,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    dropouti=dropouti,
    dropoutw=dropoutw,
    dropoute=dropoute,
    dropoute_max=dropoute_max,
    dropoutr=dropoutr,
    val_ratio=val_ratio,
    use_focal_loss=use_focal_loss,
    focal_loss_alpha=focal_loss_alpha,
    focal_loss_gamma=focal_loss_gamma,
    use_augmented=use_augmented,
    freeze_embeddings=freeze_embeddings,
    attention_bias=attention_bias,
    use_attention_aux=use_attention_aux,
    weight_decay=weight_decay,
    bias_init=bias_init,
    neg_splits=neg_splits,
    num_layers=num_layers,
    rnn_type=rnn_type,
    rnn_residual=rnn_residual,
    pooling_type=pooling_type,
    model_type=model_type,
    cache_elmo_embeddings=cache_elmo_embeddings,
    use_word_level_features=use_word_level_features,
    use_sentence_level_features=use_sentence_level_features,
    mixup_ratio=mixup_ratio,
    discrete_mixup_ratio=discrete_mixup_ratio,
    bucket=bucket,
    compute_thres_on_test=compute_thres_on_test,
    permute_sentences=permute_sentences,
    find_lr=find_lr,
    run_id=run_id,
)

In [55]:
from pathlib import Path

DATA_ROOT = Path("../data") / config.dataset


In [56]:
reader_registry = {}
def register(name: str):
    def dec(x: Callable):
        reader_registry[name] = x
        return x
    return dec

In [57]:
import gc, csv

from allennlp.data.fields import TextField, SequenceLabelField, LabelField, MetadataField, ArrayField
import string
alphabet = set(string.ascii_lowercase)

sentence_level_features: List[Callable[[List[str]], float]] = [
#     lambda x: (np.log1p(len(x)) - 3.628) / 1.065, # stat computed on train set
]

word_level_features: List[Callable[[str], float]] = [
    lambda x: 1 if (x.lower() == x) else 0,
    lambda x: len([c for c in x.lower() if c not in alphabet]) / len(x),
]

def proc(x: str) -> str:
    if "uncased" in config.model_type:
        x = x.lower()
    if config.model_type == "standard":
        return x.lower()
    else:
        return x

class MemoryOptimizedTextField(TextField):
    @overrides
    def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
        self.tokens = tokens
        self._token_indexers = token_indexers
        self._indexed_tokens: Optional[Dict[str, TokenList]] = None
        self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
        # skip checks for tokens
    @overrides
    def index(self, vocab):
        super().index(vocab)
        self.tokens = None # empty tokens

@register("jigsaw")
class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[str], id: str,
                         labels: np.ndarray) -> Instance:
        sentence_field = MemoryOptimizedTextField([proc(x) for x in tokens],
                                   self.token_indexers)
        fields = {"tokens": sentence_field}
        
        wl_feats = np.array([[func(w) for func in word_level_features] for w in tokens])
        fields["word_level_features"] = ArrayField(array=wl_feats)
        
        sl_feats = np.array([func(tokens) for func in sentence_level_features])
        fields["sentence_level_features"] = ArrayField(array=sl_feats)

        label_field = ArrayField(array=labels)
        fields["label"] = label_field

        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            reader = csv.reader(f)
            next(reader)
            for i, line in enumerate(reader):
                if len(line) == 9:
                    _, id_, text, *labels = line
                elif len(line) == 8:
                    id_, text, *labels = line
                else: raise ValueError(f"line has {len(line)} values")
                yield self.text_to_instance(
                    self.tokenizer(text),
                    id_, np.array([int(x) for x in labels]),
                )
                if config.testing and i == 1000: break

In [58]:
from tqdm import tqdm
import warnings

def get_coefs(word,*arr): return word, np.asarray(arr, dtype='float32')

def get_fasttext_embeddings(vocab: Vocabulary):
    prog_bar = tqdm(open(config.ft_model_path, encoding="utf8", errors='ignore'))
    prog_bar.set_description("Loading embeddings")
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in prog_bar
                             if len(o)>100)
    all_embs = np.stack(embeddings_index.values())

    embeddings = np.zeros((config.vocab_size + 5, 300))
    n_missing_tokens = 0
    prog_bar = tqdm(vocab.get_index_to_token_vocabulary().items())
    prog_bar.set_description("Creating matrix")
    for idx, token in prog_bar:
        if idx == 0: continue # keep padding as all zeros
        if idx == 1: continue # Treat unknown words as dropped words
        if token == "[MASK]":
            embeddings[idx, :] = np.random.randn(300) * 0.5
        if token not in embeddings_index:
            n_missing_tokens += 1
            if n_missing_tokens < 10:
                warnings.warn(f"Token {token} not in embeddings: did you change preprocessing?")
            if n_missing_tokens == 10:
                warnings.warn(f"More than {n_missing_tokens} missing, supressing warnings")
        else:
            embeddings[idx, :] = embeddings_index[token]
    
    if n_missing_tokens > 0:
        warnings.warn(f"{n_missing_tokens} in total are missing from embedding text file")
    return embeddings

In [59]:
import random
from functools import wraps

def maybeshuffle(_tokenize):
    def func(*args, **kwargs):
        arr = _tokenize(*args, **kwargs)
        if config.permute_sentences:
            random.shuffle(arr)
        return arr
    return func

In [60]:
# Tokenizer and tokenizer index: THIS Is DIFFERENT FROM BERT!!!
token_indexer = SingleIdTokenIndexer(
    lowercase_tokens=True,
)
@maybeshuffle
def tokenizer(x: str):
    return [w.text for w in
            _spacy_tok(x)[:config.max_seq_len]]


In [61]:
reader = JigsawDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer}
)

In [62]:
train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                          "val.csv",
                                                                          "test_proced.csv"])

151592it [03:05, 815.43it/s] 
7979it [00:07, 1026.20it/s]
63978it [01:18, 815.40it/s] 


In [63]:
import pickle
with open("../data/jigsaw/train_ds.bin", "wb") as of:
    pickle.dump(train_ds, of)

In [64]:
import pickle
with open("../data/jigsaw/val_ds.bin", "wb") as of:
    pickle.dump(val_ds, of)

In [65]:
import pickle
with open("../data/jigsaw/test_ds.bin", "wb") as of:
    pickle.dump(test_ds, of)

In [15]:
full_ds = train_ds + val_ds + test_ds

In [16]:
from itertools import groupby
def remove_extra_chars(s, max_qty=2):
    res = [c * min(max_qty, len(list(group_iter))) for c, group_iter in groupby(s)] 
    return ''.join(res)

In [17]:
remove_extra_chars('abcdef')

'abcdef'

In [18]:
from pytorch_pretrained_bert.tokenization import BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [19]:
vocab = Vocabulary.from_instances(full_ds, max_vocab_size=config.max_vocab_size)

100%|██████████| 223549/223549 [00:18<00:00, 11866.52it/s]


In [20]:
prog_bar = tqdm(vocab.get_index_to_token_vocabulary().items())
token_arr = [token for idx, token in prog_bar]

100%|██████████| 305140/305140 [00:00<00:00, 2142998.14it/s]


In [21]:
config.set("vocab_size", min(vocab.get_vocab_size(), config.max_vocab_size))
config.set("embedding_dim", 300)
fasttext_embeds = get_fasttext_embeddings(vocab)

Loading embeddings: : 317458it [00:30, 10274.49it/s]
  # This is added back by InteractiveShellApp.init_path()
Creating matrix: 100%|██████████| 305140/305140 [00:01<00:00, 236711.16it/s]


In [22]:
from scipy.spatial.distance import cosine, euclidean
from numpy.linalg import norm
w1 = fasttext_embeds[vocab.get_token_index('fcuk')]
w2 = fasttext_embeds[vocab.get_token_index('fuck')]
print(np.sum(w1 * w2)/norm(w1)/norm(w2), 1 - cosine(w1, w2))

0.38523468070826167 0.3852346807082616


In [23]:
'fcuk' in bert_tokenizer.vocab

False

In [24]:
bert_vocab_toks = bert_tokenizer.vocab.keys()
vocab_toks = set( [w for idx, w in vocab.get_index_to_token_vocabulary().items() ])

In [25]:
len(vocab_toks), len(bert_vocab_toks)

(305140, 30522)

In [26]:

bert_vocab_ids = []

for tok in bert_vocab_toks:
    tok_id = vocab.get_token_index(tok)
    if tok_id > 1:
        bert_vocab_ids.append(tok_id)
        
bert_vocab_ids = np.array(bert_vocab_ids)

In [27]:
fasttext_embeds[bert_vocab_ids].shape

(22778, 300)

In [28]:
import nmslib, time

M = 25
efC = 200

num_threads = 0
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post' : 0}
print('Index-time parameters', index_time_params)

Index-time parameters {'M': 25, 'indexThreadQty': 0, 'efConstruction': 200, 'post': 0}


In [29]:

# Space name should correspond to the space name 
# used for brute-force search
space_name='cosinesimil'


# Intitialize the library, specify the space, the type of the vector and add data points 
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) 
index.addDataPointBatch(fasttext_embeds[bert_vocab_ids])

22778

In [30]:
# Create an index
start = time.time()
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
index.createIndex(index_time_params) 
end = time.time() 
print('Index-time parameters', index_time_params)
print('Indexing time = %f' % (end-start))

Index-time parameters {'M': 25, 'indexThreadQty': 0, 'efConstruction': 200}
Indexing time = 5.314431


In [31]:
# Setting query-time parameters
efS = 1000
K=10
query_time_params = {'efSearch': efS}
print('Setting query-time parameters', query_time_params)
index.setQueryTimeParams(query_time_params)

Setting query-time parameters {'efSearch': 1000}


In [32]:
tok_id = vocab.get_token_index('fuuck')
query_arr = [fasttext_embeds[tok_id]]
query_matrix = np.array(query_arr)
K=10
query_matrix.shape, tok_id

((1, 300), 100538)

In [33]:
# Querying
query_qty = query_matrix.shape[0]
start = time.time() 
nbrs = index.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)
end = time.time() 
print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % 
      (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty))

kNN time total=0.005025 (sec), per query=0.005025 (sec), per query adjusted for thread number=0.000000 (sec)


In [34]:
nbrs[0][0], nbrs[0][1]

(array([11084, 21732, 19477, 15291, 22556, 17013, 18107, 20259, 20725,
        16667], dtype=int32),
 array([0.47294343, 0.48340237, 0.4915681 , 0.50250036, 0.50813234,
        0.509331  , 0.5209646 , 0.53238195, 0.5377827 , 0.53826004],
       dtype=float32))

In [35]:
for i in nbrs[0][0]:
    print(vocab.get_token_from_index(bert_vocab_ids[i]))

ninja
minions
yuki
ryu
gunslinger
akira
sakura
mikey
godzilla
knocks


In [36]:

oov_toks = vocab_toks - bert_vocab_toks
query_arr = []
query_toks = []
for tok in oov_toks:
    tok_id = vocab.get_token_index(remove_extra_chars(tok))
    query_arr.append(fasttext_embeds[tok_id])
    query_toks.append(tok)
    
query_matrix = np.array(query_arr)
query_matrix.shape

(282362, 300)

In [37]:
# Setting query-time parameters
efS = 1000
K=10
query_time_params = {'efSearch': efS}
print('Setting query-time parameters', query_time_params)
index.setQueryTimeParams(query_time_params)

Setting query-time parameters {'efSearch': 1000}


In [38]:
# Querying
query_qty = query_matrix.shape[0]
start = time.time() 
nbrs = index.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)
end = time.time() 
print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % 
      (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty))

kNN time total=256.942548 (sec), per query=0.000910 (sec), per query adjusted for thread number=0.000000 (sec)


In [39]:
OOV_THRESHOLD=0.3 # We care only about very similar items
oov_map = dict()

for i in range(query_qty):
    nbr_ids = nbrs[i][0]
    nbr_dist = nbrs[i][1]
    query_tok = query_toks[i]
    
    for k in range(len(nbr_ids)):
        nbr_id = nbr_ids[k]
        tok_id = bert_vocab_ids[nbr_id]
        nbr_tok = vocab.get_token_from_index(tok_id)
        assert(query_tok != nbr_tok)
        if nbr_dist[k] <= OOV_THRESHOLD:
            oov_map[query_tok] = nbr_tok
            break
             
        

In [40]:
oov_map

{'refraction': 'wavelengths',
 'requests#requests': 'requests',
 'upadhyay': 'sharma',
 'atherosclerotic': 'inflammation',
 'fabricators': 'fabrication',
 'macedonicity': 'macedonia',
 'humanitarianism': 'humanitarian',
 'veterinarians': 'veterinary',
 'chellappa': 'menon',
 'judgemental': 'judgement',
 'elearning': 'learning',
 'taman': 'jalan',
 'humour~': 'humour',
 'переломне': 'и',
 'spacing': 'spaced',
 'hellenization': 'hellenistic',
 'lighthouses': 'lighthouse',
 'neudeutschland': 'deutschland',
 'uplands': 'upland',
 'platform==': 'platforms',
 'wikikikipedia': 'wikipedia',
 'interpretation?==': 'interpretation',
 'settlin': 'settling',
 'heliskiing': 'skiing',
 'psychopathy': 'schizophrenia',
 'giganotosaurus': 'dinosaurs',
 'romanians==': 'romanian',
 'hypophyseal': 'congenital',
 'senater': 'senate',
 'http://usability.gov/pdfs/chapter11.pdf': 'pdf',
 'знайти': 'в',
 '4-electron': 'electron',
 'свои': 'и',
 'kammerphilharmonie': 'philharmonic',
 'payable': 'payment',
 'beak

In [41]:
import pickle

with open(bert_oov_map_path, 'wb') as of:
    pickle.dump(oov_map, of)

In [43]:
bert_oov_map_path

'../data/jigsaw/bert_oov_map.bin'

In [45]:
with open(bert_oov_map_path, 'rb') as f:
    tmp = pickle.load(f)

In [48]:
import pickle

with open(data_vocab_path, 'wb') as of:
    pickle.dump(vocab, of)


# Save data to use in other notebooks
np.save(ft_compiled_path, fasttext_embeds)