In [1]:
# settings for seamlessly running on colab
import os
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.environ["IS_COLAB"] = "True"
except ImportError:
    os.environ["IS_COLAB"] = "False"

In [2]:
%%bash
if [ "$IS_COLAB" = "True" ]; then
    pip install git+https://github.com/facebookresearch/fastText.git
    pip install torch
    pip install torchvision
    pip install allennlp
    pip install dnspython
    pip install jupyter_slack
    pip install git+https://github.com/keitakurita/Better_LSTM_PyTorch.git
fi

In [3]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline
from overrides import overrides

In [4]:
import time
from contextlib import contextmanager

class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')
    
import functools
import traceback
import sys

def get_ref_free_exc_info():
    "Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory"
    type, val, tb = sys.exc_info()
    traceback.clear_frames(tb)
    return (type, val, tb)

def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = get_ref_free_exc_info() # must!
            raise type(val).with_traceback(tb) from None
    return wrapper

def ifnone(a: Any, alt: Any): return alt if a is None else a

In [5]:
# for papermill
testing = True
debugging = False
seed = 1
char_encoder = "fasttext"
computational_batch_size = 64
batch_size = 64
lr = 4e-3
lr_schedule = "slanted_triangular"
epochs = 6 if not testing else 1
hidden_sz = 128
dataset = "jigsaw"
n_classes = 6
max_seq_len = 128
download_data = False
ft_model_path = "../data/jigsaw/ft_model.txt"
max_vocab_size = 40000
dropouti = 0.2
dropoutw = 0.0
dropoute = 0.2
dropoutr = 0.3 # TODO: Implement
val_ratio = 0.0
use_augmented = False
freeze_embeddings = True
mixup_ratio = 0.0
discrete_mixup_ratio = 0.0
attention_bias = True
weight_decay = 0.
bias_init = True
neg_splits = 1
num_layers = 2
rnn_type = "lstm"
pooling_type = "augmented_multipool" # attention or multipool or augmented_multipool
model_type = "standard"
use_word_level_features = True
use_sentence_level_features = True
bucket = True
compute_thres_on_test = True
find_lr = False
permute_sentences = False
run_id = None

In [6]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    char_encoder=char_encoder,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    loss="crossentropy",
    lr=lr,
    lr_schedule=lr_schedule,
    epochs=epochs,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    dropouti=dropouti,
    dropoutw=dropoutw,
    dropoute=dropoute,
    dropoutr=dropoutr,
    val_ratio=val_ratio,
    use_augmented=use_augmented,
    freeze_embeddings=freeze_embeddings,
    attention_bias=attention_bias,
    weight_decay=weight_decay,
    bias_init=bias_init,
    neg_splits=neg_splits,
    num_layers=num_layers,
    rnn_type=rnn_type,
    pooling_type=pooling_type,
    model_type=model_type,
    use_word_level_features=use_word_level_features,
    use_sentence_level_features=use_sentence_level_features,
    mixup_ratio=mixup_ratio,
    discrete_mixup_ratio=discrete_mixup_ratio,
    bucket=bucket,
    compute_thres_on_test=compute_thres_on_test,
    permute_sentences=permute_sentences,
    find_lr=find_lr,
    run_id=run_id,
)

In [7]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]  # pylint: disable=invalid-name

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util
from allennlp.data.dataset_readers import DatasetReader

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [8]:
if os.environ["IS_COLAB"] != "True":
    DATA_ROOT = Path("../data") / config.dataset
else:
    DATA_ROOT = Path("./gdrive/My Drive/Colab_Workspace/Colab Notebooks/data") / config.dataset
    config.ft_model_path = str(DATA_ROOT / "ft_model.txt")

In [9]:
import subprocess
if download_data:
    if config.val_ratio > 0.0:
        fnames = ["train_wo_val.csv", "test_proced.csv", "val.csv", "ft_model.txt"]
    else:
        fnames = ["train.csv", "test_proced.csv", "ft_model.txt"]
    if config.use_augmented or config.discrete_mixup_ratio > 0.0: fnames.append("train_extra.csv")
    for fname in fnames:
        if not (DATA_ROOT / fname).exists():
            print(subprocess.Popen([f"aws s3 cp s3://nnfornlp/raw_data/jigsaw/{fname} {str(DATA_ROOT)}"],
                                   shell=True, stdout=subprocess.PIPE).stdout.read())

# Dataset

In [10]:
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class JigsawLMDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len
        
    def _clean(self, x: str) -> str:
        """
        Maps a word to its desired output. Will leave as identity for now.
        In the future, will change to denoising operation.
        """
        return x

    @overrides
    def text_to_instance(self, tokens: List[str]) -> Instance:
        sentence_field = TextField([Token(x) for x in tokens],
                                   self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = TextField([Token(self._clean(x)) for x in tokens],
                                          self.token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                self.tokenizer(row["comment_text"]),
            )

In [11]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper, ELMoTokenCharactersIndexer
from allennlp.data.token_indexers import SingleIdTokenIndexer

if config.char_encoder == "cnn":
    token_indexer = ELMoTokenCharactersIndexer()
else:
    token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) # Temporary

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

def tokenizer(x: str):
        return ["[CLS]"] + [w.text for w in
                _spacy_tok(x)[:config.max_seq_len - 2]] + ["[SEP]"]

In [12]:
reader = JigsawLMDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer},
)
train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                          "val.csv",
                                                                          "test_proced.csv"])

1000it [00:03, 313.20it/s]
1000it [00:02, 441.27it/s]
1000it [00:02, 370.82it/s]


In [13]:
vars(train_ds[2].fields["input"])

{'tokens': [[CLS],
  How,
  can,
  my,
  comment,
  on,
  my,
  own,
  talk,
  page,
  WP,
  :,
  DISRUPT,
  wikipedia,
  ?,
  [SEP]],
 '_token_indexers': {'tokens': <allennlp.data.token_indexers.single_id_token_indexer.SingleIdTokenIndexer at 0x1a2c7fe358>},
 '_indexed_tokens': None,
 '_indexer_name_to_indexed_token': None}

### Build Vocab

In [14]:
from allennlp.data.vocabulary import Vocabulary

In [15]:
full_ds = train_ds + test_ds + val_ds
vocab = Vocabulary.from_instances(full_ds, tokens_to_add={"tokens": ["[MASK]"]},
                                  max_vocab_size=config.max_vocab_size)

100%|██████████| 3000/3000 [00:00<00:00, 12427.40it/s]


### Build word to indices mapping

In [16]:
if config.char_encoder == "cnn":
    # TODO: Speed up
    # See allennlp/data/token_indexers/elmo_indexer.py
    with timer("Building character indexes"):
        word_id_to_char_idxs = []
        freqs = []
        for w, freq in word_freqs.items():
            # TODO: Check for start/end of word symbols
            char_idxs = [258] + [int(c) for c in w.encode("utf-8")] + [259]
            word_id_to_char_idxs.append([x + 1 for x in char_idxs])
            freqs.append(freqs)
        word_id_to_char_idxs = np.array(word_id_to_char_idxs)
        freqs = np.array(freqs)

    word_id_to_char_idxs = torch.LongTensor(word_id_to_char_idxs)

### Build frequencies

TODO: Implement fast sampling

In [17]:
freqs = np.zeros(vocab.get_vocab_size())
for w, c in vocab._retained_counter["tokens"].items():
    freqs[vocab.get_token_index(w)] = c
freqs /= freqs.sum()
freqs **= 2 / 3

### Build Iterator

In [18]:
from allennlp.data.iterators import BucketIterator, DataIterator
iterator = BucketIterator(
        batch_size=config.batch_size, 
        biggest_batch_first=config.testing,
        sorting_keys=[("input", "num_tokens")],
)
iterator.index_with(vocab)

In [19]:
next(iterator(train_ds))["input"]["tokens"].shape

torch.Size([40, 128])

In [20]:
batch = next(iterator(train_ds))

In [21]:
batch

{'input': {'tokens': tensor([[   8,  243, 8678,  ...,  250, 8684,    9],
          [   8,    5,   17,  ...,    0,    0,    0],
          [   8,    5,   62,  ...,    3,  801,    9],
          ...,
          [   8,    5,   44,  ...,  153,    2,    9],
          [   8,    5,   62,  ...,    2,   25,    9],
          [   8,   38,  913,  ...,  135,   26,    9]])},
 'output': {'tokens': tensor([[   8,  243, 8678,  ...,  250, 8684,    9],
          [   8,    5,   17,  ...,    0,    0,    0],
          [   8,    5,   62,  ...,    3,  801,    9],
          ...,
          [   8,    5,   44,  ...,  153,    2,    9],
          [   8,    5,   62,  ...,    2,   25,    9],
          [   8,   38,  913,  ...,  135,   26,    9]])}}

# Building token embedder

In [22]:
if config.char_encoder == "cnn":
    from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
    from allennlp.modules.elmo import _ElmoCharacterEncoder

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
    char_encoder = _ElmoCharacterEncoder(
        options_file=options_file, 
        weight_file=weight_file,
        requires_grad=True
    )

# char_encoder

# sample_idxs = next(iterator(train_ds))["tokens"]["tokens"]

# char_encoder(sample_idxs)

In [23]:
if config.char_encoder == "boc":
    from torch.nn.modules.sparse import EmbeddingBag
    # TODO: Customize
    class FastTextEmbeddingBag(EmbeddingBag):
        def __init__(self, model_path):
            self.model = load_model(model_path)
            input_matrix = self.model.get_input_matrix()
            input_matrix_shape = input_matrix.shape
            super().__init__(input_matrix_shape[0], input_matrix_shape[1])
            self.weight.data.copy_(torch.FloatTensor(input_matrix))

        def forward(self, words):
            word_subinds = np.empty([0], dtype=np.int64)
            word_offsets = [0]
            for word in words:
                _, subinds = self.model.get_subwords(word)
                word_subinds = np.concatenate((word_subinds, subinds))
                word_offsets.append(word_offsets[-1] + len(subinds))
            word_offsets = word_offsets[:-1]
            ind = Variable(torch.LongTensor(word_subinds))
            offsets = Variable(torch.LongTensor(word_offsets))
            return super().forward(ind, offsets)

# Masked Language Model

In [24]:
from pytorch_pretrained_bert.modeling import (BertConfig, BertForMaskedLM, 
                                              BertEncoder, BertPooler, BertOnlyMLMHead)

In [25]:
bert_config = BertConfig(
        config.max_vocab_size, hidden_size=32, num_attention_heads=4,
        num_hidden_layers=4, intermediate_size=32 * 4,
)

In [26]:
bert_config

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 32,
  "initializer_range": 0.02,
  "intermediate_size": 128,
  "max_position_embeddings": 512,
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "type_vocab_size": 2,
  "vocab_size": 40000
}

### The encoder

For now, temporarily use fasttext embeddings

In [27]:
from tqdm import tqdm
import warnings

def get_coefs(word,*arr): return word, np.asarray(arr, dtype='float32')

def get_fasttext_embeddings(model_path: str, vocab: Vocabulary):
    prog_bar = tqdm(open(model_path, encoding="utf8", errors='ignore'))
    prog_bar.set_description("Loading embeddings")
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in prog_bar
                             if len(o)>100)
    all_embs = np.stack(embeddings_index.values())

    embeddings = np.zeros((bert_config.vocab_size + 5, 300))
    n_missing_tokens = 0
    prog_bar = tqdm(vocab.get_index_to_token_vocabulary().items())
    prog_bar.set_description("Creating matrix")
    for idx, token in prog_bar:
        if idx == 0: continue # keep padding as all zeros
        if idx == 1: continue # Treat unknown words as dropped words
        if token == "[MASK]":
            embeddings[idx, :] = np.random.randn(300) * 0.5
        if token not in embeddings_index:
            n_missing_tokens += 1
            if n_missing_tokens < 10:
                warnings.warn(f"Token {token} not in embeddings: did you change preprocessing?")
            if n_missing_tokens == 10:
                warnings.warn(f"More than {n_missing_tokens} missing, supressing warnings")
        else:
            embeddings[idx, :] = embeddings_index[token]
    
    if n_missing_tokens > 0:
        warnings.warn(f"{n_missing_tokens} in total are missing from embedding text file")
    return embeddings

In [28]:
ft_matrix = np.random.randn(bert_config.vocab_size + 5, 300) * 0.3
#get_fasttext_embeddings(str(DATA_ROOT / "ft_model.txt"), vocab)

In [29]:
def freeze(x):
    x.requires_grad = False
    if hasattr(x, "parameters"):
        for p in x.parameters: freeze(p)

In [30]:
from allennlp.modules import Embedding

In [31]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

In [32]:
class CustomEmbedding(Embedding):
    def __init__(self, num_embeddings, embedding_dim, bert_hidden_sz,
                 freeze_embeddings=False, dropout=0.1, **kwargs):
        super().__init__(num_embeddings, embedding_dim, **kwargs)
        if freeze_embeddings:
            freeze(self.weight)
        self.position_embeddings = nn.Embedding(config.max_seq_len, 
                                                bert_hidden_sz)
        self.linear = nn.Linear(embedding_dim, bert_hidden_sz,
                                bias=False) # Transform dimensions
        self.norm = LayerNorm(bert_hidden_sz)
        self.do = nn.Dropout(0.1)
        
    def forward(self, input_ids, token_type_ids):
        # We won't be using token types since we won't be predicting the next sentence
        seq_length = input_ids.size(1)
        position_ids = (torch.arange(seq_length, dtype=torch.long)
                             .to(input_ids.device)
                             .unsqueeze(0)
                             .expand_as(input_ids))
        word_embs = self.linear(super().forward(input_ids))
        position_embs = self.position_embeddings(position_ids)
        return self.do(self.norm(word_embs + position_embs))

In [33]:
sample_embs = CustomEmbedding(bert_config.vocab_size + 5, 
                              300,
                              bert_config.hidden_size,
                              weight=torch.FloatTensor(ft_matrix))

In [34]:
sample_embs

CustomEmbedding(
  (position_embeddings): Embedding(128, 32)
  (linear): Linear(in_features=300, out_features=32, bias=False)
  (norm): LayerNorm()
  (do): Dropout(p=0.1)
)

In [35]:
bert_encoder = BertEncoder(bert_config)

In [36]:
class CustomBert(nn.Module):
    def __init__(self, embeddings, encoder):
        super().__init__()
        self.embeddings = embeddings
        self.encoder = encoder
        
    def forward(self, input_ids, 
                token_type_ids=None, 
                attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=True)
        return encoded_layers

In [37]:
class BertMLMPooler(nn.Module):
    def forward(self, x: List[torch.FloatTensor]) -> torch.FloatTensor:
        return x[-1] # return final layer only

In [38]:
bert_model = CustomBert(sample_embs, bert_encoder)

In [39]:
import math
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertCustomLMPredictionHead(nn.Module):
    def __init__(self, config, embedding, output_logits=True):
        super().__init__()
        # Projections
        self.dense = nn.Linear(config.hidden_size, embedding.shape[1])
        self.transform_act_fn = gelu
        self.LayerNorm = LayerNorm(embedding.shape[1])
        
        # Predictions
        self.output_logits = output_logits
        if self.output_logits:
            self.decoder = nn.Linear(embedding.size(1), embedding.size(0), 
                                     bias=False)
            self.decoder.weight = embedding
            self.bias = nn.Parameter(torch.zeros(embedding.size(0)))

    def forward(self, hidden_states):
        if self.output_logits:
            hidden_states = self.dense(hidden_states)
            hidden_states = self.transform_act_fn(hidden_states)
            preds = self.LayerNorm(hidden_states)
            preds = self.decoder(preds) + self.bias
        else:
            preds = hidden_states
        return preds

In [None]:
bert_mlm_head = BertCustomLMPredictionHead(bert_config, sample_embs.weight,
                                           output_logits=True)

In [None]:
custom_model = nn.Sequential(
    bert_model,
    BertMLMPooler(),
    bert_mlm_head,
)

# Negative Sampling Loss

In [None]:
# Use a different decoder (TODO: Enable sharing of parameters)
# char_decoder = _ElmoCharacterEncoder(
#     options_file=options_file, 
#     weight_file=weight_file,
#     requires_grad=True
# )

In [None]:
class WordCorrection(nn.Module):
    """From the paper `Exploring the Limitations of Language Modeling`"""
    def __init__(self, hidden_sz: int, bottleneck_sz: int=128):
        super().__init__()
        self.l1 = nn.Linear(hidden_sz, bottleneck_sz)
        
    def forward(self, h: torch.FloatTensor, 
                corr: torch.FloatTensor):
        x = self.l1(h)
        return x @ corr

In [None]:
def build_unigram_noise(freq):
    """build the unigram noise from a list of frequency
    Parameters:
        freq: a tensor of #occurrences of the corresponding index
    Return:
        unigram_noise: a torch.Tensor with size ntokens,
        elements indicate the probability distribution
    """
    total = freq.sum()
    noise = freq / total
    assert abs(noise.sum() - 1) < 0.001
    return noise ** 0.75 # slight modification

In [None]:
from torch.distributions import Categorical

In [None]:
class ImportanceSamplingLoss(nn.Module):
    def __init__(self, embedding_generator: nn.Module,
                 probs: np.ndarray, k=10):
        super().__init__()
        self.embedding_generator = embedding_generator
        # TODO: Should this be according to the unigram probability?
        # Or should it be uniform?
        self.sampler = Categorical(probs=probs)
        # TODO: Compute samples in advance
        self._loss_func = nn.CrossEntropyLoss()
        self.k = k
    
    def get_negative_samples(self, bs) -> torch.LongTensor:
        neg = self.sampler.sample((bs, self.k, )) # TODO: Speed up??
        return neg
    
    def get_embeddings(self, idxs: torch.LongTensor) -> torch.FloatTensor:
        """Converts indexes into vectors"""
        return self.embedding_generator(idxs, None) # TODO: Implement general case
    
    def forward(self, y: torch.LongTensor, tgt):
        """
        Expects input of shape
        y: (batch * seq, feature_sz)
        tgt: (batch * seq, )
        """
        bs = y.size(0)
        neg_samples = self.get_negative_samples(bs)
        idxs = torch.cat([tgt.unsqueeze(1), neg_samples], dim=1)
        # TODO: More efficient implementation? (Ask TAs)
        dot_prods = torch.bmm(y.unsqueeze(1), 
                              self.get_embeddings(idxs).transpose(1, 2)).squeeze(1)
        return self._loss_func(dot_prods, torch.zeros(bs, dtype=torch.int64))

Testing

In [None]:
y = torch.randn((3, 7, bert_config.hidden_size)).view((-1, bert_config.hidden_size))
tgt = torch.randint(100, (3, 7)).view((-1, ))

In [None]:
cat = Categorical(probs=torch.rand(100, ))

In [None]:
neg_samples = cat.sample((y.size(0), 10))

In [None]:
idxs = torch.cat([tgt.unsqueeze(1), neg_samples], dim=1)

In [None]:
sample_embs(idxs, None).shape

torch.Size([21, 11, 32])

In [None]:
y.shape

torch.Size([21, 32])

In [None]:
torch.bmm(y.unsqueeze(1), sample_embs(idxs, None).transpose(1, 2)).squeeze(1).shape

torch.Size([21, 11])

In [None]:
loss = ImportanceSamplingLoss(
    sample_embs, probs=torch.rand(100),
)
loss(y, tgt)

tensor(9.5355, grad_fn=<NllLossBackward>)

# Training

In [None]:
from allennlp.models import Model

In [None]:
class Masker(nn.Module):
    def __init__(self, vocab: Vocabulary, 
                 noise_rate: float=0.15,
                 mask_rate=0.8,
                 replace_rate=0.1):
        super().__init__()
        self.vocab = vocab
        self.vocab_sz = vocab.get_vocab_size()
        self.mask_id = vocab.get_token_index("[MASK]")
        self.noise_rate = noise_rate
        self.mask_rate = mask_rate
        self.replace_rate = replace_rate
        
    def create_mask(self, x, ones_ratio, dtype=torch.uint8):
        return (torch.ones(x.shape, dtype=dtype, 
                           requires_grad=False)
                     .bernoulli(ones_ratio))

    def forward(self, x: torch.LongTensor) -> torch.LongTensor:
        if self.train and self.noise_rate > 0:
            mask = self.create_mask(x, self.noise_rate * self.mask_rate)
            x = x.masked_fill(mask, self.mask_id)
            if self.replace_rate > 0.:
                # this is techinically incorrect, since we might overwrite the mask tokens
                # but I guess it will do for now
                mask = self.create_mask(x, 1 - self.noise_rate * self.replace_rate,
                                        dtype=x.dtype)
                # TODO: Should we really sample uniformly?
                x = x * mask + torch.randint(self.vocab_sz, x.shape) * (1 - mask)
        return x

In [None]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self._loss = nn.CrossEntropyLoss(reduction="none")

    def forward(self, preds, tgts, mask=None) -> torch.tensor:
        if mask is None:
            return self._loss(preds, tgts).mean()
        else:
            # Is this reshaping really necessary? Seems like there would be a more elegant solution
            loss = self._loss(preds.view((-1, preds.size(-1))),
                              tgts.view((-1, )))
            n_elements = mask.sum()
            return (loss * mask.view((-1, )).float()).sum() / n_elements

In [None]:
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy

class MaskedLM(Model):
    def __init__(self, vocab: Vocabulary, model: nn.Module,
                loss: nn.Module, noise_rate=0.8):
        super().__init__(vocab)
        self.masker = Masker(vocab, noise_rate=noise_rate)
        self.accuracy = CategoricalAccuracy()
        self.model = model
        self.loss = loss
    
    @property
    def outputs_logits(self) -> bool:
        return self.model[-1].output_logits
    
    def forward(self, input: TensorDict, 
                output: TensorDict, **kwargs) -> TensorDict:
        mask = get_text_field_mask(input)
        x = self.masker(input["tokens"])
        logits = self.model(x)
        out_dict = {"loss": self.loss(logits, output["tokens"], mask=mask)}
        out_dict["logits"] = logits
        if self.outputs_logits:
            out_dict["accuracy"] = self.accuracy(logits, output["tokens"])
        return out_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.outputs_logits:
            return {"accuracy": self.accuracy.get_metric(reset)}
        else:
            return {}

In [None]:
if config.loss == "masked_crossentropy":
    loss = MaskedCrossEntropyLoss()
elif config.loss == "crossentropy":
    _loss = nn.CrossEntropyLoss()
    def ce(y, t, mask=None): 
        return _loss(y.view((-1, y.size(-1))), t.view((-1, )))
    loss = ce
elif config.loss == "is":
    # TODO: Implement masking
    loss = ImportanceSamplingLoss(sample_embs, torch.FloatTensor(freqs))
else:
    raise ValueError("AAAAAAAAAAAAA")
masked_lm = MaskedLM(vocab, custom_model, loss, noise_rate=0.15)

Sanity checks

In [None]:
tokens = masked_lm.masker(batch["input"]["tokens"])

In [None]:
hidden_states = masked_lm.model[:2](tokens)

In [None]:
hidden_states

tensor([[[ 7.9257e-01,  3.3017e-01, -1.9341e+00,  ...,  2.2470e+00,
           8.4122e-04, -1.4128e+00],
         [ 7.8955e-01, -5.5334e-01, -8.7381e-01,  ...,  1.2595e+00,
          -1.3412e-01, -1.7660e+00],
         [ 2.9464e+00,  9.1522e-02,  6.1226e-01,  ...,  1.1923e+00,
           1.6033e+00, -5.0738e-01],
         ...,
         [ 3.5568e-01, -2.5305e+00,  7.2234e-01,  ...,  1.5066e+00,
          -3.3546e-01, -1.9514e-01],
         [ 1.9070e+00, -1.4500e+00,  6.1093e-01,  ..., -1.5661e-01,
           1.2132e+00, -3.2763e-01],
         [-1.5238e+00, -5.9411e-01,  5.3762e-01,  ...,  1.0543e-01,
          -1.8073e+00,  5.3456e-01]],

        [[ 7.9084e-01,  2.6551e-01, -2.2540e+00,  ...,  2.0175e+00,
          -1.1645e-01, -1.2852e+00],
         [ 5.6060e-01, -4.4838e-01, -2.5647e-01,  ...,  1.2451e+00,
          -4.1708e-01, -1.5916e+00],
         [ 2.6489e+00,  3.6799e-01,  3.7464e-01,  ...,  1.1297e+00,
           2.3959e+00, -6.0559e-01],
         ...,
         [-3.3035e-01, -2

In [None]:
masked_lm.model[2](hidden_states)

tensor([[[  2.3681,   2.3155,  -1.8459,  ...,   4.3093,  -0.5433,   4.8540],
         [ -1.8181,  -0.3749,  -0.1832,  ...,   5.3945,   2.9493,  -3.0373],
         [ -5.8419,   0.0406,   6.6328,  ...,  -5.5296,  -0.8188,   7.1563],
         ...,
         [  1.1825,   3.1556,   6.2975,  ...,   3.0837, -10.9357,   7.4305],
         [  5.6680,   4.5227,   3.6175,  ...,  -1.2479,  -7.5419,   6.1736],
         [ -5.2398,   4.8604,   5.4095,  ...,   9.5512,   8.7998,  -1.7505]],

        [[ -0.6243,   4.0400,  -2.1453,  ...,   2.2620,   0.6276,   4.3817],
         [  3.9622,  -2.1417,  -1.3703,  ...,  10.5944,   1.8270,  -5.2715],
         [ -2.1944,  -0.5615,   4.8347,  ...,  -5.5999,  -1.3051,   6.3563],
         ...,
         [  4.8018,   1.1742,   2.9209,  ...,   5.4262,  -2.5404,   5.8397],
         [  6.8405,   4.9000,   3.9885,  ...,  -1.9911,  -8.4598,   3.8191],
         [ -6.5504,   7.3194,   1.9782,  ...,   2.4225,   7.8605,  -0.0655]],

        [[ -0.3907,   3.9685,  -2.0745,  ...

Sanity checks

In [None]:
from allennlp.training import Trainer

optimizer = torch.optim.Adam(masked_lm.parameters(), lr=0.01)

trainer = Trainer(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=1,
)

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


In [None]:
trainer.train()

accuracy: 0.0473, loss: 12.3293 ||: 100%|██████████| 16/16 [01:22<00:00,  3.84s/it]
accuracy: 0.1600, loss: 7.7848 ||: 100%|██████████| 16/16 [00:33<00:00,  3.38s/it]


{'best_epoch': 0,
 'peak_cpu_memory_MB': 2642.382848,
 'training_duration': '00:01:56',
 'training_start_epoch': 0,
 'training_epochs': 0,
 'epoch': 0,
 'training_accuracy': 0.047279228855721396,
 'training_loss': 12.32925209403038,
 'training_cpu_memory_MB': 2642.382848,
 'validation_accuracy': 0.1599829889112903,
 'validation_loss': 7.784813791513443,
 'best_validation_accuracy': 0.1599829889112903,
 'best_validation_loss': 7.784813791513443}

In [None]:
masked_lm(**batch)["logits"]

tensor([[[ 1.2464e+01,  4.8041e-01,  9.9983e+00,  ..., -2.2404e-01,
          -1.3467e+00, -1.4002e+00],
         [ 1.0214e+01, -4.7934e-01,  1.1212e+01,  ...,  1.1572e+00,
          -1.2680e+00, -2.0720e+00],
         [ 1.1160e+01, -6.3993e-01,  1.0217e+01,  ...,  1.4380e-01,
          -7.9035e-01, -1.1506e+00],
         ...,
         [ 1.1929e+01,  5.9776e-01,  1.0373e+01,  ...,  9.4615e-01,
          -1.8114e+00, -2.3422e+00],
         [ 1.2395e+01,  2.2977e-01,  9.2476e+00,  ..., -7.4081e-01,
          -8.9187e-01, -9.6294e-01],
         [ 1.4495e+01,  1.3644e+00,  8.9255e+00,  ..., -8.7293e-01,
          -1.1923e+00, -8.4349e-01]],

        [[ 1.2144e+01,  5.2355e-01,  1.0376e+01,  ..., -4.4624e-01,
          -1.2708e+00, -1.6591e+00],
         [ 9.9819e+00, -3.7882e-01,  1.1470e+01,  ...,  1.2626e+00,
          -1.3926e+00, -2.4874e+00],
         [ 1.3420e+01,  4.3167e-01,  7.9241e+00,  ...,  7.5914e-01,
          -1.2522e+00, -1.0451e+00],
         ...,
         [ 1.4919e+01,  1

In [None]:
hidden_states = masked_lm.model[:2](tokens)

In [None]:
masked_lm(**batch)["logits"].argmax(2)

tensor([[ 7,  5,  4,  ...,  5, 43,  9],
        [ 7,  5,  0,  ...,  0,  0,  0],
        [ 7,  5,  5,  ..., 43,  7,  9],
        ...,
        [ 7,  5, 35,  ...,  9,  5,  9],
        [ 7,  5,  5,  ...,  7,  7,  9],
        [ 7,  5,  5,  ...,  5,  4,  9]])

In [None]:
trainer = Trainer(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=1,
)

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


In [None]:
trainer.train()

accuracy: 0.0852, loss: 7.4609 ||:  31%|███▏      | 5/16 [00:37<01:19,  7.25s/it]

In [None]:
masked_lm(**batch)["logits"].argmax(2)

In [None]:
trainer = Trainer(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=3,
)

In [None]:
trainer.train()

# Manually Check Outputs

TODO: Implement manual checks for negative sampling loss as well

In [None]:
def to_words(arr):
    arr = to_np(arr)
    

# Predict and Evaluate