In [1]:
# settings for seamlessly running on colab
import os
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.environ["IS_COLAB"] = "True"
except ImportError:
    os.environ["IS_COLAB"] = "False"

In [2]:
%%bash
if [ "$IS_COLAB" = "True" ]; then
    pip install git+https://github.com/facebookresearch/fastText.git
    pip install torch
    pip install torchvision
    pip install allennlp
    pip install dnspython
    pip install jupyter_slack
    pip install git+https://github.com/keitakurita/Better_LSTM_PyTorch.git
fi

In [3]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline
from overrides import overrides

In [4]:
import time
from contextlib import contextmanager

class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')
    
import functools
import traceback
import sys

def get_ref_free_exc_info():
    "Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory"
    type, val, tb = sys.exc_info()
    traceback.clear_frames(tb)
    return (type, val, tb)

def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = get_ref_free_exc_info() # must!
            raise type(val).with_traceback(tb) from None
    return wrapper

def ifnone(a: Any, alt: Any): return alt if a is None else a

In [5]:
# for papermill
testing = True
debugging = True
seed = 1
char_encoder = "cnn"
computational_batch_size = 64
batch_size = 64
loss = "is"
lr = 4e-3
lr_schedule = "slanted_triangular"
epochs = 6 if not testing else 1
hidden_sz = 128
dataset = "jigsaw"
n_classes = 6
max_seq_len = 128
download_data = False
ft_model_path = "../data/jigsaw/ft_model.txt"
max_vocab_size = 40000
dropouti = 0.2
dropoutw = 0.0
dropoute = 0.2
dropoutr = 0.3 # TODO: Implement
val_ratio = 0.0
use_augmented = False
freeze_embeddings = True
mixup_ratio = 0.0
discrete_mixup_ratio = 0.0
attention_bias = True
weight_decay = 0.
bias_init = True
neg_splits = 1
num_layers = 2
rnn_type = "lstm"
pooling_type = "augmented_multipool" # attention or multipool or augmented_multipool
model_type = "standard"
use_word_level_features = True
use_sentence_level_features = True
bucket = True
compute_thres_on_test = True
find_lr = False
permute_sentences = False
run_id = None

In [6]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    char_encoder=char_encoder,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    loss=loss,
    lr=lr,
    lr_schedule=lr_schedule,
    epochs=epochs,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    dropouti=dropouti,
    dropoutw=dropoutw,
    dropoute=dropoute,
    dropoutr=dropoutr,
    val_ratio=val_ratio,
    use_augmented=use_augmented,
    freeze_embeddings=freeze_embeddings,
    attention_bias=attention_bias,
    weight_decay=weight_decay,
    bias_init=bias_init,
    neg_splits=neg_splits,
    num_layers=num_layers,
    rnn_type=rnn_type,
    pooling_type=pooling_type,
    model_type=model_type,
    use_word_level_features=use_word_level_features,
    use_sentence_level_features=use_sentence_level_features,
    mixup_ratio=mixup_ratio,
    discrete_mixup_ratio=discrete_mixup_ratio,
    bucket=bucket,
    compute_thres_on_test=compute_thres_on_test,
    permute_sentences=permute_sentences,
    find_lr=find_lr,
    run_id=run_id,
)

In [7]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]  # pylint: disable=invalid-name

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util
from allennlp.data.dataset_readers import DatasetReader

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [8]:
if os.environ["IS_COLAB"] != "True":
    DATA_ROOT = Path("../data") / config.dataset
else:
    DATA_ROOT = Path("./gdrive/My Drive/Colab_Workspace/Colab Notebooks/data") / config.dataset
    config.ft_model_path = str(DATA_ROOT / "ft_model.txt")

In [9]:
import subprocess
if download_data:
    if config.val_ratio > 0.0:
        fnames = ["train_wo_val.csv", "test_proced.csv", "val.csv", "ft_model.txt"]
    else:
        fnames = ["train.csv", "test_proced.csv", "ft_model.txt"]
    if config.use_augmented or config.discrete_mixup_ratio > 0.0: fnames.append("train_extra.csv")
    for fname in fnames:
        if not (DATA_ROOT / fname).exists():
            print(subprocess.Popen([f"aws s3 cp s3://nnfornlp/raw_data/jigsaw/{fname} {str(DATA_ROOT)}"],
                                   shell=True, stdout=subprocess.PIPE).stdout.read())

# Dataset

In [10]:
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class JigsawLMDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer]=None, # TODO: Handle mapping from BERT
                 output_token_indexers: Dict[str, TokenIndexer]=None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.output_token_indexers = output_token_indexers or token_indexers
        self.max_seq_len = max_seq_len
        
    def _clean(self, x: str) -> str:
        """
        Maps a word to its desired output. Will leave as identity for now.
        In the future, will change to denoising operation.
        """
        return x

    @overrides
    def text_to_instance(self, tokens: List[str]) -> Instance:
        sentence_field = TextField([Token(x) for x in tokens],
                                   self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = TextField([Token(self._clean(x)) for x in tokens],
                                          self.output_token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                self.tokenizer(row["comment_text"]),
            )

In [11]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper, ELMoTokenCharactersIndexer
from allennlp.data.token_indexers import SingleIdTokenIndexer

if config.char_encoder == "cnn":
    token_indexer = ELMoTokenCharactersIndexer()
else:
    token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) # Temporary

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

def tokenizer(x: str):
        return ["[CLS]"] + [w.text for w in
                _spacy_tok(x)[:config.max_seq_len - 2]] + ["[SEP]"]

In [12]:
if not isinstance(token_indexer, SingleIdTokenIndexer):
    output_token_indexer = SingleIdTokenIndexer(lowercase_tokens=False)
else:
    output_token_indexer = token_indexer
reader = JigsawLMDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer},
    output_token_indexers={"words": output_token_indexer}
)
train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                          "val.csv",
                                                                          "test_proced.csv"])

1000it [00:07, 129.51it/s]
1000it [00:05, 170.73it/s]
1000it [00:09, 100.95it/s]


In [13]:
vars(train_ds[2].fields["input"])

{'tokens': [[CLS],
  How,
  can,
  my,
  comment,
  on,
  my,
  own,
  talk,
  page,
  WP,
  :,
  DISRUPT,
  wikipedia,
  ?,
  [SEP]],
 '_token_indexers': {'tokens': <allennlp.data.token_indexers.elmo_indexer.ELMoTokenCharactersIndexer at 0x12b37db38>},
 '_indexed_tokens': None,
 '_indexer_name_to_indexed_token': None}

In [14]:
vars(train_ds[2].fields["output"])

{'tokens': [[CLS],
  How,
  can,
  my,
  comment,
  on,
  my,
  own,
  talk,
  page,
  WP,
  :,
  DISRUPT,
  wikipedia,
  ?,
  [SEP]],
 '_token_indexers': {'words': <allennlp.data.token_indexers.single_id_token_indexer.SingleIdTokenIndexer at 0x12b37da58>},
 '_indexed_tokens': None,
 '_indexer_name_to_indexed_token': None}

### Build Vocab

In [15]:
from allennlp.data.vocabulary import Vocabulary

In [16]:
full_ds = train_ds + test_ds + val_ds
vocab = Vocabulary.from_instances(full_ds, tokens_to_add={"tokens": ["[MASK]"]},
                                  max_vocab_size=config.max_vocab_size)

100%|██████████| 3000/3000 [00:00<00:00, 5661.09it/s]


In [17]:
vocab.get_vocab_size()

17977

In [18]:
config.set("vocab_sz", vocab.get_vocab_size())

### Build frequencies

TODO: Implement fast sampling

In [19]:
if config.char_encoder == "cnn" and config.loss == "is":
    freqs = np.zeros(vocab.get_vocab_size())
    for w, c in vocab._retained_counter["tokens"].items():
        freqs[vocab.get_token_index(w)] = c
    freqs /= freqs.sum()
    freqs **= 2 / 3

### Build Iterator

In [20]:
from allennlp.data.iterators import BucketIterator, DataIterator
iterator = BucketIterator(
        batch_size=config.batch_size, 
        biggest_batch_first=config.testing,
        sorting_keys=[("input", "num_tokens")],
)
iterator.index_with(vocab)

In [21]:
next(iterator(train_ds))["input"]["tokens"].shape

torch.Size([40, 128, 50])

In [22]:
batch = next(iterator(train_ds))

In [23]:
batch

{'input': {'tokens': tensor([[[259,  92,  68,  ..., 261, 261, 261],
           [259,  35, 260,  ..., 261, 261, 261],
           [259,  74, 116,  ..., 261, 261, 261],
           ...,
           [259,  50,  52,  ..., 261, 261, 261],
           [259,  35, 260,  ..., 261, 261, 261],
           [259,  92,  84,  ..., 261, 261, 261]],
  
          [[259,  92,  68,  ..., 261, 261, 261],
           [259,  74, 117,  ..., 261, 261, 261],
           [259, 116, 109,  ..., 261, 261, 261],
           ...,
           [259,  74, 103,  ..., 261, 261, 261],
           [259, 115, 102,  ..., 261, 261, 261],
           [259,  92,  84,  ..., 261, 261, 261]],
  
          [[259,  92,  68,  ..., 261, 261, 261],
           [259,  35, 260,  ..., 261, 261, 261],
           [259,  81, 115,  ..., 261, 261, 261],
           ...,
           [259, 106, 111,  ..., 261, 261, 261],
           [259,  99, 122,  ..., 261, 261, 261],
           [259,  92,  84,  ..., 261, 261, 261]],
  
          ...,
  
          [[259,  92,

### Build word to indices mapping

In [24]:
if config.char_encoder == "cnn":
    from tqdm import tqdm
    # TODO: Speed up
    # TODO: Debug
    # See allennlp/data/token_indexers/elmo_indexer.py
    with timer("Building character indexes"):
        word_id_to_char_idxs = np.zeros((config.vocab_sz, 50))
        freqs = []
        for w, idx in tqdm(vocab.get_token_to_index_vocabulary().items()):
            # TODO: Check for start/end of word symbols
            if idx > 0: 
                word_id_to_char_idxs[idx, :] = 261
                word_id_to_char_idxs[idx, 0] = 259
                for i, c in enumerate(w.encode("utf-8")):
                    if i + 1 == 48: break
                    word_id_to_char_idxs[idx, i+1] = int(c) + 1
                word_id_to_char_idxs[idx, i+2] = 260
        word_id_to_char_idxs = np.array(word_id_to_char_idxs)

    word_id_to_char_idxs = torch.LongTensor(word_id_to_char_idxs)

100%|██████████| 17977/17977 [00:00<00:00, 42401.28it/s]

[Building character indexes] done in 0 s





In [25]:
if config.char_encoder == "cnn":
    mask_char_ids = torch.ones(50, dtype=torch.int64) * 261
    for i, c in enumerate("[MASK]".encode("utf-8")):
        mask_char_ids[i+1] = int(c) + 1
    mask_char_ids[i+2] = 260

Testing

In [26]:
if config.char_encoder == "cnn":
    print((batch["input"]["tokens"][0] == word_id_to_char_idxs[batch["output"]["words"][0]]).all().item())

1


# Bert configuration

In [27]:
from pytorch_pretrained_bert.modeling import (BertConfig, BertForMaskedLM, 
                                              BertEncoder, BertPooler, BertOnlyMLMHead)

bert_config = BertConfig(
        config.max_vocab_size, hidden_size=32, num_attention_heads=4,
        num_hidden_layers=4, intermediate_size=32 * 4,
)

bert_config

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 32,
  "initializer_range": 0.02,
  "intermediate_size": 128,
  "max_position_embeddings": 512,
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "type_vocab_size": 2,
  "vocab_size": 40000
}

# Building token embedder

In [28]:
if config.char_encoder == "cnn":
    from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
    from allennlp.modules.elmo import _ElmoCharacterEncoder

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
    _inner_char_encoder = _ElmoCharacterEncoder(
        options_file=options_file, 
        weight_file=weight_file,
        requires_grad=True
    )
    class ElmoEncoder(nn.Module):
        def __init__(self, _inner):
            super().__init__()
            self._inner = _inner
        def forward(self, *args):
            # TODO: Stop Elmo encoder from adding SoS and EoS tokens
            return self._inner(*args)["token_embedding"][:, 1:-1, :]
        def get_output_dim(self):
            return self._inner.get_output_dim()
    char_encoder = ElmoEncoder(_inner_char_encoder)

# char_encoder

# sample_idxs = next(iterator(train_ds))["tokens"]["tokens"]

# char_encoder(sample_idxs)
    config.set("embedding_sz", char_encoder.get_output_dim())

Bag of char-ngrams using fastText (too much memory consumption for now...)

In [29]:
if config.char_encoder == "boc":
    from fastText import load_model
    from torch.nn.modules.sparse import EmbeddingBag
    
    # TODO: Reduce size of fasttext model binary
    class FastTextEmbeddingBag(EmbeddingBag):
        def __init__(self, model_path):
            self.model = load_model(model_path)
            input_matrix = self.model.get_input_matrix()
            input_matrix_shape = input_matrix.shape
            super().__init__(input_matrix_shape[0], input_matrix_shape[1])
            self.weight.data.copy_(torch.FloatTensor(input_matrix))

        def forward(self, words):
            word_subinds = np.empty([0], dtype=np.int64)
            word_offsets = [0]
            for word in words:
                _, subinds = self.model.get_subwords(word)
                word_subinds = np.concatenate((word_subinds, subinds))
                word_offsets.append(word_offsets[-1] + len(subinds))
            word_offsets = word_offsets[:-1]
            ind = torch.LongTensor(word_subinds)
            offsets = torch.LongTensor(word_offsets)
            return super().forward(ind, offsets)
    
    char_encoder = FastTextEmbeddingBag(str(DATA_ROOT / "ft_model.bin"))
    config.set("embedding_sz", 300)

Simple word-level embeddings

In [30]:
if config.char_encoder == "fasttext":
    from allennlp.modules import Embedding

    ft_matrix = np.random.randn(bert_config.vocab_size, 300) * 0.3
    char_encoder = Embedding(bert_config.vocab_size, 300, weight=torch.FloatTensor(ft_matrix))
    config.set("embedding_sz", 300)

In [31]:
def freeze(x):
    x.requires_grad = False
    if hasattr(x, "parameters"):
        for p in x.parameters: freeze(p)

In [32]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

In [33]:
class EmbeddingWPositionEmbs(nn.Module):
    """Embeds, then maps the embeddings to the bert_hidden_sz for processing"""
    def __init__(self, word_emb: nn.Module, 
                 embedding_dim,
                 bert_hidden_sz, 
                 freeze_embeddings=False,
                 dropout=0.1):
        super().__init__()
        self.word_emb = word_emb
        if freeze_embeddings: freeze(self.word_emb)
        self.position_embeddings = nn.Embedding(config.max_seq_len, 
                                                bert_hidden_sz)
        self.linear = nn.Linear(embedding_dim, bert_hidden_sz,
                                bias=False) # Transform dimensions
        self.norm = LayerNorm(bert_hidden_sz)
        self.do = nn.Dropout(0.1)
    
    def get_word_embs(self, input_ids):
        return self.linear(self.word_emb(input_ids))
    
    def forward(self, input_ids):
        # We won't be using token types since we won't be predicting the next sentence
        bs, seq_length, *_ = input_ids.shape
        position_ids = (torch.arange(seq_length, dtype=torch.long)
                             .to(input_ids.device)
                             .unsqueeze(0)
                             .expand((bs, seq_length)))
        word_embs = self.get_word_embs(input_ids)
        position_embs = self.position_embeddings(position_ids)
        return self.do(self.norm(word_embs + position_embs))

In [34]:
sample_embs = EmbeddingWPositionEmbs(
    char_encoder,
    config.embedding_sz,
    bert_config.hidden_size,
)

In [35]:
sample_embs

EmbeddingWPositionEmbs(
  (word_emb): ElmoEncoder(
    (_inner): _ElmoCharacterEncoder(
      (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
      (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
      (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
      (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
      (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
      (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
      (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
      (_highways): Highway(
        (_layers): ModuleList(
          (0): Linear(in_features=2048, out_features=4096, bias=True)
        )
      )
      (_projection): Linear(in_features=2048, out_features=128, bias=True)
    )
  )
  (position_embeddings): Embedding(128, 32)
  (linear): Linear(in_features=128, out_features=32, bias=False)
  (norm): LayerNorm()
  (do): Dropout(p=0.1)
)

# Masked Language Model

### The encoder

In [36]:
bert_encoder = BertEncoder(bert_config)

In [37]:
class CustomBert(nn.Module):
    def __init__(self, embeddings, encoder):
        super().__init__()
        self.embeddings = embeddings
        self.encoder = encoder
        
    def forward(self, input_ids, 
                token_type_ids=None, 
                attention_mask=None):
        if attention_mask is None:
            if len(input_ids.shape) > 2:
                attention_mask = torch.ones_like(input_ids[:, :, 0])
            else:
                attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            if len(input_ids.shape) > 2:
                token_type_ids = torch.ones_like(input_ids[:, :, 0])
            else:
                token_type_ids = torch.ones_like(input_ids)
        
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=True)
        return encoded_layers

In [38]:
class BertMLMPooler(nn.Module):
    def forward(self, x: List[torch.FloatTensor]) -> torch.FloatTensor:
        return x[-1] # return final layer only

In [39]:
bert_model = CustomBert(sample_embs, bert_encoder)

In [40]:
import math
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertCustomLMPredictionHead(nn.Module):
    def __init__(self, config, out_sz, vocab_sz, 
                 embedding, output_logits=True):
        super().__init__()
        # Projections
        self.dense = nn.Linear(config.hidden_size, out_sz)
        self.transform_act_fn = gelu
        self.LayerNorm = LayerNorm(out_sz)
        
        # Predictions
        self.output_logits = output_logits
        if self.output_logits:
            self.decoder = nn.Linear(out_sz, vocab_sz, 
                                     bias=False)
            if embedding is not None:
                self.decoder.weight = embedding
            self.bias = nn.Parameter(torch.zeros(vocab_sz))

    def forward(self, hidden_states):
        if self.output_logits:
            hidden_states = self.dense(hidden_states)
            hidden_states = self.transform_act_fn(hidden_states)
            preds = self.LayerNorm(hidden_states)
            preds = self.decoder(preds) + self.bias
        else:
            preds = hidden_states
        return preds

In [41]:
output_logits = config.loss != "is"
bert_mlm_head = BertCustomLMPredictionHead(config=bert_config, 
                                           out_sz=config.embedding_sz, 
                                           vocab_sz=config.vocab_sz,
                                           embedding=(sample_embs.word_emb.weight 
                                                      if config.char_encoder == "fasttext" 
                                                      else None),
                                           output_logits=output_logits,
)

In [42]:
custom_model = nn.Sequential(
    bert_model,
    BertMLMPooler(),
    bert_mlm_head,
)

### The decoder

In [43]:
if config.char_encoder != "fasttext":
    from allennlp.modules import Embedding
    output_embs = Embedding(config.vocab_sz, bert_config.hidden_size)
else:
    output_embs = sample_embs.get_word_embs

# Loss Functions

In [44]:
def build_unigram_noise(freq):
    """build the unigram noise from a list of frequency
    Parameters:
        freq: a tensor of #occurrences of the corresponding index
    Return:
        unigram_noise: a torch.Tensor with size ntokens,
        elements indicate the probability distribution
    """
    total = freq.sum()
    noise = freq / total
    assert abs(noise.sum() - 1) < 0.001
    return noise ** 0.75 # slight modification

In [45]:
from torch.distributions import Categorical

Masked Cross Entropy

In [46]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self._loss = nn.CrossEntropyLoss(reduction="none")

    def forward(self, preds, tgts, mask=None) -> torch.tensor:
        if mask is None:
            return self._loss(preds, tgts).mean()
        else:
            # Is this reshaping really necessary? Seems like there would be a more elegant solution
            loss = self._loss(preds.view((-1, preds.size(-1))),
                              tgts.view((-1, )))
            n_elements = mask.sum()
            return (loss * mask.view((-1, )).float()).sum() / n_elements

Importance Sampling

In [47]:
class UniformSampler:
    def __init__(self, min_, max_):
        self.min_, self.max_ = min_, max_
    
    def sample(self, shape):
        return torch.randint(low=self.min_,
                             high=self.max_, size=shape)

In [48]:
class ImportanceSamplingLoss(nn.Module):
    def __init__(self, embedding_generator,
                 probs: np.ndarray, k=10):
        super().__init__()
        self.embedding_generator = embedding_generator
        # TODO: Should this be according to the unigram probability?
        # Or should it be uniform?
        self.sampler = UniformSampler(1, vocab.get_vocab_size())
#         self.sampler = Categorical(probs=probs)
        # TODO: Compute samples in advance
        self._loss_func = MaskedCrossEntropyLoss()
        self.k = k
    
    def get_negative_samples(self, bs) -> torch.LongTensor:
        neg = self.sampler.sample((bs, self.k, )) # TODO: Speed up??
        return neg
    
    def get_embeddings(self, idxs: torch.LongTensor) -> torch.FloatTensor:
        """Converts indexes into vectors"""
        return self.embedding_generator(idxs) # TODO: Implement general case
    
    def forward(self, y: torch.LongTensor, tgt, mask=None):
        """
        Expects input of shape
        y: (batch * seq, feature_sz)
        tgt: (batch * seq, )
        """
        if len(y.shape) > 2:            
            y = y.view((-1, y.size(-1))) # (batch * seq, feature_sz)
            tgt = tgt.view((-1, )) # (batch * seq, s)
        bs = y.size(0)
        neg_samples = self.get_negative_samples(bs)
        idxs = torch.cat([tgt.unsqueeze(1), neg_samples], dim=1)
        # TODO: More efficien implementation?
        # y: (batch * seq, feature_sz)
        # embeddings: (batch * seq, feature_sz, k)
        dot_prods = torch.einsum("bkf,bf->bk", self.get_embeddings(idxs), y)
        return self._loss_func(dot_prods, torch.zeros(bs, dtype=torch.int64),
                               mask=mask)

Testing

In [49]:
y = torch.randn((3, 7, bert_config.hidden_size)).view((-1, bert_config.hidden_size))
tgt = torch.randint(100, (3, 7)).view((-1, ))

In [50]:
cat = Categorical(probs=torch.rand(100, ))

In [51]:
neg_samples = cat.sample((y.size(0), 10))

In [52]:
idxs = torch.cat([tgt.unsqueeze(1), neg_samples], dim=1)

In [53]:
output_embs(idxs).shape

torch.Size([21, 11, 32])

In [54]:
y.shape

torch.Size([21, 32])

In [55]:
loss = ImportanceSamplingLoss(
    output_embs, probs=torch.rand(100),
)
loss(y, tgt)

tensor(2.4042, grad_fn=<MeanBackward1>)

In [56]:
torch.bmm(y.unsqueeze(1), loss.get_embeddings(idxs).transpose(1, 2)).squeeze(1).shape

torch.Size([21, 11])

In [57]:
loss.get_embeddings(idxs).shape

torch.Size([21, 11, 32])

# Decoders

In [58]:
class WordCorrection(nn.Module):
    """From the paper `Exploring the Limitations of Language Modeling`"""
    def __init__(self, hidden_sz: int, bottleneck_sz: int=128):
        super().__init__()
        self.l1 = nn.Linear(hidden_sz, bottleneck_sz)
        
    def forward(self, h: torch.FloatTensor, 
                corr: torch.FloatTensor):
        x = self.l1(h)
        return x @ corr

In [59]:
# Use a different decoder (TODO: Enable sharing of parameters)
# char_decoder = _ElmoCharacterEncoder(
#     options_file=options_file, 
#     weight_file=weight_file,
#     requires_grad=True
# )

# Training

In [60]:
from allennlp.models import Model

In [61]:
class Masker(nn.Module):
    def __init__(self, vocab: Vocabulary, 
                 noise_rate: float=0.15,
                 mask_rate=0.8,
                 replace_rate=0.1):
        super().__init__()
        self.vocab = vocab
        self.vocab_sz = vocab.get_vocab_size()
        if config.char_encoder == "cnn":
            self.mask_id = mask_char_ids.unsqueeze(0).unsqueeze(1)
        else:
            self.mask_id = vocab.get_token_index("[MASK]")
        self.noise_rate = noise_rate
        self.mask_rate = mask_rate
        self.replace_rate = replace_rate
        
    def create_mask(self, shape, ones_ratio, dtype=torch.uint8):
        return (torch.ones(shape, dtype=dtype, 
                           requires_grad=False)
                     .bernoulli(ones_ratio))

    def get_random_input_ids(self, shape):
        """Returns randomly sampled """
        if config.char_encoder == "cnn":
            rint = torch.randint(self.vocab_sz, shape)
            return word_id_to_char_idxs[rint]
        elif config.char_encoder == "fasttext":
            return torch.randint(self.vocab_sz, shape)

    def forward(self, x: torch.LongTensor) -> torch.LongTensor:
        char_level = len(x.shape) > 2
        if self.train and self.noise_rate > 0:
            with torch.no_grad(): # no grads required here
                mask_shape = x.shape[:-1] if char_level else x.shape
                mask = self.create_mask(mask_shape, 
                                        self.noise_rate * self.mask_rate)
                if config.char_encoder == "cnn":
                    x = torch.where(mask.unsqueeze(2), self.mask_id, x)                
                else:
                    x = x.masked_fill(mask, self.mask_id)
                
                if self.replace_rate > 0.:
                    # this is techinically incorrect, since we might overwrite the mask tokens
                    # but I guess it will do for now
                    mask = self.create_mask(mask_shape,
                                            self.noise_rate * self.replace_rate)
                    x = torch.where(mask.unsqueeze(2) if config.char_encoder == "cnn" else mask, 
                                    self.get_random_input_ids(mask_shape),
                                    x)
        return x

In [62]:
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy

class MaskedLM(Model):
    def __init__(self, vocab: Vocabulary, model: nn.Module,
                loss: nn.Module, noise_rate=0.8):
        super().__init__(vocab)
        self.masker = Masker(vocab, noise_rate=noise_rate)
        self.accuracy = CategoricalAccuracy()
        self.model = model
        self.loss = loss
    
    @property
    def outputs_logits(self) -> bool:
        return self.model[-1].output_logits
    
    def forward(self, input: TensorDict, 
                output: TensorDict, **kwargs) -> TensorDict:
        mask = get_text_field_mask(input)
        x = self.masker(input["tokens"])
        tgt = output["words"]
        
        logits = self.model(x)
        out_dict = {"loss": self.loss(logits, tgt, mask=mask)}
        out_dict["logits"] = logits
        if self.outputs_logits:
            out_dict["accuracy"] = self.accuracy(logits, tgt)
        return out_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.outputs_logits:
            return {"accuracy": self.accuracy.get_metric(reset)}
        else:
            return {}

In [63]:
if config.loss == "masked_crossentropy":
    loss = MaskedCrossEntropyLoss()
elif config.loss == "crossentropy":
    _loss = nn.CrossEntropyLoss()
    def ce(y, t, mask=None): 
        return _loss(y.view((-1, y.size(-1))), t.view((-1, )))
    loss = ce
elif config.loss == "is":
    # TODO: Implement masking
    loss = ImportanceSamplingLoss(output_embs, None) # temporarily sample at random
else:
    raise ValueError("AAAAAAAAAAAAA")
masked_lm = MaskedLM(vocab, custom_model, loss, noise_rate=0.15)

Sanity checks

In [64]:
tokens = masked_lm.masker(batch["input"]["tokens"])

In [65]:
hidden_states = masked_lm.model[:2](tokens)

In [66]:
masked_lm.model[2](hidden_states)

tensor([[[-1.5804,  0.7760, -1.8739,  ...,  1.9227,  0.7227,  0.1324],
         [ 0.3603, -0.3518,  0.5121,  ..., -0.1061,  1.1555, -1.0039],
         [-1.2652, -1.1676, -1.7268,  ...,  0.4276,  0.7936, -0.5868],
         ...,
         [-0.3311,  0.4366, -0.5052,  ...,  0.3579, -0.2515, -0.4568],
         [ 0.2791,  0.7221, -1.7120,  ...,  0.4504,  0.3753, -1.3002],
         [ 1.1276, -0.1795, -1.3316,  ...,  0.4769,  0.0311, -1.1109]],

        [[-1.8727,  0.9486, -2.4024,  ...,  1.2721,  1.2755,  0.1641],
         [-0.0389, -0.4296,  0.4552,  ...,  0.2502,  0.8206, -1.0632],
         [-0.7303, -0.9013, -0.9299,  ...,  0.5584,  1.2745, -0.7620],
         ...,
         [-0.6082,  0.5045, -1.7468,  ...,  1.2611, -0.4222, -0.4912],
         [ 0.5867,  0.8731, -1.3594,  ...,  0.8824,  1.8493, -1.1965],
         [ 0.2679,  0.2569, -1.1916,  ...,  0.6775,  0.1391, -0.8199]],

        [[-1.4472,  1.0144, -1.9252,  ...,  2.3477,  1.1999, -0.2477],
         [ 0.1438, -0.3137,  0.5217,  ..., -0

In [67]:
masked_lm(**batch)

{'loss': tensor(2.3974, grad_fn=<DivBackward0>),
 'logits': tensor([[[-1.5808,  1.1482, -1.7822,  ...,  2.1955,  1.0737,  0.1318],
          [-0.2795, -0.2390,  0.3636,  ...,  0.2533,  0.9747, -0.7541],
          [-1.4023, -1.1716, -1.8004,  ...,  0.6374,  0.5905, -0.6558],
          ...,
          [-0.5546,  0.6915, -0.9092,  ...,  1.2913, -0.2356, -0.1399],
          [ 0.2118,  0.3902, -1.2860,  ...,  0.6158,  1.3520, -1.6078],
          [ 0.5129,  0.2748, -0.9802,  ...,  0.8152,  0.9959, -1.2413]],
 
         [[-1.7395,  0.0238, -1.9076,  ...,  2.2626,  1.2561,  0.0561],
          [-0.0229, -0.0825, -0.0552,  ...,  0.2150,  1.0042, -1.3747],
          [-0.8470, -0.9412, -1.2945,  ...,  0.7761,  1.0303, -0.8591],
          ...,
          [-0.2673,  0.4812, -1.2947,  ...,  1.0498, -0.8741, -0.3677],
          [ 0.5930,  0.7765, -1.0648,  ...,  0.5665,  1.6461, -1.2059],
          [ 0.3977,  0.3008, -0.7681,  ...,  0.5883,  0.5202, -1.3566]],
 
         [[-1.3890, -0.1364, -1.7042,  ..

Sanity checks

In [68]:
from allennlp.training import Trainer

optimizer = torch.optim.Adam(masked_lm.parameters(), lr=0.01)

trainer = Trainer(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=1,
)

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


In [69]:
trainer.train()

loss: 1.4233 ||: 100%|██████████| 16/16 [13:54<00:00, 51.92s/it] 
loss: 0.9495 ||: 100%|██████████| 16/16 [01:19<00:00,  7.94s/it]


{'best_epoch': 0,
 'peak_cpu_memory_MB': 4138.745856,
 'training_duration': '00:15:14',
 'training_start_epoch': 0,
 'training_epochs': 0,
 'epoch': 0,
 'training_loss': 1.423305545002222,
 'training_cpu_memory_MB': 4138.745856,
 'validation_loss': 0.9494792744517326,
 'best_validation_loss': 0.9494792744517326}

In [70]:
masked_lm(**batch)["logits"].argmax(2)

tensor([[17, 17, 17,  ..., 17, 17, 17],
        [17, 17, 17,  ..., 17, 17, 17],
        [17, 17, 17,  ..., 17, 17, 17],
        ...,
        [17, 17, 17,  ..., 17, 17, 17],
        [17, 17, 17,  ..., 17, 17, 17],
        [17, 17, 17,  ..., 17, 17, 17]])

For IS loss, we need to aggregate the embeddings at the end of training to make predictions

In [71]:
import math

vocab_sz = vocab.get_vocab_size()
embedding_sz = bert_config.hidden_size

if config.loss == "is":
    bs = 64
    output_embedding_matrix = torch.zeros(vocab_sz, embedding_sz)
    num_batches = math.ceil(vocab_sz / bs)
    for i in range(num_batches):
        start,end = i*bs, min(((i+1)*bs), vocab_sz)
        idxs = torch.arange(start=start, end=end).unsqueeze(0)
        output_embedding_matrix[start:end, :] = loss.get_embeddings(idxs)

# Manually Check Outputs

TODO: Implement manual checks for negative sampling loss as well

In [72]:
def to_np(t): return t.detach().cpu().numpy()

In [73]:
def to_words(arr):
    if len(arr.shape) > 1:
        return [to_words(a) for a in arr]
    else:
        arr = to_np(arr)
        return " ".join([vocab.get_token_from_index(i) for i in arr])

In [74]:
def get_preds(model, batch: TensorDict):
    if config.loss == "is":
        logits = model(**batch)["logits"]
        return (logits @ output_embedding_matrix.transpose(0, 1)).argmax(2)
    else:
        return model(**batch)["logits"].argmax(2)

In [75]:
masked_lm(**batch)["logits"].shape

torch.Size([40, 128, 32])

In [76]:
def cstm_pprint(x):
    print("\n\n".join(x))

In [77]:
cstm_pprint(to_words(batch["output"]["words"])[:5])

[CLS] " Israeli " " Apartheid " " Article How dare you identify a very clear and reasonable allegation of bias as " " vandalism " " . I 'm disgusted with your behaviour , and your willingness to suppress arguments that do n't mesh with your ideological foundation . Considering that the page I edited deals with human rights , I find it very salient that you are more than willing to suppress the freedom of expression by turning this site into a dictatorship of the obsessive over the intelligent . -edit : I noticed that you removed my statement quite quickly . Is there some explanation for succumbing to such cowardice and refusing to address me directly ? —Preceding unsigned comment added by 138.40.153.43 " [SEP]

[CLS] It slows trains down because the existing speed limits already have a significant margin of safety built in so the engineer can exceed them by small amounts and still remain safe . With PTC such fluctuations over the posted limit will trigger an overspeed alarm with the re

In [78]:
cstm_pprint(to_words(get_preds(masked_lm, batch))[:5])

this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this

this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this

In [79]:
if config.debugging:
    for i in range(3):
        trainer = Trainer(
            model=masked_lm,
            optimizer=optimizer,
            iterator=iterator,
            train_dataset=train_ds,
            validation_dataset=val_ds,
            serialization_dir=None,
            cuda_device=0 if torch.cuda.is_available() else -1,
            num_epochs=1,
        )
        trainer.train()
        cstm_pprint(to_words(get_preds(masked_lm, batch))[:5])

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled
loss: 0.6985 ||: 100%|██████████| 16/16 [03:56<00:00, 10.56s/it]
loss: 1.0746 ||: 100%|██████████| 16/16 [01:19<00:00,  8.11s/it]
You provided a validation dataset but patience was set to None, meaning that early stopping is disabled
  0%|          | 0/16 [00:00<?, ?it/s]

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

loss: 0.6535 ||: 100%|██████████| 16/16 [03:57<00:00, 13.05s/it]
loss: 1.0837 ||: 100%|██████████| 16/16 [01:26<00:00,  8.46s/it]
You provided a validation dataset but patience was set to None, meaning that early stopping is disabled
  0%|          | 0/16 [00:00<?, ?it/s]

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

loss: 0.6538 ||: 100%|██████████| 16/16 [04:16<00:00, 18.29s/it]
loss: 1.1413 ||: 100%|██████████| 16/16 [01:24<00:00,  7.95s/it]


this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this

this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this this

# Predict and Evaluate