# MLM training starting from a trained classification model

__Objective:__ start from a model for classification (root inference) and swap its head to get a model able to perform MLM for further MLM training (essentially it's the reverse operation w.r.t. pre-training + fine-tuning).

In [None]:
import sys
import torch

sys.path.append('../../modules/')

from utilities import read_data
from pytorch_utilities import load_checkpoint
from models import TransformerClassifier, replace_classification_head_with_decoder
from model_evaluation_tree_language import compute_accuracy
from masked_language_modeling import mask_sequences, compute_masked_accuracy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [None]:
DATA_PATH = '../../data/mlm_data/slrm_data/labeled_data_fixed_4_8_1.0_0.00000.npy'
VALIDATION_DATA_PATH = '../../data/mlm_data/slrm_data/labeled_data_fixed_validation_4_8_1.0_0.00000.npy'

Load data.

In [None]:
q, k, sigma, epsilon, roots_train, leaves_train, rho = read_data(DATA_PATH, seed=0)
_, _, _, _, roots_test, leaves_test, rho_validation = read_data(VALIDATION_DATA_PATH, seed=0)

leaves_test = leaves_test[:5000, :]

# Data preprocessing.
leaves_train = torch.from_numpy(leaves_train).to(device=device).to(dtype=torch.int64)
leaves_test = torch.from_numpy(leaves_test).to(device=device).to(dtype=torch.int64)

roots_train = torch.nn.functional.one_hot(
    torch.from_numpy(roots_train).to(dtype=torch.int64), num_classes=q
).to(dtype=torch.float32).to(device=device)
roots_test = torch.nn.functional.one_hot(
    torch.from_numpy(roots_test).to(dtype=torch.int64), num_classes=q
).to(dtype=torch.float32).to(device=device)

Generate a vocabulary.

In [None]:
vocab = torch.arange(q).to(dtype=torch.int64)

vocab

Instantiate a classification model (not trained, for simplicity) - in the actual case this will be loaded with `load_checkpoint`.

**Note:** the classification model doesn't know anything about the `<mask>` token.

In [None]:
seq_len = int(2 ** k)
embedding_size = 128
vocab_size = vocab.shape[0]

classification_model = TransformerClassifier(
    seq_len=seq_len,
    embedding_size=embedding_size,
    n_tranformer_layers=2,  # Good: 4
    n_heads=1,
    vocab_size=vocab_size,
    encoder_dim_feedforward=2 * embedding_size,
    positional_encoding=True,
    n_special_tokens=0,  # We assume the special tokens correspond to the last `n_special_tokens` indices.
    embedding_agg='mean',
    decoder_hidden_sizes=[64],  # Good: [64]
    decoder_activation='relu',  # Good: 'relu'
    decoder_output_activation='identity'
).to(device=device)

Test the untrained classification model on the validation data (accuracy should be $\sim 1/q$).

In [None]:
# Output shape: (batch_shape, q) (q logits for each sample in the batch).
with torch.no_grad():
    val_pred = classification_model(leaves_test)

val_pred.shape, compute_accuracy(val_pred, leaves_test)

Replace the classification head with an MLM head.

See: https://discuss.pytorch.org/t/expand-an-existing-embedding-and-linear-layer-nan-loss-value/55670/2

Notes:
- Assuming the vocabulary used for the classification model did not contain the `<mask>` token, when switching to an MLM model the input embedding layer is enlarged to take one more token (the mask) into account. The weights correponding to the other tokens (already seen) are copied from the original layer.
- We assume the classification model has seen symbols `0, ..., q-1` and that the mask token corresponds to symbol `q`.

In [None]:
mlm_model = replace_classification_head_with_decoder(
    original_model=classification_model,
    n_classes=q,
    device=device,
    decoder_hidden_dim=[64]
)

In [None]:
# Output shape: (batch_size, seq_len, embedding_size).
with torch.no_grad():
    mlm_pred = mlm_model(leaves_test)

mlm_pred.shape

In [None]:
# The function assumes that the mask token corresponds to value `q`.
mask_idx = max(vocab) + 1

leaves_test_masked, mask = mask_sequences(
    sequences=leaves_test,
    mask_rate=0.5,
    reshaped_mask_idx=mask_idx * torch.ones_like(leaves_test),
    device=device,
    single_mask=False
)

# Output shape: (batch_size, seq_len, embedding_size).
with torch.no_grad():
    mlm_pred_masked = mlm_model(leaves_test_masked).detach()

mlm_pred_masked.shape, compute_masked_accuracy(mlm_pred_masked, leaves_test, mask)