In [1]:
# settings for seamlessly running on colab
import os
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.environ["IS_COLAB"] = "True"
except ImportError:
    os.environ["IS_COLAB"] = "False"

In [2]:
%%bash
if [ "$IS_COLAB" = "True" ]; then
    pip install git+https://github.com/facebookresearch/fastText.git
    pip install torch
    pip install torchvision
    pip install --upgrade git+https://github.com/keitakurita/allennlp@develop
    pip install dnspython
    pip install jupyter_slack
    pip install git+https://github.com/keitakurita/Better_LSTM_PyTorch.git
    if [ -d "apex" ]; then
      git clone https://github.com/NVIDIA/apex.git
    fi
    cd apex && python setup.py install --cpp_ext --cuda_ext
fi

In [3]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline
from overrides import overrides

In [4]:
import time
from contextlib import contextmanager

class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')
    
import functools
import traceback
import sys

def get_ref_free_exc_info():
    "Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory"
    type, val, tb = sys.exc_info()
    traceback.clear_frames(tb)
    return (type, val, tb)

def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = get_ref_free_exc_info() # must!
            raise type(val).with_traceback(tb) from None
    return wrapper

def ifnone(a: Any, alt: Any): return alt if a is None else a

In [5]:
# for papermill
testing = False
debugging = False
seed = 1
char_encoder = "cnn"
tie_weights = False
computational_batch_size = 8
batch_size = 32
loss = "is"
num_neg_samples = 1280
lr = 1e-4
lr_schedule = "slanted_triangular"
epochs = 1
hidden_sz = 128
num_attention_heads = 4
num_hidden_layers = 6
dataset = "jigsaw_ext"
hashed_vocab = True
softmax = "cnn_softmax"
denoise = True
put_embeddings_on_gpu = True
n_classes = 6
max_seq_len = 128
download_data = False
ft_model_path = "../data/jigsaw/ft_model.txt"
max_vocab_size = 500000
num_buckets = 500_000 # TODO: add to configurable parameters
min_freq = 3
dropouti = 0.2
dropoutw = 0.0
dropoute = 0.2
dropoutr = 0.3 # TODO: Implement
val_ratio = 0.0
use_augmented = False
freeze_embeddings = True
mixup_ratio = 0.0
discrete_mixup_ratio = 0.0
attention_bias = True
weight_decay = 0.
bias_init = True
neg_splits = 1
num_layers = 2
rnn_type = "lstm"
pooling_type = "augmented_multipool" # attention or multipool or augmented_multipool
model_type = "standard"
use_word_level_features = False
use_sentence_level_features = False
bucket = True
compute_thres_on_test = True
find_lr = False
permute_sentences = False
construct_vocab = False
mask_rate = 0.15
run_id = None

In [6]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    char_encoder=char_encoder,
    tie_weights=tie_weights,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    loss=loss,
    num_neg_samples=num_neg_samples,
    lr=lr,
    lr_schedule=lr_schedule,
    epochs=epochs,
    hidden_sz=hidden_sz,
    num_attention_heads=num_attention_heads,
    num_hidden_layers=num_hidden_layers,
    dataset=dataset,
    hashed_vocab=hashed_vocab,
    softmax=softmax,
    denoise=denoise,
    put_embeddings_on_gpu=put_embeddings_on_gpu,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    num_buckets=num_buckets,
    min_freq=min_freq,
    dropouti=dropouti,
    dropoutw=dropoutw,
    dropoute=dropoute,
    dropoutr=dropoutr,
    val_ratio=val_ratio,
    use_augmented=use_augmented,
    freeze_embeddings=freeze_embeddings,
    attention_bias=attention_bias,
    weight_decay=weight_decay,
    bias_init=bias_init,
    neg_splits=neg_splits,
    num_layers=num_layers,
    rnn_type=rnn_type,
    pooling_type=pooling_type,
    model_type=model_type,
    use_word_level_features=use_word_level_features,
    use_sentence_level_features=use_sentence_level_features,
    mixup_ratio=mixup_ratio,
    discrete_mixup_ratio=discrete_mixup_ratio,
    bucket=bucket,
    compute_thres_on_test=compute_thres_on_test,
    permute_sentences=permute_sentences,
    find_lr=find_lr,
    construct_vocab=construct_vocab,
    mask_rate=mask_rate,
    run_id=run_id,
)

In [7]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]  # pylint: disable=invalid-name

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util
from allennlp.data.dataset_readers import DatasetReader

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [8]:
if os.environ["IS_COLAB"] != "True":
    DATA_ROOT = Path("../data") / config.dataset
else:
    DATA_ROOT = Path("./gdrive/My Drive/Colab_Workspace/Colab Notebooks/data") / config.dataset
    config.ft_model_path = str(DATA_ROOT / "ft_model.txt")

In [9]:
import subprocess
if download_data:
    if config.val_ratio > 0.0:
        fnames = ["train_wo_val.csv", "test_proced.csv", "val.csv", "ft_model.txt"]
    else:
        fnames = ["train.csv", "test_proced.csv", "ft_model.txt"]
    if config.use_augmented or config.discrete_mixup_ratio > 0.0: fnames.append("train_extra.csv")
    for fname in fnames:
        if not (DATA_ROOT / fname).exists():
            print(subprocess.Popen([f"aws s3 cp s3://nnfornlp/raw_data/jigsaw/{fname} {str(DATA_ROOT)}"],
                                   shell=True, stdout=subprocess.PIPE).stdout.read())

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset

In [11]:
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class MemoryOptimizedTextField(TextField):
    @overrides
    def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
        self.tokens = tokens
        self._token_indexers = token_indexers
        self._indexed_tokens: Optional[Dict[str, TokenList]] = None
        self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
        # skip checks for tokens
    @overrides
    def index(self, vocab):
        super().index(vocab)
    
    def deindex(self):
        self._indexed_tokens = None
        self._indexer_name_to_indexed_token = None

In [12]:
class JigsawLMDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer]=None, # TODO: Handle mapping from BERT
                 output_token_indexers: Dict[str, TokenIndexer]=None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.output_token_indexers = output_token_indexers or token_indexers
        self.max_seq_len = max_seq_len
        
    def _clean(self, x: str) -> str:
        """
        Maps a word to its desired output. Will leave as identity for now.
        In the future, will change to denoising operation.
        """
        return x

    @overrides
    def text_to_instance(self, tokens: List[str]) -> Instance:
        sentence_field = MemoryOptimizedTextField(
            [x for x in tokens],
            self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = MemoryOptimizedTextField(
            [self._clean(x) for x in tokens],
            self.output_token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                self.tokenizer(row["comment_text"]),
            )

In [13]:
import csv
class JigsawDenoiseDatasetReader(DatasetReader):
    def __init__(self,
                 lazy: bool=True,
                 tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer]=None, # TODO: Handle mapping from BERT
                 output_token_indexers: Dict[str, TokenIndexer]=None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=lazy)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.output_token_indexers = output_token_indexers or token_indexers
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[str], noise_tokens: List[str]) -> Instance:
        sentence_field = MemoryOptimizedTextField(
            [x for x in noise_tokens],
            self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = MemoryOptimizedTextField(
            [x for x in tokens],
            self.output_token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            reader = csv.reader(f)
            next(reader)
            for i, line in enumerate(reader):
                if config.testing and i == 1000: break
                id_, text, noised_text = line
                text = self.tokenizer(text)
                noised_text = self.tokenizer(noised_text)
                assert len(text) == len(noised_text)
                yield self.text_to_instance(
                    text, noised_text,
                )

In [14]:
class WebbaseReader(DatasetReader):
    def __init__(self,
                 lazy: bool=True,
                 tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer]=None, # TODO: Handle mapping from BERT
                 output_token_indexers: Dict[str, TokenIndexer]=None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=lazy)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.output_token_indexers = output_token_indexers or token_indexers
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[str], noise_tokens: List[str]) -> Instance:
        sentence_field = MemoryOptimizedTextField(
            [x for x in noise_tokens],
            self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = MemoryOptimizedTextField(
            [x for x in tokens],
            self.output_token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
            for i, line in enumerate(reader):
                if config.testing and i == 1000: break
                try:
                    text, noised_text = line
                except:
                    continue
                text = self.tokenizer(text)
                noised_text = self.tokenizer(noised_text)
                assert len(text) == len(noised_text)
                yield self.text_to_instance(
                    text, noised_text,
                )

In [15]:
def replace_repeats(s, max_reps=2):
    last_char = None
    last_char_count = 0
    new_s = ""
    for c in s:
        if c == last_char:
            if last_char_count >= max_reps:
                continue
            else:
                new_s += c
                last_char_count += 1
        else:
            last_char = c
            last_char_count = 1
            new_s += c
    return new_s

In [16]:
import re
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper, ELMoTokenCharactersIndexer
from allennlp.data.token_indexers import SingleIdTokenIndexer

if config.char_encoder == "cnn":
    token_indexer = ELMoTokenCharactersIndexer()
else:
    token_indexer = SingleIdTokenIndexer(
        lowercase_tokens=config.softmax != "cnn_softmax") # Temporary

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words


url_ptrn = re.compile("https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+")
def proc(x):
    x = x.lower()
    x = url_ptrn.sub("url", x)
    x = replace_repeats(x)
    return x

if config.denoise:
    def tokenizer(x: str):
        return ["[CLS]"] + [proc(w) for w in
                x.split()[:config.max_seq_len - 2]] + ["[SEP]"]
else:
    _spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words
    def tokenizer(x: str):
        return ["[CLS]"] + [proc(w.text) for w in
                _spacy_tok(x)[:config.max_seq_len - 2]] + ["[SEP]"]

In [17]:
if config.loss == "is" and config.softmax == "cnn_softmax" and not config.construct_vocab:
    output_token_indexer = token_indexer
else:
    output_token_indexer = SingleIdTokenIndexer(lowercase_tokens=True)

input_token_indexers = {"tokens": token_indexer}

if config.dataset == "jigsaw":
    if config.denoise:
        reader = JigsawDenoiseDatasetReader(
            tokenizer=tokenizer,
            token_indexers=input_token_indexers,
            output_token_indexers={"tokens": output_token_indexer}
        )
        train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in [
            "train_side_by_side.csv",
            "val_side_by_side.csv",
            "val_side_by_side.csv",
        ])
    else:
        reader = JigsawLMDatasetReader(
            tokenizer=tokenizer,
            token_indexers=input_token_indexers,
            output_token_indexers={"tokens": output_token_indexer}
        )
        train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                                  "val.csv",
                                                                                  "test_proced.csv"])
elif config.dataset == "webbase":
    assert config.hashed_vocab or config.char_encoder == "cnn"
    assert config.denoise
    if config.denoise:
        reader = WebbaseReader(
            tokenizer=tokenizer,
            token_indexers=input_token_indexers,
            output_token_indexers={"tokens": output_token_indexer}
        )
        train_ds = (reader.read(DATA_ROOT / "all_denoise.txt"))
        val_ds = None
elif config.dataset == "jigsaw_ext":
    assert config.denoise
    reader = JigsawDenoiseDatasetReader(
            tokenizer=tokenizer,
            token_indexers=input_token_indexers,
            output_token_indexers={"tokens": output_token_indexer}
        )
    train_ds = reader.read(DATA_ROOT / "train_denoise.csv")
    val_ds = None
else:
    raise ValueError(f"Invalid dataset {config.dataset}")

### Build Vocab

In [18]:
from allennlp.data.vocabulary import Vocabulary

In [19]:
import fnv
class HashedVocabulary(Vocabulary):
    def __init__(self, *args, 
                 protected={"[pad]": 0, "[cls]": 2, "[sep]": 3, "[MASK]": 1},
                 num_buckets=config.num_buckets,
                 **kwargs,
                ):
        super().__init__(*args, **kwargs)
        self.protected = protected
        self.num_buckets = (num_buckets - len(self.protected))
        
    def get_token_index(self, w, namespace="tokens"):
        if w in self.protected: return self.protected[w]
        else:
            hash_val = fnv.hash(w.encode("utf-8"), bits=32)
            return (hash_val % self.num_buckets) + len(self.protected)
    
    def get_vocab_size(self, namespace="tokens"):
        if "char" not in namespace: return self.num_buckets + len(self.protected)
        else: raise ValueError(f"{namespace} is not a valid namespace")

In [20]:
if config.hashed_vocab and not config.char_encoder == "cnn":
    vocab = HashedVocabulary()
else:
    if config.construct_vocab:
        full_ds = train_ds + test_ds + val_ds
        vocab = Vocabulary.from_instances(full_ds, tokens_to_add={"tokens": ["[MASK]"]},
                                          min_count={"tokens": 2},
                                          max_vocab_size=config.max_vocab_size)
        vocab.save_to_files(DATA_ROOT / "bert_vocab")
    else:
        vocab = Vocabulary.from_files(DATA_ROOT / "bert_vocab")

In [21]:
vocab.get_vocab_size("tokens")

98251

### Build frequencies

TODO: Implement fast sampling

In [22]:
if config.loss == "is":
    if config.construct_vocab:
        freqs = np.zeros(vocab.get_vocab_size())
        for w, c in vocab._retained_counter["tokens"].items():
            freqs[vocab.get_token_index(w)] = c
        freqs /= freqs.sum()
        freqs **= 2 / 3
        # renormalize
        freqs /= freqs.sum()
        np.save(DATA_ROOT / "bert_vocab_counts.npy", freqs)
    else:
        freqs = np.load(DATA_ROOT / ".." / "jigsaw" / "bert_vocab_counts.npy")

In [23]:
config.set("vocab_sz", vocab.get_vocab_size("tokens"))

In [24]:
config.vocab_sz

98251

In [25]:
assert config.vocab_sz == freqs.shape[0]

### Build Iterator

In [26]:
from allennlp.data.iterators import BucketIterator, DataIterator, BasicIterator, MemoryOptimizedIterator
if config.bucket:
    iterator = BucketIterator(
            batch_size=config.computational_batch_size, 
            biggest_batch_first=config.testing,
            sorting_keys=[("input", "num_tokens")],
            max_instances_in_memory=config.batch_size * 2,
    )
else:
    iterator = MemoryOptimizedIterator(
        batch_size=config.computational_batch_size,
        max_instances_in_memory=config.batch_size * 2,
    )
iterator.index_with(vocab)

In [27]:
batch = next(iterator(train_ds))

In [28]:
batch["input"]["tokens"].max()

tensor(261)

### Build word to indices mapping

TODO: Output this constructed dictionary to disk and load in the future

In [29]:
import fnv
def fnv_hash(w):
    return fnv.hash(w.encode("utf-8"), bits=32)

In [30]:
def generate_char_ngrams(w, range_=range(2, 6)):
    w = "<" + w + ">"
    for r in range_: # prioritize smaller n-grams
        for i, c in enumerate(w):
            if i + r <= len(w): yield w[i:i+r]

In [31]:
from tqdm import tqdm
if config.char_encoder == "cnn":
    # TODO: Speed up
    # TODO: Debug
    # See allennlp/data/token_indexers/elmo_indexer.py
    with timer("Building character indexes"):
        word_id_to_char_idxs = np.zeros((config.vocab_sz, 50))
        for w, idx in tqdm(vocab.get_token_to_index_vocabulary().items()):
            # TODO: Check for start/end of word symbols
            if idx > 0: 
                word_id_to_char_idxs[idx, :] = 261
                word_id_to_char_idxs[idx, 0] = 259
                for i, c in enumerate(w.encode("utf-8")):
                    if i + 1 == 48: break
                    word_id_to_char_idxs[idx, i+1] = int(c) + 1
                word_id_to_char_idxs[idx, i+2] = 260
        word_id_to_char_idxs = np.array(word_id_to_char_idxs)

    word_id_to_char_idxs = torch.LongTensor(word_id_to_char_idxs).to(device)
elif config.char_encoder == "subword":
    word_to_uid = {}
    uid = 2 # reserve @@PADDING@@ and @@UNKNOWN@@
    for word, freq in tqdm(vocab._retained_counter["tokens"].items()):
        if freq > config.min_freq:
            word_to_uid[word] = uid; uid += 1
        else:
            word_to_uid[word] = -1
    num_uniq_ids = uid
    offset = num_uniq_ids # these slots are reserved for unique words in the embedding matrix
    with timer("Building word to subword indices mapping"):
        subword_ids = [[] for _ in vocab.get_token_to_index_vocabulary()]
        freq_counter = vocab._retained_counter["tokens"]
        for word, idx in tqdm(vocab.get_token_to_index_vocabulary().items()):
            if word == "@@PADDING@@" or word == "@@UNKNOWN@@" or word == "[MASK]":
                subword_ids[idx] = [idx]
            else:
                uid = word_to_uid[word]
                uniq_idx = [uid] if uid > -1 else [] # only give unique index to words with sufficient frequency
                subword_ids[idx] = uniq_idx + list([fnv_hash(x) % config.num_buckets + offset
                                                    for x in generate_char_ngrams(word)])
        maxlen = int(np.percentile([len(a) for a in subword_ids], 95)) # 95th-percentile
        word_id_to_subword_ids = torch.zeros(len(subword_ids), maxlen, dtype=torch.long)
        word_id_to_num_subwords = torch.zeros(len(subword_ids), 1, dtype=torch.float)
        print(word_id_to_subword_ids.shape)
        for i, idxs in tqdm(enumerate(subword_ids)):
            for j, id_ in enumerate(idxs):
                if j >= maxlen: break
                word_id_to_subword_ids[i, j] = id_
            word_id_to_num_subwords[i] = len(idxs)

100%|██████████| 98251/98251 [00:00<00:00, 192921.48it/s]


[Building character indexes] done in 1 s


In [32]:
word_id_to_char_idxs[:10]

tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0],
        [259,  65,  65,  86,  79,  76,  79,  80,  88,  79,  65,  65, 260, 261,
         261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
         261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
         261, 261, 261, 261, 261, 261, 261, 261],
        [259,  47, 260, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
         261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
         261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 261,
         261, 261, 261, 261, 261, 261, 261, 261],
        [259, 117, 105, 102, 260, 261, 261, 261, 261, 261, 261, 261, 261, 261,
         261, 261, 261, 261, 261, 261, 261, 261, 261, 261, 2

In [33]:
if config.char_encoder == "cnn":
    mask_char_ids = torch.ones(50, dtype=torch.int64).to(device) * 261
    for i, c in enumerate("[MASK]".encode("utf-8")):
        mask_char_ids[i+1] = int(c) + 1
    mask_char_ids[i+2] = 260

# Bert configuration

In [34]:
from pytorch_pretrained_bert.modeling import (BertConfig, BertForMaskedLM, 
                                              BertEncoder, BertPooler, BertOnlyMLMHead)

bert_config = BertConfig(
        config.max_vocab_size, 
        hidden_size=config.hidden_sz, 
        num_attention_heads=config.num_attention_heads,
        num_hidden_layers=config.num_hidden_layers, 
        intermediate_size=config.hidden_sz * config.num_attention_heads,
        max_position_embeddings=config.max_seq_len,
)

bert_config

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "max_position_embeddings": 128,
  "num_attention_heads": 4,
  "num_hidden_layers": 6,
  "type_vocab_size": 2,
  "vocab_size": 500000
}

# Building token embedder

In [35]:
if config.char_encoder == "cnn":
    from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
    from allennlp.modules.elmo import _ElmoCharacterEncoder

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
    _inner_char_encoder = _ElmoCharacterEncoder(
        options_file=options_file, 
        weight_file=weight_file,
        requires_grad=True
    )
    class ElmoEncoder(nn.Module):
        def __init__(self, _inner):
            super().__init__()
            self._inner = _inner
        def forward(self, *args):
            # TODO: Stop Elmo encoder from adding SoS and EoS tokens
            return self._inner(*args)["token_embedding"][:, 1:-1, :]
        def get_output_dim(self):
            return self._inner.get_output_dim()
    char_encoder = ElmoEncoder(_inner_char_encoder)

# char_encoder

# sample_idxs = next(iterator(train_ds))["tokens"]["tokens"]

# char_encoder(sample_idxs)
    config.set("embedding_sz", char_encoder.get_output_dim())

Bag of char-ngrams using fastText (too much memory consumption for now...)

In [36]:
from torch.nn.modules.sparse import EmbeddingBag

In [37]:
if config.char_encoder == "subword":
    from torch.nn.modules.sparse import EmbeddingBag
    
    # A hack to prevent these embeddings from being put on the GPU
    global_emb_bag = EmbeddingBag(offset + config.num_buckets, 250, mode="sum")
    if config.put_embeddings_on_gpu and torch.cuda.is_available():
        global_emb_bag.cuda()

    class SubwordEncoder(nn.Module):
        def __init__(self, num_embeddings, 
                     embedding_sz, 
                     embeddings_on_gpu: bool=config.put_embeddings_on_gpu):
            super().__init__()
            self._embeddings_on_gpu = embeddings_on_gpu
            self._device = torch.device("cuda:0" if torch.cuda.is_available() and embeddings_on_gpu else "cpu")
            self._subword_encoding = word_id_to_subword_ids.to(self._device)
            self._subword_count = word_id_to_num_subwords.to(self._device)
            self.n_subwords_per_word = self._subword_encoding.size(1)

        def forward(self, 
                    tsr: torch.LongTensor, # (batch, seq) or # (batch)
                   ) -> torch.FloatTensor: # (batch, seq, feat) or # (batch, feat)
            # TODO: can I use offsets in a differentiable manner??
            if len(tsr.shape) > 1:
                bs, seq = tsr.size(0), tsr.size(1)
                subword_ids = self._subword_encoding[tsr]
                bag_of_embs = global_emb_bag(subword_ids.view((-1, self.n_subwords_per_word)) # need to convert to 2D
                                      ).view((bs, seq, -1)) # reshape to 3d
            else:
                subword_ids = self._subword_encoding[tsr]
                bag_of_embs = global_emb_bag(subword_ids)
            norm_factor = self._subword_count[tsr] + 0.1
            embs = bag_of_embs / norm_factor
            if self._embeddings_on_gpu:
                return embs
            else:
                return embs.to(device)
    
    char_encoder = SubwordEncoder(offset + config.num_buckets, 250)
    config.set("embedding_sz", 250)

In [38]:
char_encoder

ElmoEncoder(
  (_inner): _ElmoCharacterEncoder(
    (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
    (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
    (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
    (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
    (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
    (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
    (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
    (_highways): Highway(
      (_layers): ModuleList(
        (0): Linear(in_features=2048, out_features=4096, bias=True)
      )
    )
    (_projection): Linear(in_features=2048, out_features=128, bias=True)
  )
)

Simple word-level embeddings

In [39]:
if config.char_encoder == "fasttext":
    from allennlp.modules import Embedding

    ft_matrix = np.random.randn(bert_config.vocab_size, 300) * 0.3
    char_encoder = Embedding(bert_config.vocab_size, 300, weight=torch.FloatTensor(ft_matrix))
    config.set("embedding_sz", 300)

In [40]:
def freeze(x):
    x.requires_grad = False
    if hasattr(x, "parameters"):
        for p in x.parameters: freeze(p)

In [41]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

In [42]:
class EmbeddingWPositionEmbs(nn.Module):
    """Embeds, then maps the embeddings to the bert_hidden_sz for processing"""
    def __init__(self, word_emb: nn.Module, 
                 embedding_dim,
                 bert_hidden_sz, 
                 freeze_embeddings=False,
                 dropout=0.1):
        super().__init__()
        self.word_emb = word_emb
        if freeze_embeddings: freeze(self.word_emb)
        self.position_embeddings = nn.Embedding(config.max_seq_len, 
                                                bert_hidden_sz)
        self.linear = nn.Linear(embedding_dim, bert_hidden_sz,
                                bias=False) # Transform dimensions
        self.norm = LayerNorm(bert_hidden_sz)
        self.do = nn.Dropout(0.1)
    
    def get_word_embs(self, input_ids):
        return self.linear(self.word_emb(input_ids))
    
    def forward(self, input_ids):
        # We won't be using token types since we won't be predicting the next sentence
        bs, seq_length, *_ = input_ids.shape
        position_ids = (torch.arange(seq_length, dtype=torch.long)
                             .to(device)
                             .unsqueeze(0)
                             .expand((bs, seq_length)))
        word_embs = self.get_word_embs(input_ids)
        position_embs = self.position_embeddings(position_ids)
        normed = self.norm(word_embs + position_embs)
        return self.do(normed)

In [43]:
sample_embs = EmbeddingWPositionEmbs(
    char_encoder,
    config.embedding_sz,
    bert_config.hidden_size,
)

In [44]:
if torch.cuda.is_available(): sample_embs.cuda()

In [45]:
for nm, param in sample_embs.named_parameters():
    if "bag" in nm: param.to(torch.device("cpu"))

# Masked Language Model

### The encoder

In [46]:
bert_encoder = BertEncoder(bert_config)

In [47]:
class CustomBert(nn.Module):
    def __init__(self, embeddings, encoder):
        super().__init__()
        self.embeddings = embeddings
        self.encoder = encoder
        
    def forward(self, input_ids, 
                token_type_ids=None, 
                attention_mask=None):
        if attention_mask is None:
            if len(input_ids.shape) > 2:
                attention_mask = torch.ones_like(input_ids[:, :, 0]).to(device)
            else:
                attention_mask = torch.ones_like(input_ids).to(device)
        if token_type_ids is None:
            if len(input_ids.shape) > 2:
                token_type_ids = torch.ones_like(input_ids[:, :, 0]).to(device)
            else:
                token_type_ids = torch.ones_like(input_ids).to(device)
        
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=True)
        return encoded_layers

In [48]:
class BertMLMPooler(nn.Module):
    def forward(self, x: List[torch.FloatTensor]) -> torch.FloatTensor:
        return x[-1] # return final layer only

In [49]:
bert_model = CustomBert(sample_embs, bert_encoder)

In [50]:
import math
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertCustomLMPredictionHead(nn.Module):
    def __init__(self, config, out_sz, vocab_sz, 
                 embedding, output_logits=True):
        super().__init__()
        # Projections
        self.dense = nn.Linear(config.hidden_size, out_sz)
        self.transform_act_fn = gelu
        self.LayerNorm = LayerNorm(out_sz)
        
        # Predictions
        self.output_logits = output_logits
        if self.output_logits:
            self.decoder = nn.Linear(out_sz, vocab_sz, 
                                     bias=False)
            if embedding is not None:
                self.decoder.weight = embedding
            self.bias = nn.Parameter(torch.zeros(vocab_sz))

    def forward(self, hidden_states):
        if self.output_logits:
            hidden_states = self.dense(hidden_states)
            hidden_states = self.transform_act_fn(hidden_states)
            preds = self.LayerNorm(hidden_states)
            preds = self.decoder(preds) + self.bias
        else:
            preds = hidden_states
        return preds

In [51]:
output_logits = config.loss != "is"
bert_mlm_head = BertCustomLMPredictionHead(config=bert_config, 
                                           out_sz=config.embedding_sz, 
                                           vocab_sz=config.vocab_sz,
                                           embedding=(sample_embs.word_emb.weight 
                                                      if config.char_encoder == "fasttext" 
                                                      else None),
                                           output_logits=output_logits,
)

In [52]:
custom_model = nn.Sequential(
    bert_model,
    BertMLMPooler(),
    bert_mlm_head,
)

### The decoder

In [53]:
class ElmoDecoder(nn.Module):
    """TODO: Add word correction"""
    def __init__(self, enc: nn.Module, dec: nn.Linear, word_correction_dim: int=0):
        super().__init__()
        self._enc = enc
        self._dec = dec
        self.word_correction_dim = word_correction_dim
        if word_correction_dim > 0:
            self.word_correction = nn.Embedding(config.max_vocab_size,
                                                word_correction_dim)
            self.back_projection = nn.Linear(word_correction_dim, self._dec.out_features, 
                                             bias=False)
        
    def forward(self, idxs, lookup_char_idxs=True):
        if len(idxs.shape) == 1: 
            idxs = idxs.unsqueeze(0)
        if lookup_char_idxs:
            char_idxs = word_id_to_char_idxs[idxs]
        else:
            char_idxs = idxs
            
        output = self._dec(self._enc(char_idxs)).squeeze(0) 
        if self.word_correction_dim > 0:
            corr = self.back_projection(self.word_correction(idxs.squeeze(0)))
            return output + corr
        else: return output

In [54]:
sample_embs.word_emb

ElmoEncoder(
  (_inner): _ElmoCharacterEncoder(
    (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
    (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
    (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
    (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
    (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
    (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
    (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
    (_highways): Highway(
      (_layers): ModuleList(
        (0): Linear(in_features=2048, out_features=4096, bias=True)
      )
    )
    (_projection): Linear(in_features=2048, out_features=128, bias=True)
  )
)

In [55]:
if config.char_encoder == "fasttext":
    output_embs = sample_embs.get_word_embs
elif config.char_encoder == "cnn" and config.softmax == "cnn_softmax" and config.loss == "is":
    if not config.tie_weights:
        _inner_char_decoder = _ElmoCharacterEncoder(
                options_file=options_file, 
                weight_file=weight_file,
                requires_grad=True
        )
        char_decoder = ElmoEncoder(_inner_char_decoder)
    else:
        char_decoder = sample_embs.word_emb
    # share just the linear transformation with 
    output_embs = ElmoDecoder(char_decoder, sample_embs.linear)
    if torch.cuda.is_available(): output_embs.cuda()
elif config.char_encoder == "subword":
    output_embs = sample_embs.get_word_embs # tie input and output weights
else:
    from allennlp.modules import Embedding
    mtrx = None
    output_embs = Embedding(config.vocab_sz, bert_config.hidden_size, weight=mtrx)
    if torch.cuda.is_available(): output_embs.cuda()

In [56]:
class WordCorrection(nn.Module):
    """From the paper `Exploring the Limitations of Language Modeling`"""
    def __init__(self, hidden_sz: int, bottleneck_sz: int=128):
        super().__init__()
        self.l1 = nn.Linear(hidden_sz, bottleneck_sz)
        
    def forward(self, h: torch.FloatTensor, 
                corr: torch.FloatTensor):
        x = self.l1(h)
        return x @ corr

# Loss Functions

Masked Cross Entropy

In [57]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self._loss = nn.CrossEntropyLoss(reduction="none")

    def forward(self, preds, tgts, mask=None) -> torch.tensor:
        if mask is None:
            return self._loss(preds, tgts).mean()
        else:
            # Is this reshaping really necessary? Seems like there would be a more elegant solution
            loss = self._loss(preds.view((-1, preds.size(-1))),
                              tgts.view((-1, )))
            n_elements = mask.sum()
            return (loss * mask.view((-1, )).float()).sum() / n_elements

Importance Sampling

In [58]:
class UniformSampler:
    def __init__(self, min_, max_):
        self.min_, self.max_ = min_, max_
    
    def sample(self, shape):
        return torch.randint(low=self.min_,
                             high=self.max_, size=shape)

class UnigramSampler:
    def __init__(self, probs):
        self.probs = probs
    @staticmethod
    def _prod(x):
        acc = 1
        for a in x: acc *= a
        return acc
    def sample(self, shape):
        return torch.multinomial(self.probs, self._prod(shape), replacement=True).view(shape)

In [59]:
class ImportanceSamplingLoss(nn.Module):
    def __init__(self, embedding_generator,
                 probs: np.ndarray, k=config.num_neg_samples):
        super().__init__()
        self.embedding_generator = embedding_generator
        # TODO: Should this be according to the unigram probability?
        # Or should it be uniform?
        self.sampler = UnigramSampler(probs=torch.FloatTensor(probs))
        # TODO: Compute samples in advance
        self._loss_func = MaskedCrossEntropyLoss()
        self.k = k
    
    def get_negative_samples(self) -> torch.LongTensor:
        neg = self.sampler.sample((self.k, )) # TODO: Speed up??
        return neg
    
    def get_embeddings(self, idxs: torch.LongTensor) -> torch.FloatTensor:
        """Converts indexes into vectors"""
        return self.embedding_generator(idxs) # TODO: Implement general case
    
    def forward(self, y: torch.LongTensor, tgt, mask=None):
        """
        Expects input of shape
        y: (batch * seq, feature_sz)
        tgt: (batch * seq, )
        """
        if len(y.shape) > 2:            
            y = y.view((-1, y.size(-1))) # (batch * seq, emb_sz)
            tgt = tgt.view((-1, )) # (batch * seq, s)
        bs, emb_sz = y.size(0), y.size(1)
        pos_embeddings = self.get_embeddings(tgt) # (bs, emb_sz)
        # share negative samples across the batch
        neg_samples = (self.get_negative_samples()
                       .to(y.device)) # (k, )
        neg_embeddings = self.get_embeddings(neg_samples) # (k, emb_sz)
        embs = torch.cat([
            pos_embeddings.unsqueeze(1), # (bs, 1, emb_sz)
            neg_embeddings.unsqueeze(0).expand(bs, self.k, emb_sz) # (bs, k, emb_sz)
        ], dim=1) # (bs, k+1, emb_sz)
        dot_prods = torch.einsum("bkf,bf->bk", embs, y)
        return self._loss_func(dot_prods, torch.zeros(bs, dtype=torch.int64).to(y.device),
                               mask=mask)

In [60]:
class ImportanceSamplingCNNLoss(nn.Module):
    def __init__(self, embedding_generator,
                 probs: np.ndarray, k=config.num_neg_samples):
        super().__init__()
        self.embedding_generator = embedding_generator
        # TODO: Should this be according to the unigram probability?
        # Or should it be uniform?
        self.sampler = UnigramSampler(probs=torch.FloatTensor(probs))
        # TODO: Compute samples in advance
        self._loss_func = MaskedCrossEntropyLoss()
        self.k = k
    
    def get_negative_samples(self) -> torch.LongTensor:
        neg = self.sampler.sample((self.k, )) # TODO: Speed up??
        return neg
    
    def get_embeddings(self, idxs: torch.LongTensor, **kwargs) -> torch.FloatTensor:
        """Converts indexes into vectors"""
        return self.embedding_generator(idxs, **kwargs) # TODO: Implement general case
    
    def forward(self, y: torch.LongTensor, tgt, mask=None):
        """
        Expects input of shape
        y: (batch * seq, feature_sz)
        tgt: (batch * seq, chars)
        """
        if len(y.shape) > 2:            
            y = y.view((-1, y.size(-1))) # (batch * seq, emb_sz)
        bs, emb_sz = y.size(0), y.size(1)
        pos_embeddings = self.get_embeddings(tgt, lookup_char_idxs=False).view(bs, -1) # (bs, emb_sz)
        # share negative samples across the batch
        neg_samples = (self.get_negative_samples()
                       .to(y.device)) # (k, )
        neg_embeddings = self.get_embeddings(neg_samples) # (k, emb_sz)
        embs = torch.cat([
            pos_embeddings.unsqueeze(1), # (bs, 1, emb_sz)
            neg_embeddings.unsqueeze(0).expand(bs, self.k, emb_sz) # (bs, k, emb_sz)
        ], dim=1) # (bs, k+1, emb_sz)
        dot_prods = torch.einsum("bkf,bf->bk", embs, y)
        return self._loss_func(dot_prods, torch.zeros(bs, dtype=torch.int64).to(y.device),
                               mask=mask)

Testing

In [61]:
emb_sz = bert_config.hidden_size
y = torch.randn((3, 7, emb_sz)).view((-1, emb_sz)).to(device)
tgt = torch.randint(100, (3, 7, 50)).to(device)

In [62]:
tgt

tensor([[[39, 87, 76,  ..., 90, 45, 19],
         [84, 92, 16,  ..., 38, 94, 94],
         [40, 59, 97,  ..., 92, 60, 20],
         ...,
         [ 6, 28, 23,  ...,  6, 36, 13],
         [35, 95, 22,  ..., 39,  0, 62],
         [70, 38, 45,  ..., 95, 52, 43]],

        [[81, 53, 80,  ..., 25, 11, 81],
         [26, 34, 40,  ..., 29, 32, 82],
         [36, 60, 61,  ..., 30, 90, 89],
         ...,
         [74, 84, 16,  ..., 30, 72, 29],
         [98, 41, 66,  ..., 87, 99, 31],
         [50, 47, 89,  ..., 50, 53, 24]],

        [[41, 59, 60,  ..., 36, 67,  7],
         [ 4, 43, 68,  ..., 88, 97, 17],
         [96, 34, 44,  ..., 14, 83, 35],
         ...,
         [93, 28, 39,  ..., 59, 53, 12],
         [24, 34,  4,  ..., 32, 19, 48],
         [17, 95, 20,  ...,  6, 23, 68]]], device='cuda:0')

In [63]:
if config.loss == "is":
    loss = ImportanceSamplingCNNLoss(
        output_embs, k=10, probs=torch.rand(100),
    )
    loss(y, tgt)

# Training

In [64]:
from allennlp.models import Model

In [65]:
class Masker(nn.Module):
    def __init__(self, vocab: Vocabulary, 
                 noise_rate: float=config.mask_rate,
                 mask_rate=1.0,
                 replace_rate=0.0):
        super().__init__()
        self.vocab = vocab
        self.vocab_sz = vocab.get_vocab_size()
        if config.char_encoder == "cnn":
            self.mask_id = mask_char_ids.unsqueeze(0).unsqueeze(1)
        else:
            self.mask_id = vocab.get_token_index("[MASK]")
        self.noise_rate = noise_rate
        self.mask_rate = mask_rate
        self.replace_rate = replace_rate
        
    def create_mask(self, shape, ones_ratio, dtype=torch.uint8):
        return (torch.ones(shape, dtype=dtype, 
                           requires_grad=False)
                     .bernoulli(ones_ratio))

    def get_random_input_ids(self, shape):
        """Returns randomly sampled """
        if config.char_encoder == "cnn":
            rint = torch.randint(self.vocab_sz, shape)
            return word_id_to_char_idxs[rint]
        elif config.char_encoder == "fasttext" or config.char_encoder == "subword":
            return torch.randint(self.vocab_sz, shape)

    def forward(self, x: torch.LongTensor) -> torch.LongTensor:
        char_level = len(x.shape) > 2 # using character-level features, but mask at word-level
        if self.noise_rate > 0:
            with torch.no_grad(): # no grads required here
                mask_shape = x.shape[:-1] if char_level else x.shape
                mask = self.create_mask(mask_shape, 
                                        self.noise_rate * self.mask_rate).to(x.device)
                if config.char_encoder == "cnn":
                    x = torch.where(mask.unsqueeze(2), self.mask_id, x)                
                else:
                    x = x.masked_fill(mask, self.mask_id)
        return x

In [66]:
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy

class MaskedLM(Model):
    def __init__(self, vocab: Vocabulary, model: nn.Module,
                loss: nn.Module, noise_rate=0.8):
        super().__init__(vocab)
        self.masker = Masker(vocab, noise_rate=noise_rate)
        self.accuracy = CategoricalAccuracy()
        self.model = model
        self.loss = loss
    
    @property
    def outputs_logits(self) -> bool:
        return self.model[-1].output_logits
    
    def forward(self, input: TensorDict, 
                output: TensorDict, **kwargs) -> TensorDict:
        mask = get_text_field_mask(input).to(device)
        x = self.masker(input["tokens"])
        tgt = output["tokens"].to(device)
        
        logits = self.model(x)
        out_dict = {"loss": self.loss(logits, tgt, mask=mask)}
        out_dict["logits"] = logits
        if self.outputs_logits:
            out_dict["accuracy"] = self.accuracy(logits, tgt)
        return out_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.outputs_logits:
            return {"accuracy": self.accuracy.get_metric(reset)}
        else:
            return {}

In [67]:
if config.loss == "masked_crossentropy":
    loss = MaskedCrossEntropyLoss()
elif config.loss == "crossentropy":
    _loss = nn.CrossEntropyLoss()
    def ce(y, t, mask=None): 
        return _loss(y.view((-1, y.size(-1))), t.view((-1, )))
    loss = ce
elif config.loss == "is":
    if config.softmax == "cnn_softmax":
        loss = ImportanceSamplingCNNLoss(output_embs, freqs)
    else:# TODO: Implement masking
        loss = ImportanceSamplingLoss(output_embs, freqs)
else:
    raise ValueError("AAAAAAAAAAAAA")
masked_lm = MaskedLM(vocab, custom_model, loss, noise_rate=0.15)
if torch.cuda.is_available(): masked_lm.cuda() # Is this different from to(device)?

In [68]:
torch.save(masked_lm.state_dict(), DATA_ROOT / "masked_lm_tmp2.pth")

Sample

In [69]:
batch = nn_util.move_to_device(batch, 0)

In [70]:
input = batch["input"]
output = batch["output"]
mask = get_text_field_mask(input).to(device)
x = masked_lm.masker(input["tokens"])
tgt = output["tokens"].to(device)

logits = masked_lm.model(x)

In [71]:
logits

tensor([[[ 1.8803e+00, -1.1189e+00,  3.8443e-01,  ..., -3.1873e-01,
           5.9588e-01,  3.8869e-01],
         [-3.4913e-01, -4.6801e-01, -1.1681e+00,  ..., -1.2559e+00,
           1.6917e+00,  8.2842e-01],
         [ 5.6573e-01, -1.6289e+00, -1.1664e+00,  ..., -6.6965e-01,
           7.6219e-01, -1.2412e+00],
         ...,
         [ 5.0758e-01, -1.5967e+00, -8.9374e-01,  ..., -1.3117e+00,
           8.4707e-01, -2.8985e-01],
         [-1.6587e+00, -9.1444e-01, -6.7633e-02,  ...,  3.6648e-01,
           5.6736e-01, -8.1231e-01],
         [ 1.0118e+00, -2.8515e+00, -1.1834e+00,  ...,  9.1839e-01,
           8.8668e-01,  3.9653e-01]],

        [[ 1.0310e+00, -1.2212e+00, -3.2921e-01,  ..., -9.9011e-01,
           8.0284e-01, -1.0964e-01],
         [-1.5519e+00, -2.9082e-01, -1.0774e+00,  ..., -1.1716e+00,
           1.7685e+00,  6.5217e-01],
         [ 1.0063e+00, -2.3794e+00, -1.2714e+00,  ..., -4.5537e-01,
           4.2940e-01, -7.8346e-01],
         ...,
         [ 3.0946e-01, -1

In [72]:
y = logits
if len(y.shape) > 2:            
    y = y.view((-1, y.size(-1))) # (batch * seq, emb_sz)
bs, emb_sz = y.size(0), y.size(1)
pos_embeddings = loss.get_embeddings(tgt, lookup_char_idxs=False).view(bs, -1) # (bs, emb_sz)

In [73]:
pos_embeddings

tensor([[ 0.1372,  0.0569, -0.2525,  ...,  0.0218, -0.1455, -0.0987],
        [ 0.0607, -0.5274, -0.3160,  ...,  0.5322,  0.2186, -0.1910],
        [ 0.4555,  0.4182, -0.9137,  ...,  0.0938,  0.0448, -0.2436],
        ...,
        [ 0.2326, -0.5339, -0.0772,  ...,  0.4814, -0.0768, -0.1773],
        [ 0.4555,  0.4182, -0.9137,  ...,  0.0938,  0.0448, -0.2436],
        [ 0.0492, -0.0012, -0.2530,  ...,  0.2310, -0.2063, -0.0053]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [74]:
# share negative samples across the batch
neg_samples = (loss.get_negative_samples()
               .to(y.device)) # (k, )

In [75]:
neg_samples

tensor([  245,   523, 37057,  ...,  2793,  4597, 54890], device='cuda:0')

In [76]:
neg_samples.max()

tensor(98015, device='cuda:0')

In [77]:
word_id_to_char_idxs.shape

torch.Size([98251, 50])

In [78]:
neg_embeddings = loss.get_embeddings(neg_samples) # (k, emb_sz)
embs = torch.cat([
    pos_embeddings.unsqueeze(1), # (bs, 1, emb_sz)
    neg_embeddings.unsqueeze(0).expand(bs, loss.k, emb_sz) # (bs, k, emb_sz)
], dim=1) # (bs, k+1, emb_sz)
dot_prods = torch.einsum("bkf,bf->bk", embs, y)

Sanity checks

In [79]:
logits

tensor([[[ 1.8803e+00, -1.1189e+00,  3.8443e-01,  ..., -3.1873e-01,
           5.9588e-01,  3.8869e-01],
         [-3.4913e-01, -4.6801e-01, -1.1681e+00,  ..., -1.2559e+00,
           1.6917e+00,  8.2842e-01],
         [ 5.6573e-01, -1.6289e+00, -1.1664e+00,  ..., -6.6965e-01,
           7.6219e-01, -1.2412e+00],
         ...,
         [ 5.0758e-01, -1.5967e+00, -8.9374e-01,  ..., -1.3117e+00,
           8.4707e-01, -2.8985e-01],
         [-1.6587e+00, -9.1444e-01, -6.7633e-02,  ...,  3.6648e-01,
           5.6736e-01, -8.1231e-01],
         [ 1.0118e+00, -2.8515e+00, -1.1834e+00,  ...,  9.1839e-01,
           8.8668e-01,  3.9653e-01]],

        [[ 1.0310e+00, -1.2212e+00, -3.2921e-01,  ..., -9.9011e-01,
           8.0284e-01, -1.0964e-01],
         [-1.5519e+00, -2.9082e-01, -1.0774e+00,  ..., -1.1716e+00,
           1.7685e+00,  6.5217e-01],
         [ 1.0063e+00, -2.3794e+00, -1.2714e+00,  ..., -4.5537e-01,
           4.2940e-01, -7.8346e-01],
         ...,
         [ 3.0946e-01, -1

In [80]:
masked_lm.masker.noise_rate = 0.05 if config.denoise else 0.15

In [81]:
pos_embeddings

tensor([[ 0.1372,  0.0569, -0.2525,  ...,  0.0218, -0.1455, -0.0987],
        [ 0.0607, -0.5274, -0.3160,  ...,  0.5322,  0.2186, -0.1910],
        [ 0.4555,  0.4182, -0.9137,  ...,  0.0938,  0.0448, -0.2436],
        ...,
        [ 0.2326, -0.5339, -0.0772,  ...,  0.4814, -0.0768, -0.1773],
        [ 0.4555,  0.4182, -0.9137,  ...,  0.0938,  0.0448, -0.2436],
        [ 0.0492, -0.0012, -0.2530,  ...,  0.2310, -0.2063, -0.0053]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [82]:
if config.testing:
    batch = nn_util.move_to_device(batch, 0 if torch.cuda.is_available() else -1)

#     tokens = masked_lm.masker(batch["input"]["tokens"])

#     hidden_states = masked_lm.model[:2](tokens)

#     masked_lm.model[2](hidden_states)

In [83]:
if config.testing: masked_lm(**batch)

In [84]:
neg_embeddings = loss.get_embeddings(neg_samples) # (k, emb_sz)

Sanity checks

In [85]:
loss.get_embeddings

<bound method ImportanceSamplingCNNLoss.get_embeddings of ImportanceSamplingCNNLoss(
  (embedding_generator): ElmoDecoder(
    (_enc): ElmoEncoder(
      (_inner): _ElmoCharacterEncoder(
        (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
        (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
        (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
        (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
        (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
        (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
        (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
        (_highways): Highway(
          (_layers): ModuleList(
            (0): Linear(in_features=2048, out_features=4096, bias=True)
          )
        )
        (_projection): Linear(in_features=2048, out_features=128, bias=True)
      )
    )
    (_dec): Linear(in_features=128, out_features=128, bias=False)
  )
  (_los

In [86]:
sample_embs

EmbeddingWPositionEmbs(
  (word_emb): ElmoEncoder(
    (_inner): _ElmoCharacterEncoder(
      (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
      (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
      (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
      (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
      (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
      (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
      (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
      (_highways): Highway(
        (_layers): ModuleList(
          (0): Linear(in_features=2048, out_features=4096, bias=True)
        )
      )
      (_projection): Linear(in_features=2048, out_features=128, bias=True)
    )
  )
  (position_embeddings): Embedding(128, 128)
  (linear): Linear(in_features=128, out_features=128, bias=False)
  (norm): LayerNorm()
  (do): Dropout(p=0.1)
)

In [87]:
output_embs

ElmoDecoder(
  (_enc): ElmoEncoder(
    (_inner): _ElmoCharacterEncoder(
      (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
      (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
      (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
      (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
      (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
      (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
      (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
      (_highways): Highway(
        (_layers): ModuleList(
          (0): Linear(in_features=2048, out_features=4096, bias=True)
        )
      )
      (_projection): Linear(in_features=2048, out_features=128, bias=True)
    )
  )
  (_dec): Linear(in_features=128, out_features=128, bias=False)
)

In [88]:
neg_embeddings

tensor([[ 0.2023, -0.6875, -0.2014,  ...,  0.2349, -0.1344, -0.1402],
        [-0.7778, -0.0124,  0.6745,  ..., -0.2378, -0.0466, -0.1360],
        [-0.1739,  0.2081, -0.0993,  ...,  0.1709, -0.2481, -0.0325],
        ...,
        [-0.2379, -0.6720,  0.2089,  ...,  0.4555, -0.1540,  0.0879],
        [-0.2309,  0.0147, -0.1290,  ...,  0.1137, -0.5712,  0.3739],
        [ 0.2628, -0.2134,  0.0961,  ...,  0.2700, -0.3750,  0.1704]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [89]:
embs = torch.cat([
    pos_embeddings.unsqueeze(1), # (bs, 1, emb_sz)
    neg_embeddings.unsqueeze(0).expand(bs, loss.k, emb_sz) # (bs, k, emb_sz)
], dim=1) # (bs, k+1, emb_sz)
dot_prods = torch.einsum("bkf,bf->bk", embs, y)

In [90]:
loss

ImportanceSamplingCNNLoss(
  (embedding_generator): ElmoDecoder(
    (_enc): ElmoEncoder(
      (_inner): _ElmoCharacterEncoder(
        (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
        (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
        (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
        (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
        (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
        (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
        (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
        (_highways): Highway(
          (_layers): ModuleList(
            (0): Linear(in_features=2048, out_features=4096, bias=True)
          )
        )
        (_projection): Linear(in_features=2048, out_features=128, bias=True)
      )
    )
    (_dec): Linear(in_features=128, out_features=128, bias=False)
  )
  (_loss_func): MaskedCrossEntropyLoss(
    (_loss): CrossEntropy

In [91]:
from allennlp.training import Callback
import gc

class GCCallback(Callback):
    """Calls gc periodically to prevent memory errors"""
    def __init__(self, period: int=10000):
        self._period = period
        
    def on_batch_end(self, data):
        if (data["batches_this_epoch"] + 1) % self._period == 0:
            gc.collect()

In [92]:
from allennlp.training import Trainer, TrainerWithCallbacks

optimizer = torch.optim.Adam(masked_lm.parameters(), lr=config.lr)
SER_DIR = (DATA_ROOT / "bert_ckpts_cnn_denoise_is_cnnsoftmax") if not config.testing else None
cdevice = 0 if torch.cuda.is_available() and (config.char_encoder != "subword" or config.put_embeddings_on_gpu) else -1

# trainer = TrainerWithCallbacks(
#     model=masked_lm,
#     optimizer=optimizer,
#     iterator=iterator,
#     train_dataset=train_ds,
#     validation_dataset=val_ds,
#     gradient_accumulation_steps=config.batch_size // config.computational_batch_size,
#     serialization_dir=SER_DIR,
#     cuda_device=cdevice,
#     patience=1,
#     num_epochs=config.epochs,
#     grad_clipping=5.,
#     callbacks=[GCCallback()],
# )

In [93]:
# masked_lm.load_state_dict(torch.load(DATA_ROOT / "masked_lm_tmp2.pth"))

In [94]:
# trainer.train()

In [95]:
tmp_state_dict = torch.load(DATA_ROOT / ".." / "jigsaw" / "bert_ckpts_cnn_sm_denoise_is_tst_2" / "best.th")

In [96]:
import warnings

In [97]:
tmp_state_dict.keys()

odict_keys(['model.0.embeddings.word_emb._inner._char_embedding_weights', 'model.0.embeddings.word_emb._inner.char_conv_0.weight', 'model.0.embeddings.word_emb._inner.char_conv_0.bias', 'model.0.embeddings.word_emb._inner.char_conv_1.weight', 'model.0.embeddings.word_emb._inner.char_conv_1.bias', 'model.0.embeddings.word_emb._inner.char_conv_2.weight', 'model.0.embeddings.word_emb._inner.char_conv_2.bias', 'model.0.embeddings.word_emb._inner.char_conv_3.weight', 'model.0.embeddings.word_emb._inner.char_conv_3.bias', 'model.0.embeddings.word_emb._inner.char_conv_4.weight', 'model.0.embeddings.word_emb._inner.char_conv_4.bias', 'model.0.embeddings.word_emb._inner.char_conv_5.weight', 'model.0.embeddings.word_emb._inner.char_conv_5.bias', 'model.0.embeddings.word_emb._inner.char_conv_6.weight', 'model.0.embeddings.word_emb._inner.char_conv_6.bias', 'model.0.embeddings.word_emb._inner._highways._layers.0.weight', 'model.0.embeddings.word_emb._inner._highways._layers.0.bias', 'model.0.embed

In [98]:
def load_parital_state_dict(model, dct):
    added_params = set()
    non_added_params = set()
    for name, param in model.state_dict().items():
        if name in dct:
            try:
                param.data.copy_(dct[name])
                added_params.add(name)
            except:
                warnings.warn(f"Failed to load {name} even though key exists")
                non_added_params.add(name)
        else: non_added_params.add(name)
    return non_added_params

In [99]:
load_parital_state_dict(masked_lm, tmp_state_dict)

set()

In [100]:
SER_DIR = (DATA_ROOT / "bert_ckpts_cnn_denoise_is_cnnsoftmax") if not config.testing else None

In [101]:
masked_lm.load_state_dict(torch.load(SER_DIR / "best.th"))

In [102]:
trainer = TrainerWithCallbacks(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    gradient_accumulation_steps=config.batch_size // config.computational_batch_size,
    serialization_dir=SER_DIR,
    cuda_device=cdevice,
    patience=1,
    num_epochs=5,
    grad_clipping=5.,
    callbacks=[GCCallback()],
    deindex=True,
    model_save_interval=3600,
)

In [103]:
metrics = trainer.train()

loss: 0.9456 ||: : 11623it [51:17,  3.94it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

loss: 0.9469 ||: : 41611it [3:00:54,  4.39it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

loss: 0.9459 ||: : 54356it [3:54:46,  2.99it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limi

loss: 0.9426 ||: : 345727it [25:07:22,  2.74it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

loss: 0.9426 ||: : 358680it [26:03:59,  4.42it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

loss: 0.9425 ||: : 371641it [26:56:57,  3.62it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_r

KeyboardInterrupt: 

In [104]:
if SER_DIR is not None:
    torch.save(masked_lm.state_dict(), SER_DIR / "best.th")

In [105]:
if SER_DIR is not None:
    masked_lm.load_state_dict(torch.load(SER_DIR / "best.th"))

In [None]:
# trainer = TrainerWithCallbacks(
#     model=masked_lm,
#     optimizer=optimizer,
#     iterator=iterator,
#     train_dataset=train_ds,
#     validation_dataset=val_ds,
#     gradient_accumulation_steps=config.batch_size // config.computational_batch_size,
#     serialization_dir=SER_DIR,
#     cuda_device=cdevice,
#     patience=1,
#     num_epochs=config.epochs + 5,
#     grad_clipping=5.,
#     callbacks=[GCCallback()],
# )

In [None]:
# trainer.train()

In [None]:
cdevice = 0 if torch.cuda.is_available() and (config.char_encoder != "subword" or config.put_embeddings_on_gpu) else -1

batch = nn_util.move_to_device(batch, cdevice)
out_dict =masked_lm(**batch)

For IS loss, we need to aggregate the embeddings at the end of training to make predictions

In [None]:
# del optimizer

In [None]:
# import math

# vocab_sz = vocab.get_vocab_size()
# embedding_sz = bert_config.hidden_size

# if config.loss == "is":
#     bs = 1
#     output_embedding_matrix = torch.zeros(vocab_sz, embedding_sz)
#     num_batches = math.ceil(vocab_sz / bs)
#     for i in range(num_batches):
#         start,end = i*bs, min(((i+1)*bs), vocab_sz)
#         idxs = torch.arange(start=start, end=end).unsqueeze(0)
#         output_embedding_matrix[start:end, :] = loss.get_embeddings(idxs.to(device)).cpu()

# Manually Check Outputs

TODO: Implement manual checks for negative sampling loss as well

In [None]:
def to_np(t): return t.detach().cpu().numpy()

In [None]:
batch

In [None]:
def to_words(arr):
    if len(arr.shape) > 1:
        return [to_words(a) for a in arr]
    else:
        arr = to_np(arr)
        return " ".join([vocab.get_token_from_index(i) for i in arr])

In [None]:
def get_preds(model, batch: TensorDict):
    if config.loss == "is":
        logits = model(**batch)["logits"].cpu()
        return (logits @ output_embedding_matrix.transpose(0, 1)).argmax(2)
    else:
        return model(**batch)["logits"].argmax(2)

In [None]:
masked_lm(**batch)["logits"].shape

In [None]:
def cstm_pprint(x):
    s = "\n\n".join(x)
    print(s)
    with open(SER_DIR / "outputs.txt", "at") as f:
        f.write(s + "\n\n")

In [None]:
cstm_pprint(to_words(batch["output"]["tokens"])[:5])

In [None]:
cstm_pprint(to_words(get_preds(masked_lm, batch)[:5]))

In [None]:
if config.debugging:
    for i in range(3):
        trainer = Trainer(
            model=masked_lm,
            optimizer=optimizer,
            iterator=iterator,
            train_dataset=train_ds,
            validation_dataset=val_ds,
            serialization_dir=SER_DIR,
            cuda_device=0 if torch.cuda.is_available() else -1,
            num_epochs=1,
        )
        trainer.train()
        cstm_pprint(to_words(get_preds(masked_lm, batch))[:5])

In [None]:
masked_lm(**batch)

In [None]:
masked_lm(**batch)

In [None]:
out_dict

In [None]:
cstm_pprint(to_words(get_preds(masked_lm, batch))[:5])

# Predict and Evaluate

In [None]:
masked_lm(**batch)["logits"]