# Use transformers with padded sequences

In [None]:
import os
import sys
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
import torch

sys.path.append('../../modules/')

from logger_tree_language import get_logger
from tree_generation_nontrivial_topology import pad_sequence
from tree_language_vocab import TreeLanguageVocab
from models import TransformerClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = get_logger('transformers_with_padding')

%load_ext autoreload
%autoreload 2

## Load and preprocess data

In [None]:
DATA_PATH = '../../data/topological_tree_data/topological_tree_data_4_5_1.0_0.0.pkl'

Load data.

In [None]:
q, l, sigma, epsilon = DATA_PATH.split('.pkl')[0].split('_')[-4:]

q = int(q)
l = int(l)
sigma = float(sigma)
epsilon = float(epsilon)
max_seq_len = 2 ** l

logger.info(f'Loading data from: {DATA_PATH}')
logger.info(f'q = {q} | l = {l} | sigma = {sigma} | epsilon = {epsilon}')

with open(DATA_PATH, 'rb') as f:
    data = pickle.load(f)

Preprocess data and train-test split.

In [None]:
vocab = TreeLanguageVocab(
    n_symbols=q,
    use_pad_token=True,
    use_mask_token=True
)

roots = np.array([t[1] for t in data])
leaves = np.vstack([
    pad_sequence(t[0], max_seq_len, vocab('<pad>'))
    for t in data
])

leaves = torch.from_numpy(leaves).to(device=device).to(dtype=torch.int64)
roots = torch.from_numpy(roots).to(device=device).to(dtype=torch.int64)

shuffled_indices = np.random.choice(range(leaves.shape[0]), leaves.shape[0], replace=False)

n_samples_test = int(0.2 * leaves.shape[0])

leaves_train = leaves[:-n_samples_test, ...]
leaves_test = leaves[-n_samples_test:, ...]

roots_train = roots[:-n_samples_test, ...]
roots_test = roots[-n_samples_test:, ...]

leaves_train.shape, roots_train.shape, leaves_test.shape, roots_test.shape

## Masked language modelling

MLM with padded sequences works as with "all complete" sequences, except for the facts that:
- At inference time we need to tell the model which tokens are the padded in each sequence (via a mask).
- We should only mask tokens that are NOT padded.

**Notes:**
- When using padded sequences, a mask **MUST** be passed to the model during its forward pass to identify all the padded tokens in all the sequences in the batch.
- The mask must have the same shape as the data used in the forward pass, `(batch_size, max_seq_len)`, and must be boolean with `True` in the positions corresponding to the padding tokens and `False` everywhere else.
- The mask must be passed as the `src_key_padding_mask` of the model's foward pass (`__call__` method).

In [None]:
embedding_size = 128
vocab_size = len(vocab)

# Instantiate model.
model_params = dict(
    seq_len=max_seq_len,
    embedding_size=embedding_size,
    n_tranformer_layers=2,  # Good: 4
    n_heads=1,
    vocab_size=vocab_size,
    encoder_dim_feedforward=2 * embedding_size,
    positional_encoding=True,
    n_special_tokens=len(vocab.special_symbols),  # We assume the special tokens correspond to the last `n_special_tokens` indices.
    embedding_agg=None,
    decoder_hidden_sizes=[64],  # Good: [64]
    decoder_activation='relu',  # Good: 'relu'
    decoder_output_activation='identity'
)

model = TransformerClassifier(
    **model_params
).to(device=device)

In [None]:
# Test of a forward pass with masking on tha padded positions.
batch_size = 32

batch = leaves_test[:batch_size, ...]
batch_src_key_padding_mask = (batch == vocab('<pad>'))

targets = roots_test[:batch_size]

model(batch, src_key_padding_mask=batch_src_key_padding_mask).detach()

**Masked language modelling:** just add the usual MLM procedure on top of the previous one.

In [None]:
from masked_language_modeling import mask_sequences, compute_masked_accuracy

In [None]:
masked_batch, masked_symbols_mask = mask_sequences(
    sequences=batch,
    mask_rate=0.1,
    reshaped_mask_idx=(torch.ones(batch.shape, dtype=batch.dtype) * vocab('<mask>')).to(device=device),
    device=device,
    src_key_padding_mask=batch_src_key_padding_mask,
    single_mask=False
)

# Note: having a fraction of masked symbol per sequence close to `mask_rate` is
#       more difficult if padded sequences are used, as effectively the sequences
#       we can actually mask are shorted than `max_seq_len`.
(masked_batch == vocab('<mask>')).to(dtype=torch.float32).mean(dim=-1).mean().detach()

In [None]:
masked_batch

In [None]:
# MLM inference.
masked_pred = model(
    masked_batch,
    src_key_padding_mask=batch_src_key_padding_mask
).detach()

masked_pred

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(
    reduction='none'
)

loss = loss_fn(
    torch.permute(
        masked_pred,
        (0, 2, 1)
    ),
    batch
)[masked_symbols_mask].mean().detach().cpu().numpy()

masked_accuracy = compute_masked_accuracy(
    masked_pred,
    batch,
    masked_symbols_mask
).detach().cpu().numpy()

print(f'Loss: {loss} | Masked accuracy: {masked_accuracy}')

## Root inference

For root inference with padded sequences, we just need to tell the model which tokens are padded at inference time via a mask.

In [None]:
from model_evaluation_tree_language import compute_accuracy

In [None]:
embedding_size = 128
vocab_size = len(vocab)

# Instantiate model.
root_inference_model_params = dict(
    seq_len=max_seq_len,
    embedding_size=embedding_size,
    n_tranformer_layers=2,  # Good: 4
    n_heads=1,
    vocab_size=vocab_size,
    encoder_dim_feedforward=2 * embedding_size,
    positional_encoding=True,
    n_special_tokens=len(vocab.special_symbols),  # We assume the special tokens correspond to the last `n_special_tokens` indices.
    embedding_agg='flatten',
    decoder_hidden_sizes=[64],  # Good: [64]
    decoder_activation='relu',  # Good: 'relu'
    decoder_output_activation='identity'
)

root_inference_model = TransformerClassifier(
    **root_inference_model_params
).to(device=device)

In [None]:
classif_pred = root_inference_model(
    batch,
    src_key_padding_mask=batch_src_key_padding_mask
).detach()

classif_pred

In [None]:
classification_loss_fn = torch.nn.CrossEntropyLoss()

# Can work with targets that are either one-hot encoded
# or with their natural encoding.
classification_loss = classification_loss_fn(
    classif_pred,
    torch.nn.functional.one_hot(targets, num_classes=q).to(dtype=torch.float32)
).detach()

classification_accuracy = compute_accuracy(
    classif_pred,
    torch.nn.functional.one_hot(targets, num_classes=q).to(dtype=torch.float32)
).detach()

print(
    f"Classification loss: {classification_loss}"
    f" | Classification accuracy: {classification_accuracy}"
)