In [1]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline
from overrides import overrides

In [2]:
import time
from contextlib import contextmanager

class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')
    
import functools
import traceback
import sys

def get_ref_free_exc_info():
    "Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory"
    type, val, tb = sys.exc_info()
    traceback.clear_frames(tb)
    return (type, val, tb)

def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = get_ref_free_exc_info() # must!
            raise type(val).with_traceback(tb) from None
    return wrapper

def ifnone(a: Any, alt: Any): return alt if a is None else a

In [3]:
# for papermill
testing = True
debugging = False
seed = 1
char_encoder = "fasttext"
computational_batch_size = 16
batch_size = 16
lr = 4e-3
lr_schedule = "slanted_triangular"
epochs = 6 if not testing else 1
hidden_sz = 128
dataset = "jigsaw"
n_classes = 6
max_seq_len = 512
download_data = False
ft_model_path = "../data/jigsaw/ft_model.txt"
max_vocab_size = 40000
dropouti = 0.2
dropoutw = 0.0
dropoute = 0.2
dropoutr = 0.3 # TODO: Implement
val_ratio = 0.0
use_augmented = False
freeze_embeddings = True
mixup_ratio = 0.0
discrete_mixup_ratio = 0.0
attention_bias = True
weight_decay = 0.
bias_init = True
neg_splits = 1
num_layers = 2
rnn_type = "lstm"
pooling_type = "augmented_multipool" # attention or multipool or augmented_multipool
model_type = "standard"
use_word_level_features = True
use_sentence_level_features = True
bucket = True
compute_thres_on_test = True
find_lr = False
permute_sentences = False
run_id = None

In [4]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    char_encoder=char_encoder,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    lr=lr,
    lr_schedule=lr_schedule,
    epochs=epochs,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    dropouti=dropouti,
    dropoutw=dropoutw,
    dropoute=dropoute,
    dropoutr=dropoutr,
    val_ratio=val_ratio,
    use_augmented=use_augmented,
    freeze_embeddings=freeze_embeddings,
    attention_bias=attention_bias,
    weight_decay=weight_decay,
    bias_init=bias_init,
    neg_splits=neg_splits,
    num_layers=num_layers,
    rnn_type=rnn_type,
    pooling_type=pooling_type,
    model_type=model_type,
    use_word_level_features=use_word_level_features,
    use_sentence_level_features=use_sentence_level_features,
    mixup_ratio=mixup_ratio,
    discrete_mixup_ratio=discrete_mixup_ratio,
    bucket=bucket,
    compute_thres_on_test=compute_thres_on_test,
    permute_sentences=permute_sentences,
    find_lr=find_lr,
    run_id=run_id,
)

In [5]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]  # pylint: disable=invalid-name

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util
from allennlp.data.dataset_readers import DatasetReader

In [6]:
DATA_ROOT = Path("../data/jigsaw/")

# Dataset

In [7]:
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class JigsawLMDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len
        
    def _clean(self, x: str) -> str:
        """
        Maps a word to its desired output. Will leave as identity for now.
        In the future, will change to denoising operation.
        """
        return x

    @overrides
    def text_to_instance(self, tokens: List[str]) -> Instance:
        sentence_field = TextField([Token(x) for x in tokens],
                                   self.token_indexers)
        fields = {"input": sentence_field}
        output_sentence_field = TextField([Token(self._clean(x)) for x in tokens],
                                          self.token_indexers)
        fields["output"] = output_sentence_field
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                self.tokenizer(row["comment_text"]),
            )

In [13]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper, ELMoTokenCharactersIndexer
from allennlp.data.token_indexers import SingleIdTokenIndexer

if config.char_encoder == "cnn":
    token_indexer = ELMoTokenCharactersIndexer()
else:
    token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) # Temporary

_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words

def tokenizer(x: str):
        return [w.text for w in
                _spacy_tok(x)[:config.max_seq_len]]

In [14]:
reader = JigsawLMDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer},
)
train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                          "val.csv",
                                                                          "test_proced.csv"])

1000it [00:04, 240.69it/s]
1000it [00:03, 325.28it/s]
1000it [00:03, 294.16it/s]


### Build Vocab

In [15]:
from allennlp.data.vocabulary import Vocabulary

In [16]:
full_ds = train_ds + test_ds + val_ds
vocab = Vocabulary.from_instances(full_ds, tokens_to_add={"tokens": ["[MASK]"]},
                                  max_vocab_size=config.max_vocab_size)

02/23/2019 15:59:19 - INFO - allennlp.data.vocabulary -   Fitting token dictionary from dataset.
100%|██████████| 3000/3000 [00:00<00:00, 8527.69it/s]


### Build word to indices mapping

In [17]:
if config.char_encoder == "cnn":
    # TODO: Speed up
    with timer("Building character indexes"):
        word_id_to_char_idxs = []
        freqs = []
        for w, freq in word_freqs.items():
            char_idxs = token_indexer.tokens_to_indices([Token(w)], None, "tokens")["tokens"][0]
            word_id_to_char_idxs.append(char_idxs)
            freqs.append(freqs)
        word_id_to_char_idxs = np.array(word_id_to_char_idxs)
        freqs = np.array(freqs)

    word_id_to_char_idxs = torch.LongTensor(word_id_to_char_idxs)

### Build Iterator

In [18]:
from allennlp.data.iterators import BucketIterator, DataIterator
iterator = BucketIterator(
        batch_size=config.batch_size, 
        biggest_batch_first=config.testing,
        sorting_keys=[("input", "num_tokens")],
)
iterator.index_with(vocab)

In [19]:
next(iterator(train_ds))["input"]["tokens"].shape

torch.Size([8, 512])

In [20]:
batch = next(iterator(train_ds))

# Building token embedder

In [21]:
if config.char_encoder == "cnn":
    from allennlp.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder
    from allennlp.modules.elmo import _ElmoCharacterEncoder

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
    char_encoder = _ElmoCharacterEncoder(
        options_file=options_file, 
        weight_file=weight_file,
        requires_grad=True
    )

# char_encoder

# sample_idxs = next(iterator(train_ds))["tokens"]["tokens"]

# char_encoder(sample_idxs)

# Masked Language Model

In [22]:
from pytorch_pretrained_bert.modeling import (BertConfig, BertForMaskedLM, 
                                              BertEncoder, BertPooler, BertOnlyMLMHead)

In [23]:
bert_config = BertConfig(
        config.max_vocab_size, hidden_size=512, num_attention_heads=4,
        num_hidden_layers=6, intermediate_size=512 * 4,
)

In [24]:
bert_config

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "max_position_embeddings": 512,
  "num_attention_heads": 4,
  "num_hidden_layers": 6,
  "type_vocab_size": 2,
  "vocab_size": 40000
}

### The encoder

For now, temporarily use fasttext embeddings

In [25]:
from tqdm import tqdm
import warnings

def get_coefs(word,*arr): return word, np.asarray(arr, dtype='float32')

def get_fasttext_embeddings(model_path: str, vocab: Vocabulary):
    prog_bar = tqdm(open(model_path, encoding="utf8", errors='ignore'))
    prog_bar.set_description("Loading embeddings")
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in prog_bar
                             if len(o)>100)
    all_embs = np.stack(embeddings_index.values())

    embeddings = np.zeros((bert_config.vocab_size + 5, 300))
    n_missing_tokens = 0
    prog_bar = tqdm(vocab.get_index_to_token_vocabulary().items())
    prog_bar.set_description("Creating matrix")
    for idx, token in prog_bar:
        if idx == 0: continue # keep padding as all zeros
        if idx == 1: continue # Treat unknown words as dropped words
        if token == "[MASK]":
            embeddings[idx, :] = np.random.randn(300) * 0.5
        if token not in embeddings_index:
            n_missing_tokens += 1
            if n_missing_tokens < 10:
                warnings.warn(f"Token {token} not in embeddings: did you change preprocessing?")
            if n_missing_tokens == 10:
                warnings.warn(f"More than {n_missing_tokens} missing, supressing warnings")
        else:
            embeddings[idx, :] = embeddings_index[token]
    
    if n_missing_tokens > 0:
        warnings.warn(f"{n_missing_tokens} in total are missing from embedding text file")
    return embeddings

In [26]:
ft_matrix = get_fasttext_embeddings(str(DATA_ROOT / "ft_model.txt"), vocab)

Loading embeddings: : 317458it [00:27, 11704.95it/s]
  # This is added back by InteractiveShellApp.init_path()
Creating matrix: 100%|██████████| 18864/18864 [00:00<00:00, 178514.85it/s]


In [27]:
def freeze(x):
    x.requires_grad = False
    if hasattr(x, "parameters"):
        for p in x.parameters: freeze(p)

In [48]:
from allennlp.modules import Embedding

In [51]:
class CustomEmbedding(Embedding):
    def __init__(self, num_embeddings, embedding_dim, bert_hidden_sz,
                 freeze_embeddings=False, dropout=0.1, **kwargs):
        super().__init__(num_embeddings, embedding_dim, **kwargs)
        if freeze_embeddings:
            freeze(self.weight)
        self.position_embeddings = nn.Embedding(config.max_seq_len, 
                                                bert_hidden_sz)
        self.linear = nn.Linear(embedding_dim, bert_hidden_sz,
                                bias=False) # Transform dimensions
        self.norm = LayerNorm(bert_hidden_sz)
        self.do = nn.Dropout(0.1)
        
    def forward(self, input_ids, token_type_ids):
        # We won't be using token types since we won't be predicting the next sentence
        seq_length = input_ids.size(1)
        position_ids = (torch.arange(seq_length, dtype=torch.long)
                             .to(input_ids.device)
                             .unsqueeze(0)
                             .expand_as(input_ids))
        word_embs = self.linear(super().forward(input_ids))
        position_embs = self.position_embeddings(position_ids)
        return self.do(self.norm(word_embs + position_embs))

In [52]:
sample_embs = CustomEmbedding(bert_config.vocab_size + 5, 
                              300,
                              bert_config.hidden_size,
                              weight=torch.FloatTensor(ft_matrix))

In [53]:
sample_embs

CustomEmbedding(
  (position_embeddings): Embedding(512, 512)
  (linear): Linear(in_features=300, out_features=512, bias=False)
  (norm): LayerNorm()
  (do): Dropout(p=0.1)
)

In [54]:
bert_encoder = BertEncoder(bert_config)

In [55]:
class CustomBert(nn.Module):
    def __init__(self, embeddings, encoder):
        super().__init__()
        self.embeddings = embeddings
        self.encoder = encoder
        
    def forward(self, input_ids, 
                token_type_ids=None, 
                attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=True)
        return encoded_layers

In [56]:
class BertMLMPooler(nn.Module):
    def forward(self, x: List[torch.FloatTensor]) -> torch.FloatTensor:
        return x[-1] # return final layer only

In [57]:
bert_model = CustomBert(sample_embs, bert_encoder)

In [58]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

In [59]:
import math
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertCustomLMPredictionHead(nn.Module):
    def __init__(self, config, embedding):
        super().__init__()
        # Projections
        self.dense = nn.Linear(config.hidden_size, embedding.shape[1])
        self.transform_act_fn = gelu
        self.LayerNorm = LayerNorm(embedding.shape[1])
        
        # Predictions
        self.decoder = nn.Linear(embedding.size(1), embedding.size(0), 
                                 bias=False)
        self.decoder.weight = embedding
        self.bias = nn.Parameter(torch.zeros(embedding.size(0)))

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        preds = self.decoder(hidden_states) + self.bias
        return preds

In [60]:
bert_mlm_head = BertCustomLMPredictionHead(bert_config, sample_embs.weight)

In [61]:
custom_model = nn.Sequential(
    bert_model,
    BertMLMPooler(),
    bert_mlm_head,
)

# Training

In [62]:
from allennlp.models import Model

In [63]:
class Masker(nn.Module):
    # TODO: Implement copying
    def __init__(self, vocab: Vocabulary, mask_rate: float=0.15):
        super().__init__()
        self.mask_id = vocab.get_token_index("[MASK]")
        self.mask_rate = mask_rate

    def forward(self, x: torch.LongTensor) -> torch.LongTensor:
        mask = torch.ones(x.shape, dtype=torch.uint8).bernoulli(1 - self.mask_rate)
        return x.masked_fill(mask, self.mask_id)

In [64]:
class MaskedLM(Model):
    def __init__(self, vocab: Vocabulary, model: nn.Module):
        super().__init__(vocab)
        self.masker = Masker(vocab)
        self.model = model
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, input: TensorDict, 
                output: TensorDict, **kwargs) -> TensorDict:
        x = self.masker(input["tokens"])
        logits = self.model(x)
        out_dict = {"loss": self.loss(logits.view((-1, logits.size(-1))),
                                                  output["tokens"].view((-1, )))}
        out_dict["logits"] = logits
        return out_dict

In [65]:
masked_lm = MaskedLM(vocab, custom_model)

In [66]:
out_dict = masked_lm(**batch)

In [67]:
out_dict

{'loss': tensor(38.7347, grad_fn=<NllLossBackward>),
 'logits': tensor([[[ 0.0000,  0.0000,  0.6921,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.9607,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.2008,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  2.2032,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2016,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.8598,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.0000,  0.0000,  0.3411,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  2.5390,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.2102,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  2.4451,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0797,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0840,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.0000,  0.0000, -1.3008,

In [68]:
from allennlp.training import Trainer

optimizer = torch.optim.Adam(masked_lm.parameters(), lr=config.lr)

trainer = Trainer(
    model=masked_lm,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=config.epochs,
)



In [69]:
trainer.train()

02/23/2019 16:08:53 - INFO - allennlp.training.trainer -   Beginning training.
02/23/2019 16:08:53 - INFO - allennlp.training.trainer -   Epoch 0/0
02/23/2019 16:08:53 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 5675.966464
02/23/2019 16:08:53 - INFO - allennlp.training.trainer -   Training
loss: 9.9748 ||: 100%|██████████| 63/63 [06:00<00:00,  2.77s/it] 
02/23/2019 16:14:53 - INFO - allennlp.training.trainer -   Validating
loss: 7.4051 ||: 100%|██████████| 63/63 [01:33<00:00,  5.44s/it]
02/23/2019 16:16:26 - INFO - allennlp.training.trainer -                     Training |  Validation
02/23/2019 16:16:26 - INFO - allennlp.training.trainer -   cpu_memory_MB |  5675.966  |       N/A
02/23/2019 16:16:26 - INFO - allennlp.training.trainer -   loss          |     9.975  |     7.405
02/23/2019 16:16:26 - INFO - allennlp.training.trainer -   Epoch duration: 00:07:33


{'peak_cpu_memory_MB': 5675.966464,
 'training_duration': '00:07:33',
 'training_start_epoch': 0,
 'training_epochs': 0,
 'epoch': 0,
 'training_loss': 9.974837212335496,
 'training_cpu_memory_MB': 5675.966464,
 'validation_loss': 7.405061123863099,
 'best_epoch': 0,
 'best_validation_loss': 7.405061123863099}

# Negative Sampling Loss

In [39]:
# Use a different decoder (TODO: Enable sharing of parameters)
# char_decoder = _ElmoCharacterEncoder(
#     options_file=options_file, 
#     weight_file=weight_file,
#     requires_grad=True
# )

In [70]:
class WordCorrection(nn.Module):
    """From the paper `Exploring the Limitations of Language Modeling`"""
    def __init__(self, hidden_sz: int, bottleneck_sz: int=128):
        super().__init__()
        self.l1 = nn.Linear(hidden_sz, bottleneck_sz)
        
    def forward(self, h: torch.FloatTensor, 
                corr: torch.FloatTensor):
        x = self.l1(h)
        return x @ corr

TODO: Implement importance sampling

In [71]:
def build_unigram_noise(freq):
    """build the unigram noise from a list of frequency
    Parameters:
        freq: a tensor of #occurrences of the corresponding index
    Return:
        unigram_noise: a torch.Tensor with size ntokens,
        elements indicate the probability distribution
    """
    total = freq.sum()
    noise = freq / total
    assert abs(noise.sum() - 1) < 0.001
    return noise ** 0.75 # slight modification

In [72]:
from torch.distributions import Categorical

In [73]:
cat = Categorical(probs=torch.rand(100))

In [159]:
class ImportanceSamplingLoss(nn.Module):
    def __init__(self, embedding_generator: nn.Module,
                 probs: np.ndarray, k=10):
        super().__init__()
        self.embedding_generator = embedding_generator
        self.sampler = Categorical(probs=probs)
        self._loss_func = nn.CrossEntropyLoss()
        self.k = k
    
    def get_negative_samples(self, bs) -> torch.LongTensor:
        neg = self.sampler.sample((bs, self.k, ))
        return neg
    
    def get_embeddings(self, idxs: torch.LongTensor) -> torch.FloatTensor:
        """Converts indexes into vectors"""
        return self.embedding_generator(idxs, None) # TODO: Implement general case
    
    def forward(self, y: torch.LongTensor, tgt):
        """
        Expects input of shape
        y: (batch * seq, feature_sz)
        tgt: (batch * seq, )
        """
        bs = y.size(0)
        neg_samples = self.get_negative_samples(bs)
        idxs = torch.cat([tgt.unsqueeze(1), neg_samples], dim=1)
        # TODO: More efficient implementation? (Ask TAs)
        dot_prods = torch.bmm(y.unsqueeze(1), 
                              self.get_embeddings(idxs).transpose(1, 2)).squeeze(1)
        return self._loss_func(dot_prods, torch.zeros(bs, dtype=torch.int64))

Testing

In [160]:
y = torch.randn((3, 7, 512)).view((-1, 512))
tgt = torch.randint(100, (3, 7)).view((-1, ))

In [161]:
neg_samples = cat.sample((y.size(0), 10))

In [162]:
idxs = torch.cat([tgt.unsqueeze(1), neg_samples], dim=1)

In [163]:
sample_embs(idxs, None).shape

torch.Size([21, 11, 512])

In [164]:
y.shape

torch.Size([21, 512])

In [165]:
torch.bmm(y.unsqueeze(1), sample_embs(idxs, None).transpose(1, 2)).squeeze(1).shape

torch.Size([21, 11])

In [166]:
loss = ImportanceSamplingLoss(
    sample_embs, probs=torch.rand(100),
)

In [167]:
loss(y, tgt)

tensor(27.9522, grad_fn=<NllLossBackward>)