In [None]:
# settings for seamlessly running on colab
import os
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.environ["IS_COLAB"] = "True"
except ImportError:
    os.environ["IS_COLAB"] = "False"

In [None]:
%%bash
if [ "$IS_COLAB" = "True" ]; then
    pip install git+https://github.com/facebookresearch/fastText.git
    pip install torch
    pip install torchvision
    pip install --upgrade git+https://github.com/keitakurita/allennlp@develop
    pip install dnspython
    pip install jupyter_slack
    pip install git+https://github.com/keitakurita/Better_LSTM_PyTorch.git
    if [ -d "apex" ]; then
      git clone https://github.com/NVIDIA/apex.git
    fi
    cd apex && python setup.py install --cpp_ext --cuda_ext
fi

In [None]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline
from overrides import overrides

In [None]:
import time
from contextlib import contextmanager

class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')
    
import functools
import traceback
import sys

def get_ref_free_exc_info():
    "Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory"
    type, val, tb = sys.exc_info()
    traceback.clear_frames(tb)
    return (type, val, tb)

def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = get_ref_free_exc_info() # must!
            raise type(val).with_traceback(tb) from None
    return wrapper

def ifnone(a: Any, alt: Any): return alt if a is None else a

In [None]:
# for papermill
testing = True
debugging = False
seed = 1
char_encoder = "subword"
computational_batch_size = 16
batch_size = 64
loss = "is"
num_neg_samples = 1280
lr = 1e-4
lr_schedule = "slanted_triangular"
epochs = 1 if not testing else 1
hidden_sz = 128 if not testing else 32
num_attention_heads = 4
num_hidden_layers = 4
dataset = "jigsaw"
softmax = "cnn_softmax"
n_classes = 6
max_seq_len = 128
download_data = False
ft_model_path = "../data/jigsaw/ft_model.txt"
max_vocab_size = 300000
dropouti = 0.2
dropoutw = 0.0
dropoute = 0.2
dropoutr = 0.3 # TODO: Implement
val_ratio = 0.0
use_augmented = False
freeze_embeddings = True
mixup_ratio = 0.0
discrete_mixup_ratio = 0.0
attention_bias = True
weight_decay = 0.
bias_init = True
neg_splits = 1
num_layers = 2
rnn_type = "lstm"
pooling_type = "augmented_multipool" # attention or multipool or augmented_multipool
model_type = "standard"
use_word_level_features = True
use_sentence_level_features = True
bucket = True
compute_thres_on_test = True
find_lr = False
permute_sentences = False
run_id = None

In [None]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    char_encoder=char_encoder,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    loss=loss,
    num_neg_samples=num_neg_samples,
    lr=lr,
    lr_schedule=lr_schedule,
    epochs=epochs,
    hidden_sz=hidden_sz,
    num_attention_heads=num_attention_heads,
    num_hidden_layers=num_hidden_layers,
    dataset=dataset,
    softmax=softmax,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    dropouti=dropouti,
    dropoutw=dropoutw,
    dropoute=dropoute,
    dropoutr=dropoutr,
    val_ratio=val_ratio,
    use_augmented=use_augmented,
    freeze_embeddings=freeze_embeddings,
    attention_bias=attention_bias,
    weight_decay=weight_decay,
    bias_init=bias_init,
    neg_splits=neg_splits,
    num_layers=num_layers,
    rnn_type=rnn_type,
    pooling_type=pooling_type,
    model_type=model_type,
    use_word_level_features=use_word_level_features,
    use_sentence_level_features=use_sentence_level_features,
    mixup_ratio=mixup_ratio,
    discrete_mixup_ratio=discrete_mixup_ratio,
    bucket=bucket,
    compute_thres_on_test=compute_thres_on_test,
    permute_sentences=permute_sentences,
    find_lr=find_lr,
    run_id=run_id,
)

In [None]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]  # pylint: disable=invalid-name

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util
from allennlp.data.dataset_readers import DatasetReader

In [None]:
if os.environ["IS_COLAB"] != "True":
    DATA_ROOT = Path("../data") / config.dataset
else:
    DATA_ROOT = Path("./gdrive/My Drive/Colab_Workspace/Colab Notebooks/data") / config.dataset
    config.ft_model_path = str(DATA_ROOT / "ft_model.txt")

In [None]:
import subprocess
if download_data:
    if config.val_ratio > 0.0:
        fnames = ["train_wo_val.csv", "test_proced.csv", "val.csv", "ft_model.txt"]
    else:
        fnames = ["train.csv", "test_proced.csv", "ft_model.txt"]
    if config.use_augmented or config.discrete_mixup_ratio > 0.0: fnames.append("train_extra.csv")
    for fname in fnames:
        if not (DATA_ROOT / fname).exists():
            print(subprocess.Popen([f"aws s3 cp s3://nnfornlp/raw_data/jigsaw/{fname} {str(DATA_ROOT)}"],
                                   shell=True, stdout=subprocess.PIPE).stdout.read())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset

In [None]:
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class MemoryOptimizedTextField(TextField):
    @overrides
    def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
        self.tokens = tokens
        self._token_indexers = token_indexers
        self._indexed_tokens: Optional[Dict[str, TokenList]] = None
        self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
        # skip checks for tokens
    @overrides
    def index(self, vocab):
        super().index(vocab)
        self.tokens = None # empty tokens

In [None]:
class JigsawLMDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer]=None, # TODO: Handle mapping from BERT
                 output_token_indexers: Dict[str, TokenIndexer]=None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.output_token_indexers = output_token_indexers or token_indexers
        self.max_seq_len = max_seq_len
        
    def _clean(self, x: str) -> str:
        """
        Maps a word to its desired output. Will leave as identity for now.
        In the future, will change to denoising operation.
        """
        return x

    @overrides
    def text_to_instance(self, tokens: List[str]) -> Instance:
        sentence_field = MemoryOptimizedTextField(
            [x for x in tokens],
            self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = MemoryOptimizedTextField(
            [self._clean(x) for x in tokens],
            self.output_token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                self.tokenizer(row["comment_text"]),
            )

In [None]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper, ELMoTokenCharactersIndexer
from allennlp.data.token_indexers import SingleIdTokenIndexer

if config.char_encoder == "cnn":
    token_indexer = ELMoTokenCharactersIndexer()
else:
    token_indexer = SingleIdTokenIndexer(
        lowercase_tokens=config.softmax != "cnn_softmax") # Temporary

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

def tokenizer(x: str):
        return ["[CLS]"] + [w.text for w in
                _spacy_tok(x)[:config.max_seq_len - 2]] + ["[SEP]"]

In [None]:
if not isinstance(token_indexer, SingleIdTokenIndexer):
    output_token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) # lowercase for now, we will need to 
else:
    output_token_indexer = token_indexer
    
reader = JigsawLMDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer},
    output_token_indexers={"words": output_token_indexer}
)
train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                          "val.csv",
                                                                          "test_proced.csv"])

In [None]:
vars(train_ds[2].fields["input"])

In [None]:
vars(train_ds[2].fields["output"])

### Build Vocab

In [None]:
from allennlp.data.vocabulary import Vocabulary

In [None]:
full_ds = train_ds + test_ds + val_ds
vocab = Vocabulary.from_instances(full_ds, tokens_to_add={"tokens": ["[MASK]"]},
                                  max_vocab_size=config.max_vocab_size)

In [None]:
vocab.get_vocab_size()

In [None]:
config.set("vocab_sz", vocab.get_vocab_size())

### Build frequencies

TODO: Implement fast sampling

In [None]:
if config.loss == "is":
    freqs = np.zeros(vocab.get_vocab_size())
    for w, c in vocab._retained_counter["tokens"].items():
        freqs[vocab.get_token_index(w)] = c
    freqs /= freqs.sum()
    freqs **= 2 / 3
    # renormalize
    freqs /= freqs.sum()

### Build Iterator

In [None]:
from allennlp.data.iterators import BucketIterator, DataIterator, BasicIterator, MemoryOptimizedIterator
if config.bucket:
    iterator = BucketIterator(
            batch_size=config.computational_batch_size, 
            biggest_batch_first=config.testing,
            sorting_keys=[("input", "num_tokens")],
            max_instances_in_memory=config.batch_size * 2,
    )
else:
    iterator = MemoryOptimizedIterator(
        batch_size=config.computational_batch_size,
        max_instances_in_memory=config.batch_size * 2,
    )
iterator.index_with(vocab)

In [None]:
next(iterator(train_ds))["input"]["tokens"].shape

In [None]:
batch = next(iterator(train_ds))

In [None]:
batch

### Build word to indices mapping

In [None]:
import fnv

TODO: Output this constructed dictionary to disk and load in the future

In [None]:
def fnv_hash(w):
    return fnv.hash(w.encode("utf-8"), bits=32)

In [None]:
def generate_char_ngrams(w, range_=range(2, 6)):
    w = "<" + w + ">"
    for r in range_: # prioritize smaller n-grams
        for i, c in enumerate(w):
            if i + r <= len(w): yield w[i:i+r]

In [None]:
if config.char_encoder == "cnn":
    from tqdm import tqdm
    # TODO: Speed up
    # TODO: Debug
    # See allennlp/data/token_indexers/elmo_indexer.py
    with timer("Building character indexes"):
        word_id_to_char_idxs = np.zeros((config.vocab_sz, 50))
        for w, idx in tqdm(vocab.get_token_to_index_vocabulary().items()):
            # TODO: Check for start/end of word symbols
            if idx > 0: 
                word_id_to_char_idxs[idx, :] = 261
                word_id_to_char_idxs[idx, 0] = 259
                for i, c in enumerate(w.encode("utf-8")):
                    if i + 1 == 48: break
                    word_id_to_char_idxs[idx, i+1] = int(c) + 1
                word_id_to_char_idxs[idx, i+2] = 260
        word_id_to_char_idxs = np.array(word_id_to_char_idxs)

    word_id_to_char_idxs = torch.LongTensor(word_id_to_char_idxs)
elif config.char_encoder == "subword":
    config.set("num_buckets", 50000) # TODO: add to configurable parameters
    offset = vocab.get_vocab_size()
    with timer("Building word to subword indices mapping"):
        subword_ids = [[] for _ in vocab.get_token_to_index_vocabulary()]
        for word, idx in vocab.get_token_to_index_vocabulary().items():
            if idx < 2 or word == "[MASK]":
                subword_ids[idx] = [idx]
            else:
                subword_ids[idx] = [idx] + list([fnv_hash(x) % config.num_buckets + offset
                                                 for x in generate_char_ngrams(word)])
        maxlen = int(np.percentile([len(a) for a in subword_ids], 99)) # 99th-percentile
        word_id_to_subword_ids = torch.zeros(len(subword_ids), maxlen, dtype=torch.long).to(device)
        for i, idxs in enumerate(subword_ids):
            for j, id_ in enumerate(idxs):
                if j >= maxlen: break
                word_id_to_subword_ids[i, j] = id_

In [None]:
if config.char_encoder == "cnn":
    mask_char_ids = torch.ones(50, dtype=torch.int64).to(device) * 261
    for i, c in enumerate("[MASK]".encode("utf-8")):
        mask_char_ids[i+1] = int(c) + 1
    mask_char_ids[i+2] = 260

# Bert configuration

In [None]:
from pytorch_pretrained_bert.modeling import (BertConfig, BertForMaskedLM, 
                                              BertEncoder, BertPooler, BertOnlyMLMHead)

bert_config = BertConfig(
        config.max_vocab_size, 
        hidden_size=config.hidden_sz, 
        num_attention_heads=config.num_attention_heads,
        num_hidden_layers=config.num_hidden_layers, 
        intermediate_size=config.hidden_sz * config.num_attention_heads,
        max_position_embeddings=config.max_seq_len,
)

bert_config

# Building token embedder

In [None]:
if config.char_encoder == "cnn":
    from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
    from allennlp.modules.elmo import _ElmoCharacterEncoder

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
    _inner_char_encoder = _ElmoCharacterEncoder(
        options_file=options_file, 
        weight_file=weight_file,
        requires_grad=True
    )
    class ElmoEncoder(nn.Module):
        def __init__(self, _inner):
            super().__init__()
            self._inner = _inner
        def forward(self, *args):
            # TODO: Stop Elmo encoder from adding SoS and EoS tokens
            return self._inner(*args)["token_embedding"][:, 1:-1, :]
        def get_output_dim(self):
            return self._inner.get_output_dim()
    char_encoder = ElmoEncoder(_inner_char_encoder)

# char_encoder

# sample_idxs = next(iterator(train_ds))["tokens"]["tokens"]

# char_encoder(sample_idxs)
    config.set("embedding_sz", char_encoder.get_output_dim())

Bag of char-ngrams using fastText (too much memory consumption for now...)

In [None]:
from torch.nn.modules.sparse import EmbeddingBag

In [None]:
if config.char_encoder == "subword":
    from torch.nn.modules.sparse import EmbeddingBag

    class SubwordEncoder(nn.Module):
        def __init__(self, num_embeddings, embedding_sz):
            super().__init__()
            self._subword_encoding = word_id_to_subword_ids
            self.bag = EmbeddingBag(num_embeddings, embedding_sz, mode="sum")
            self.n_subwords_per_word = self._subword_encoding.size(1)

        def forward(self, 
                    tsr: torch.LongTensor, # (batch, seq) or # (batch)
                   ) -> torch.FloatTensor: # (batch, seq, feat) or # (batch, feat)
            # TODO: can I use offsets in a differentiable manner??
            if len(tsr.shape) > 1:
                bs, seq = tsr.size(0), tsr.size(1)
                subword_ids = self._subword_encoding[tsr]
                bag_of_embs = self.bag(subword_ids.view((-1, self.n_subwords_per_word)) # need to convert to 2D
                                      ).view((bs, seq, -1)) # reshape to 3d
            else:
                subword_ids = self._subword_encoding[tsr]
                bag_of_embs = self.bag(subword_ids)
            norm_factor = (subword_ids > 0).float().sum(dim=-1, keepdim=True) + 0.1 # add a bit of smoothings
            return bag_of_embs / norm_factor
        
    char_encoder = SubwordEncoder(vocab.get_vocab_size() + config.num_buckets, 300)
    config.set("embedding_sz", 300)

Simple word-level embeddings

In [None]:
if config.char_encoder == "fasttext":
    from allennlp.modules import Embedding

    ft_matrix = np.random.randn(bert_config.vocab_size, 300) * 0.3
    char_encoder = Embedding(bert_config.vocab_size, 300, weight=torch.FloatTensor(ft_matrix))
    config.set("embedding_sz", 300)

In [None]:
def freeze(x):
    x.requires_grad = False
    if hasattr(x, "parameters"):
        for p in x.parameters: freeze(p)

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

In [None]:
class EmbeddingWPositionEmbs(nn.Module):
    """Embeds, then maps the embeddings to the bert_hidden_sz for processing"""
    def __init__(self, word_emb: nn.Module, 
                 embedding_dim,
                 bert_hidden_sz, 
                 freeze_embeddings=False,
                 dropout=0.1):
        super().__init__()
        self.word_emb = word_emb
        if freeze_embeddings: freeze(self.word_emb)
        self.position_embeddings = nn.Embedding(config.max_seq_len, 
                                                bert_hidden_sz)
        self.linear = nn.Linear(embedding_dim, bert_hidden_sz,
                                bias=False) # Transform dimensions
        self.norm = LayerNorm(bert_hidden_sz)
        self.do = nn.Dropout(0.1)
    
    def get_word_embs(self, input_ids):
        return self.linear(self.word_emb(input_ids))
    
    def forward(self, input_ids):
        # We won't be using token types since we won't be predicting the next sentence
        bs, seq_length, *_ = input_ids.shape
        position_ids = (torch.arange(seq_length, dtype=torch.long)
                             .to(input_ids.device)
                             .unsqueeze(0)
                             .expand((bs, seq_length)))
        word_embs = self.get_word_embs(input_ids)
        position_embs = self.position_embeddings(position_ids)
        return self.do(self.norm(word_embs + position_embs))

In [None]:
sample_embs = EmbeddingWPositionEmbs(
    char_encoder,
    config.embedding_sz,
    bert_config.hidden_size,
)

In [None]:
if torch.cuda.is_available(): sample_embs.cuda()

# Masked Language Model

### The encoder

In [None]:
bert_encoder = BertEncoder(bert_config)

In [None]:
class CustomBert(nn.Module):
    def __init__(self, embeddings, encoder):
        super().__init__()
        self.embeddings = embeddings
        self.encoder = encoder
        
    def forward(self, input_ids, 
                token_type_ids=None, 
                attention_mask=None):
        if attention_mask is None:
            if len(input_ids.shape) > 2:
                attention_mask = torch.ones_like(input_ids[:, :, 0])
            else:
                attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            if len(input_ids.shape) > 2:
                token_type_ids = torch.ones_like(input_ids[:, :, 0])
            else:
                token_type_ids = torch.ones_like(input_ids)
        
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=True)
        return encoded_layers

In [None]:
class BertMLMPooler(nn.Module):
    def forward(self, x: List[torch.FloatTensor]) -> torch.FloatTensor:
        return x[-1] # return final layer only

In [None]:
bert_model = CustomBert(sample_embs, bert_encoder)

In [None]:
import math
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertCustomLMPredictionHead(nn.Module):
    def __init__(self, config, out_sz, vocab_sz, 
                 embedding, output_logits=True):
        super().__init__()
        # Projections
        self.dense = nn.Linear(config.hidden_size, out_sz)
        self.transform_act_fn = gelu
        self.LayerNorm = LayerNorm(out_sz)
        
        # Predictions
        self.output_logits = output_logits
        if self.output_logits:
            self.decoder = nn.Linear(out_sz, vocab_sz, 
                                     bias=False)
            if embedding is not None:
                self.decoder.weight = embedding
            self.bias = nn.Parameter(torch.zeros(vocab_sz))

    def forward(self, hidden_states):
        if self.output_logits:
            hidden_states = self.dense(hidden_states)
            hidden_states = self.transform_act_fn(hidden_states)
            preds = self.LayerNorm(hidden_states)
            preds = self.decoder(preds) + self.bias
        else:
            preds = hidden_states
        return preds

In [None]:
output_logits = config.loss != "is"
bert_mlm_head = BertCustomLMPredictionHead(config=bert_config, 
                                           out_sz=config.embedding_sz, 
                                           vocab_sz=config.vocab_sz,
                                           embedding=(sample_embs.word_emb.weight 
                                                      if config.char_encoder == "fasttext" 
                                                      else None),
                                           output_logits=output_logits,
)

In [None]:
custom_model = nn.Sequential(
    bert_model,
    BertMLMPooler(),
    bert_mlm_head,
)

### The decoder

In [None]:
class ElmoDecoder(nn.Module):
    """TODO: Add word correction"""
    def __init__(self, enc: nn.Module, dec: nn.Linear, word_correction_dim: int=0):
        super().__init__()
        self._enc = enc
        self._dec = dec
        self.word_correction_dim = word_correction_dim
        if word_correction_dim > 0:
            self.word_correction = nn.Embedding(config.max_vocab_size,
                                                word_correction_dim)
            self.back_projection = nn.Linear(word_correction_dim, self._dec.out_features, 
                                             bias=False)
        
    def forward(self, idxs):
        if len(idxs.shape) == 1: idxs = idxs.unsqueeze(0)
        char_idxs = word_id_to_char_idxs[idxs]
        output = self._dec(self._enc(char_idxs)).squeeze(0) 
        if self.word_correction_dim > 0:
            corr = self.back_projection(self.word_correction(idxs.squeeze(0)))
            return output + corr
        else: return output

In [None]:
if config.char_encoder == "fasttext":
    output_embs = sample_embs.get_word_embs
elif config.char_encoder == "cnn" and config.softmax == "cnn_softmax" and config.loss == "is":
    _inner_char_decoder = _ElmoCharacterEncoder(
            options_file=options_file, 
            weight_file=weight_file,
            requires_grad=True
    )
    char_decoder = ElmoEncoder(_inner_char_decoder)
    # share just the linear transformation with 
    output_embs = ElmoDecoder(char_decoder, sample_embs.linear)
    if torch.cuda.is_available(): output_embs.cuda()
elif config.char_encoder == "subword":
    output_embs = sample_embs.get_word_embs # tie input and output weights
else:
    from allennlp.modules import Embedding
    mtrx = None
    output_embs = Embedding(config.vocab_sz, bert_config.hidden_size, weight=mtrx)
    if torch.cuda.is_available(): output_embs.cuda()

In [None]:
class WordCorrection(nn.Module):
    """From the paper `Exploring the Limitations of Language Modeling`"""
    def __init__(self, hidden_sz: int, bottleneck_sz: int=128):
        super().__init__()
        self.l1 = nn.Linear(hidden_sz, bottleneck_sz)
        
    def forward(self, h: torch.FloatTensor, 
                corr: torch.FloatTensor):
        x = self.l1(h)
        return x @ corr

# Loss Functions

Masked Cross Entropy

In [None]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self._loss = nn.CrossEntropyLoss(reduction="none")

    def forward(self, preds, tgts, mask=None) -> torch.tensor:
        if mask is None:
            return self._loss(preds, tgts).mean()
        else:
            # Is this reshaping really necessary? Seems like there would be a more elegant solution
            loss = self._loss(preds.view((-1, preds.size(-1))),
                              tgts.view((-1, )))
            n_elements = mask.sum()
            return (loss * mask.view((-1, )).float()).sum() / n_elements

Importance Sampling

In [None]:
class UniformSampler:
    def __init__(self, min_, max_):
        self.min_, self.max_ = min_, max_
    
    def sample(self, shape):
        return torch.randint(low=self.min_,
                             high=self.max_, size=shape)

class UnigramSampler:
    def __init__(self, probs):
        self.probs = probs
    @staticmethod
    def _prod(x):
        acc = 1
        for a in x: acc *= a
        return acc
    def sample(self, shape):
        return torch.multinomial(self.probs, self._prod(shape), replacement=True).view(shape)

In [None]:
class ImportanceSamplingLoss(nn.Module):
    def __init__(self, embedding_generator,
                 probs: np.ndarray, k=config.num_neg_samples):
        super().__init__()
        self.embedding_generator = embedding_generator
        # TODO: Should this be according to the unigram probability?
        # Or should it be uniform?
        self.sampler = UnigramSampler(probs=torch.FloatTensor(probs))
        # TODO: Compute samples in advance
        self._loss_func = MaskedCrossEntropyLoss()
        self.k = k
    
    def get_negative_samples(self) -> torch.LongTensor:
        neg = self.sampler.sample((self.k, )) # TODO: Speed up??
        return neg
    
    def get_embeddings(self, idxs: torch.LongTensor) -> torch.FloatTensor:
        """Converts indexes into vectors"""
        return self.embedding_generator(idxs) # TODO: Implement general case
    
    def forward(self, y: torch.LongTensor, tgt, mask=None):
        """
        Expects input of shape
        y: (batch * seq, feature_sz)
        tgt: (batch * seq, )
        """
        if len(y.shape) > 2:            
            y = y.view((-1, y.size(-1))) # (batch * seq, emb_sz)
            tgt = tgt.view((-1, )) # (batch * seq, s)
        bs, emb_sz = y.size(0), y.size(1)
        pos_embeddings = self.get_embeddings(tgt) # (bs, emb_sz)
        # share negative samples across the batch
        neg_samples = (self.get_negative_samples()
                       .to(y.device)) # (k, )
        neg_embeddings = self.get_embeddings(neg_samples) # (k, emb_sz)
        embs = torch.cat([
            pos_embeddings.unsqueeze(1), # (bs, 1, emb_sz)
            neg_embeddings.unsqueeze(0).expand(bs, self.k, emb_sz) # (bs, k, emb_sz)
        ], dim=1) # (bs, k+1, emb_sz)
        dot_prods = torch.einsum("bkf,bf->bk", embs, y)
        return self._loss_func(dot_prods, torch.zeros(bs, dtype=torch.int64).to(y.device),
                               mask=mask)

Testing

In [None]:
emb_sz = bert_config.hidden_size
y = torch.randn((3, 7, emb_sz)).view((-1, emb_sz)).to(device)
tgt = torch.randint(100, (3, 7)).view((-1, )).to(device)

In [None]:
tgt

In [None]:
word_id_to_subword_ids[tgt].max()

In [None]:
loss = ImportanceSamplingLoss(
    output_embs, k=10, probs=torch.rand(100),
)
loss(y, tgt)

# Training

In [None]:
from allennlp.models import Model

In [None]:
class Masker(nn.Module):
    def __init__(self, vocab: Vocabulary, 
                 noise_rate: float=0.15,
                 mask_rate=0.8,
                 replace_rate=0.1):
        super().__init__()
        self.vocab = vocab
        self.vocab_sz = vocab.get_vocab_size()
        if config.char_encoder == "cnn":
            self.mask_id = mask_char_ids.unsqueeze(0).unsqueeze(1)
        else:
            self.mask_id = vocab.get_token_index("[MASK]")
        self.noise_rate = noise_rate
        self.mask_rate = mask_rate
        self.replace_rate = replace_rate
        
    def create_mask(self, shape, ones_ratio, dtype=torch.uint8):
        return (torch.ones(shape, dtype=dtype, 
                           requires_grad=False)
                     .bernoulli(ones_ratio))

    def get_random_input_ids(self, shape):
        """Returns randomly sampled """
        if config.char_encoder == "cnn":
            rint = torch.randint(self.vocab_sz, shape)
            return word_id_to_char_idxs[rint]
        elif config.char_encoder == "fasttext" or config.char_encoder == "subword":
            return torch.randint(self.vocab_sz, shape)

    def forward(self, x: torch.LongTensor) -> torch.LongTensor:
        char_level = len(x.shape) > 2 # using character-level features, but mask at word-level
        if self.noise_rate > 0:
            with torch.no_grad(): # no grads required here
                mask_shape = x.shape[:-1] if char_level else x.shape
                mask = self.create_mask(mask_shape, 
                                        self.noise_rate * self.mask_rate).to(x.device)
                if config.char_encoder == "cnn":
                    x = torch.where(mask.unsqueeze(2), self.mask_id, x)                
                else:
                    x = x.masked_fill(mask, self.mask_id)
                
                if self.replace_rate > 0.:
                    # this is techinically incorrect, since we might overwrite the mask tokens
                    # but I guess it will do for now
                    mask = self.create_mask(mask_shape,
                                            self.noise_rate * self.replace_rate).to(x.device)
                    x = torch.where(mask.unsqueeze(2) if config.char_encoder == "cnn" else mask, 
                                    self.get_random_input_ids(mask_shape).to(x.device),
                                    x)
        return x

In [None]:
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy

class MaskedLM(Model):
    def __init__(self, vocab: Vocabulary, model: nn.Module,
                loss: nn.Module, noise_rate=0.8):
        super().__init__(vocab)
        self.masker = Masker(vocab, noise_rate=noise_rate)
        self.accuracy = CategoricalAccuracy()
        self.model = model
        self.loss = loss
    
    @property
    def outputs_logits(self) -> bool:
        return self.model[-1].output_logits
    
    def forward(self, input: TensorDict, 
                output: TensorDict, **kwargs) -> TensorDict:
        mask = get_text_field_mask(input)
        x = self.masker(input["tokens"])
        tgt = output["words"]
        
        logits = self.model(x)
        out_dict = {"loss": self.loss(logits, tgt, mask=mask)}
        out_dict["logits"] = logits
        if self.outputs_logits:
            out_dict["accuracy"] = self.accuracy(logits, tgt)
        return out_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.outputs_logits:
            return {"accuracy": self.accuracy.get_metric(reset)}
        else:
            return {}

In [None]:
if config.loss == "masked_crossentropy":
    loss = MaskedCrossEntropyLoss()
elif config.loss == "crossentropy":
    _loss = nn.CrossEntropyLoss()
    def ce(y, t, mask=None): 
        return _loss(y.view((-1, y.size(-1))), t.view((-1, )))
    loss = ce
elif config.loss == "is":
    # TODO: Implement masking
    loss = ImportanceSamplingLoss(output_embs, freqs)
else:
    raise ValueError("AAAAAAAAAAAAA")
masked_lm = MaskedLM(vocab, custom_model, loss, noise_rate=0.15)
if torch.cuda.is_available(): masked_lm.cuda() # Is this different from to(device)?

Sanity checks

In [None]:
if config.testing:
    batch = nn_util.move_to_device(batch, 0 if torch.cuda.is_available() else -1)

    tokens = masked_lm.masker(batch["input"]["tokens"])

    hidden_states = masked_lm.model[:2](tokens)

    masked_lm.model[2](hidden_states)

In [None]:
if config.testing: print(masked_lm(**batch))

Sanity checks

In [None]:
from allennlp.training import Callback
import gc

class GCCallback(Callback):
    """Calls gc periodically to prevent memory errors"""
    def __init__(self, period: int=1):
        self._period = period
        
    def on_batch_end(self, data):
        if (data["batches_this_epoch"] + 1) % self._period == 0:
            gc.collect()

In [None]:
from allennlp.training import Trainer, TrainerWithCallbacks

optimizer = torch.optim.Adam(masked_lm.parameters(), lr=config.lr)

trainer = TrainerWithCallbacks(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    gradient_accumulation_steps=config.batch_size // config.computational_batch_size,
    serialization_dir=DATA_ROOT / "bert_ckpts_cnn_softmax" if not config.testing else None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=config.epochs,
    callbacks=[GCCallback()],
)

In [None]:
from allennlp.commands.find_learning_rate import search_learning_rate, _save_plot

In [None]:
if config.testing:
    pass
#     lrs_, losses_ = search_learning_rate(trainer, num_batches=300)

#     n = -100
#     plt.ylabel("loss")
#     plt.xlabel('learning rate (log10 scale)')
#     plt.xscale('log')
#     plt.plot(lrs_[:n], losses_[:n])

In [None]:
trainer.train()

In [None]:
batch = nn_util.move_to_device(batch, 0 if torch.cuda.is_available() else -1)
masked_lm(**batch)["logits"].argmax(2)

For IS loss, we need to aggregate the embeddings at the end of training to make predictions

In [None]:
import math

vocab_sz = vocab.get_vocab_size()
embedding_sz = bert_config.hidden_size

if config.loss == "is":
    bs = 16
    output_embedding_matrix = torch.zeros(vocab_sz, embedding_sz)
    num_batches = math.ceil(vocab_sz / bs)
    for i in range(num_batches):
        start,end = i*bs, min(((i+1)*bs), vocab_sz)
        idxs = torch.arange(start=start, end=end).unsqueeze(0)
        output_embedding_matrix[start:end, :] = loss.get_embeddings(idxs.to(device)).cpu()

# Manually Check Outputs

TODO: Implement manual checks for negative sampling loss as well

In [None]:
def to_np(t): return t.detach().cpu().numpy()

In [None]:
def to_words(arr):
    if len(arr.shape) > 1:
        return [to_words(a) for a in arr]
    else:
        arr = to_np(arr)
        return " ".join([vocab.get_token_from_index(i) for i in arr])

In [None]:
def get_preds(model, batch: TensorDict):
    if config.loss == "is":
        logits = model(**batch)["logits"].cpu()
        return (logits @ output_embedding_matrix.transpose(0, 1)).argmax(2)
    else:
        return model(**batch)["logits"].argmax(2)

In [None]:
masked_lm(**batch)["logits"].shape

In [None]:
def cstm_pprint(x):
    print("\n\n".join(x))

In [None]:
cstm_pprint(to_words(batch["output"]["words"])[:5])

In [None]:
cstm_pprint(to_words(get_preds(masked_lm, batch)[:5]))

In [None]:
if config.debugging:
    for i in range(3):
        trainer = Trainer(
            model=masked_lm,
            optimizer=optimizer,
            iterator=iterator,
            train_dataset=train_ds,
            validation_dataset=val_ds,
            serialization_dir=None,
            cuda_device=0 if torch.cuda.is_available() else -1,
            num_epochs=1,
        )
        trainer.train()
        cstm_pprint(to_words(get_preds(masked_lm, batch))[:5])

# Predict and Evaluate