# Sequence masking with PyTorch

Resources:
- Masking:
    - [Difference between `src_mask` and `src_key_padding_mask` in PyTorch Transformer layers (from StackOverflow)](https://stackoverflow.com/questions/62170439/difference-between-src-mask-and-src-key-padding-mask)
    - [UvA (University of Amsterdam) DL tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html)
    - [Judit Ács' blog post](https://juditacs.github.io/2018/12/27/masked-attention.html) (but watch out: **the attention matrix is not square!**)
- Masked language modeling:
    - [Kaggle notebook](https://www.kaggle.com/code/mojammel/masked-language-model-with-pytorch-transformer) (very similar to the PyTorch tutorial below)
    - [MLM with BERT blog post](https://towardsdatascience.com/masked-language-modelling-with-bert-7d49793e5d2c)
    - [PyTorch transformer tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
    - [Tutorial with TensorFlow](https://keras.io/examples/nlp/masked_language_modeling/#create-bert-model-pretraining-model-for-masked-language-modeling) (this is actually a good reference, with tensor shapes etc.)
    - [Tutorial with PyTorch](https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial)

In [12]:
import sys
import torch

sys.path.append('../../modules/')

from models import TransformerClassifier, FFNN

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

Generate a vocabulary.

In [2]:
# Vocabulary size.
q = 8

vocab = torch.arange(q).to(dtype=torch.int64)
mask_idx = vocab.max() + 1

# Enalarge the vocabulary with the special tokens.
vocab = torch.hstack([vocab, torch.Tensor(mask_idx).to(dtype=torch.int64)])

vocab

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

Generate sequences of tokens.

**Note:** in the case considered all sequences have the same length and therefore no padding is (ever) needed.

In [3]:
# Tree depth.
l = 8

batch_size = 64
seq_len = 2 ** l

sequences = torch.randint(q, (batch_size, seq_len))

sequences

tensor([[1, 1, 2,  ..., 4, 7, 2],
        [3, 1, 4,  ..., 7, 5, 0],
        [5, 5, 7,  ..., 7, 3, 1],
        ...,
        [5, 3, 2,  ..., 5, 5, 6],
        [2, 4, 4,  ..., 6, 7, 6],
        [2, 3, 0,  ..., 3, 4, 0]])

## Building attention masks

Given the sequences, generate trainable input embeddings for them (just the semantic part, we're skipping the positional encoding here as it's not essential).

In [4]:
hidden_dim = 128

embedding_layer = torch.nn.Embedding(
    num_embeddings=vocab.shape[0],
    embedding_dim=hidden_dim
)

input_sequence_embeddings = embedding_layer(sequences)

input_sequence_embeddings.shape

torch.Size([64, 256, 128])

Instantiate a transformer Encoder model.

**Note:** by convention, we stick with having the batch dimension as the first one.

In [5]:
# A single encoder layer to be used in the full stack.
encoder_layer = torch.nn.TransformerEncoderLayer(
    d_model=hidden_dim,
    nhead=1,
    dim_feedforward=2048,
    batch_first=True
)

# Stack of encoder layers.
transformer_encoder = torch.nn.TransformerEncoder(
    encoder_layer,
    num_layers=1
)

encoder_output = transformer_encoder(input_sequence_embeddings)

encoder_output.shape



torch.Size([64, 256, 128])

Generate masked sequences from the original one to perform masked language modeling (MLM): every token in every sequence is converted to the masked one (set conventionally) with some probability.

**Mask:** we pass the mask as the encoder's `src_key_padding_mask` option, which means it should have shape `(batch_size, seq_len)` and (if of boolean type) contain `False` when no masking is needed and `True` when it is. In practice, it's obtained by simply checking the masked sequences against the padding value.

In [6]:
masked_sequences = sequences.clone()

masking_rate = 0.1

for i in range(sequences.shape[0]):
    for j in range(sequences.shape[1]):
        if torch.rand(1) < masking_rate:
            masked_sequences[i, j] = mask_idx

mask = (masked_sequences == mask_idx)

**Questions:**
- Should we pass the masked sequences or the original ones (always along with the mask) to the encoder?

Answer: we should pass the masked sequences to the encoder, and then use the decoder to generate logits for every token in every sequence and compare this with the ground truth (non-masked sequences) via the cross-entropy loss.

What to do with the **masked embeddings**?
 
 - If we pass the masked sequences, we should have input embeddings be created for the `<mask>` token, so it should be explicitly modeled. Could we indicate it as a padding token as we instantiate the `Embedding` layer for the input embeddings?
 
A: the `<mask>` token should be explicitly modeled as part of the vocabulary. It shouldn't be among the possible **predicted** tokens though.

- If we pass the original sequences we're leaving the original embeddings for the masked tokens, is the mask sufficient to tell the model to ignore them?

A: masking is not sufficient. Indeed, we should pass the masked sequences to the encoder.

- What about the gradient? Should we "disconnect" the masked token from the compute graph?

A: still not clear. The attention mask should avoid connecting the embeddings of the masked tokens with the loss, but it's just a guess (and what about residual connections?).

What about the **full model** (with a decoder as well)?

- How do we tell which are the masked tokens to predict for in the sequences? The decoder won't know anything about the masking, so it'll have no way to distinguish between a masked token and a not masked one.

A: it's true that the decoder won't know explicitly which are the masked tokens, but indirectly it will because we pass the masked sequences as input and because in the end we'll probably have to select ony the loss terms correspnding to the predictions for the masked tokens.

- Should the decoder predict for one masked token at the time (how to select them? Mask them one by one?) or all masked tokens together (in which case, see the previous point)?

A: the decoder should predict for all the tokens, masked and non-masked, so that for an input of shape `(batch_size, seq_len)` we get a final output of shape `(batch_size, seq_len, vocab_size)`, the last dimension corresponding to the logits over the vocabulary (restricted to the token that should be predicted, e.g. not `<mask>`). This means that we'll effectively predict for the entire sequence of tokens every time and computing the cross-entropy loss with the non-masked sequences we'll get a tensor of loss values of shape `(batch_size, seq_len)`. Now we can choose (not clear which is the right choice though!) whether to use compute the final loss as the mean over all the entries of this tensor or just over the ones corresponding to the masked tokens (by applying the mask over the tensor of loss values).

In [7]:
encoder_output_masked = transformer_encoder(
    embedding_layer(masked_sequences),
    src_key_padding_mask=mask
)

encoder_output_masked

tensor([[[-1.6878,  0.2216,  0.4210,  ...,  1.1461, -0.0612, -0.3649],
         [-1.7294,  0.3450,  0.3150,  ...,  1.2427, -0.0351, -0.5012],
         [-0.1237, -1.6333, -0.2061,  ...,  0.5824,  0.6188,  0.5238],
         ...,
         [-0.2125, -1.2664, -0.4888,  ...,  1.8632,  0.5314,  0.0293],
         [-3.1080,  0.1778, -0.1151,  ...,  0.0234, -1.3789, -0.0511],
         [-0.1504, -1.5409, -0.1529,  ...,  0.7163,  0.9293,  0.3315]],

        [[-0.3278,  0.1928,  0.7338,  ..., -0.4091,  0.5358, -0.3875],
         [-1.6437,  0.3576,  0.1702,  ...,  1.3893,  0.1358, -0.4145],
         [ 1.6211,  0.9031, -2.3291,  ...,  1.2343,  1.4894,  0.7451],
         ...,
         [-3.3954,  0.1630, -0.5129,  ...,  0.0550, -1.2563,  0.0566],
         [-0.5306, -0.6755,  1.5678,  ..., -0.7571, -0.3739,  1.0078],
         [ 0.2511,  0.1878,  0.0744,  ...,  0.7619,  0.0958, -2.2071]],

        [[-0.4421, -0.3665,  1.6370,  ..., -0.7502, -0.0875,  0.8530],
         [-0.2229, -0.2698,  1.5723,  ..., -0

The final output layer maps each token in each sequence to a set of logits (or probabilities) over the "proper" vocabulary (excluding special tokens), so it has tensors of shape `(batch_size, seq_len, hidden_dim)` as input and outputs tensors of shape `(batch_size, seq_len, vocab_size)`.

In [8]:
output_layer = torch.nn.Linear(
    in_features=hidden_dim,
    out_features=q
)

output_logits = output_layer(encoder_output_masked)
output_probs = torch.nn.Softmax(dim=-1)(output_logits)

output_logits.shape, output_probs.shape, output_probs.sum(dim=-1)

(torch.Size([64, 256, 8]),
 torch.Size([64, 256, 8]),
 tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
        grad_fn=<SumBackward1>))

Observation on the loss function (**should masking be considered at this stage?**):
- We need the model (output layer) to output the predicted logits for each token in each sequence, i.e. a tensor of shape `(batch_size, seq_len, vocab_size)`, where the last dimension represents the logits over the vocabulary.
- PyTorch's `CrossEntropyLoss` function accepts the logits and the true labels as the input, with the latter either as they are (class labels) or with one-hot encoding. In this case, if the predicted logits are put in the shape PyTorch expects (see point below), no one-hot encoding is needed for the targets.
- Without any aggregation, we should have a value for the loss for each token in each sequence, which gives a tensor of loss values of shape `(batch_size, seq_len)`. PyTorch assumes that the predicted logits have shape `(batch_size, n_classes, [additional dims])`, so the last two dimensions of the output logits need to be switched.
- **Guess:** probably the final loss should be computed only for the masked tokens, so we have to apply the mask to the loss tensor before aggregating the values.

In [78]:
loss_fn = torch.nn.CrossEntropyLoss(
    reduction='none'  # Default: 'mean'.
)

loss_tensor = loss_fn(
    torch.permute(output_logits, dims=(0, 2, 1)),
    sequences
)

loss_tensor, loss_tensor.shape

(tensor([[2.0316, 2.1514, 2.6266,  ..., 2.3532, 2.7209, 2.6511],
         [2.4452, 1.9075, 2.0654,  ..., 2.7135, 1.5421, 2.3395],
         [1.5090, 1.5543, 2.7127,  ..., 2.7202, 1.7771, 1.9542],
         ...,
         [1.5834, 2.4822, 2.6721,  ..., 2.6365, 1.6373, 1.5273],
         [2.7293, 2.3288, 2.1734,  ..., 1.4918, 2.7459, 1.7663],
         [2.6808, 2.4252, 2.3681,  ..., 1.8896, 2.4122, 2.3091]],
        grad_fn=<ViewBackward0>),
 torch.Size([64, 256]))

In [99]:
# Use masking if predicting only for the masked tokens,
# drop the mask (!) to predict for the whole sequence.
loss = loss_tensor[mask].mean()

loss

tensor(2.1178, grad_fn=<MeanBackward0>)

## Building a model for MLM

In [10]:
model = TransformerClassifier(
    seq_len=seq_len,
    embedding_size=hidden_dim,
    n_tranformer_layers=1,
    n_heads=1,
    vocab_size=vocab.shape[0],
    n_special_tokens=1,
    embedding_agg=None,
    decoder_hidden_sizes=[],
    decoder_activation='identity',  # 'identity' --> Output logits
    decoder_output_activation='softmax'
)

# Output shape: (batch_size, seq_len, vocab_size) (excluding
# the special tokens from the vocabulary, which shouldn't be
# predicted).
model(masked_sequences, src_key_padding_mask=mask)

tensor([[[0.1094, 0.1842, 0.1594,  ..., 0.0278, 0.1166, 0.1847],
         [0.0845, 0.2108, 0.1485,  ..., 0.0202, 0.1471, 0.1551],
         [0.2447, 0.0803, 0.0553,  ..., 0.0861, 0.2138, 0.0924],
         ...,
         [0.1541, 0.2808, 0.0412,  ..., 0.2141, 0.0656, 0.0513],
         [0.2992, 0.0872, 0.1047,  ..., 0.0527, 0.0574, 0.1500],
         [0.3020, 0.1788, 0.0449,  ..., 0.0920, 0.0599, 0.0727]],

        [[0.1809, 0.0874, 0.1018,  ..., 0.0678, 0.0632, 0.2991],
         [0.0956, 0.1852, 0.2221,  ..., 0.0157, 0.1728, 0.1224],
         [0.1112, 0.1424, 0.0613,  ..., 0.0654, 0.1803, 0.1098],
         ...,
         [0.1823, 0.0813, 0.0844,  ..., 0.0697, 0.1160, 0.1943],
         [0.2163, 0.0684, 0.1060,  ..., 0.0962, 0.0939, 0.1125],
         [0.1620, 0.1368, 0.0704,  ..., 0.1239, 0.1167, 0.1024]],

        [[0.1672, 0.0525, 0.1762,  ..., 0.0643, 0.2061, 0.1205],
         [0.1483, 0.0446, 0.1251,  ..., 0.0762, 0.1920, 0.1334],
         [0.2635, 0.0605, 0.0832,  ..., 0.0741, 0.1427, 0.

In [117]:
loss_fn = torch.nn.CrossEntropyLoss(
    reduction='none'
)

# Use masking if predicting only for the masked tokens,
# drop the mask (!) to predict for the whole sequence.
loss = loss_fn(
    torch.permute(model(masked_sequences, src_key_padding_mask=mask), (0, 2, 1)),
    sequences
)[mask].mean()

loss

tensor(2.0824, grad_fn=<MeanBackward0>)