# Switch the head of a pretrained model

__Objective:__ load a model pretrained on masked language modeling and replace its decoder head with a newly-initialized classification head, with the weights of all but the new head being kept frozen.

In [176]:
import os
import sys
from copy import deepcopy
import torch

sys.path.append('../../modules/')

from pytorch_utilities import load_checkpoint
from plotting import plot_training_history
from models import FFNN
from training import train_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [124]:
MODEL_DIR = '../../models/mlm_pretraining_1/'

Select the model checkpoint to load.

In [125]:
checkpoint_epochs = sorted([
    int(f.split('_')[-1].split('.')[0])
    for f in os.listdir(MODEL_DIR)
    if '.pt' in f
])

selected_checkpoint_epoch = checkpoint_epochs[-1]

checkpoint_id = [f for f in os.listdir(MODEL_DIR) if f'{selected_checkpoint_epoch}.pt' in f][0]

Load model.

In [126]:
model, optimizer, training_history = load_checkpoint(
    MODEL_DIR,
    checkpoint_id,
    device=device
)

# plot_training_history(training_history)



In [127]:
# Quick test on the output shape (assuming q=4 and l=8).
test_batch = torch.randint(0, 4, (32, 256)).to(device=device)

model(test_batch).shape

torch.Size([32, 256, 4])

Structure of the model (a `TranformerClassifier`, basically an encoder-only transformer with a fully-connected decoder), sequentially (with each item corresponding to an attribute):
- `input_embedding` (a trainable `Embedding` layer).
- `positional_embedding` (a `PositionalEncoding`, if positional encoding is used, otherwise the identity operation).
- `transformer_encoder` (a stack of `TransformerEncoderLayer`s).
- `embedding_agg_layer` (an operation on the latent representations provided by the encoder).
- `decoder` (a fully-connected network mapping the chosen aggregation of the tokens' latent representations into the output).

**Note:** there seems to be an overcounting of parameters between the **single** `TransformerEncoderLayer` instantiated and then passed to the constructor of the `TransformerEncoder` object, i.e. the parameter in the former are counted as parameters of the model even though actually only the copies of it used in the latter are used by the model.

In [128]:
submodules = ['input_embedding', 'positional_embedding', 'encoder_layer', 'transformer_encoder', 'embedding_agg_layer', 'decoder']

for submodule in submodules:
    n_params_submodule = sum([p.numel() for p in getattr(model, submodule).parameters()])
    
    print(submodule, n_params_submodule)

print('\nTotal n parameters (with overcounting of `TransformerEncoderLayer`):', sum([p.numel() for p in model.parameters()]))
print(
    '\nTotal n parameters (without overcounting of `TransformerEncoderLayer`):',
    sum([p.numel() for p in model.parameters()]) - sum([p.numel() for p in model.encoder_layer.parameters()])
)

input_embedding 640
positional_embedding 0
encoder_layer 132480
transformer_encoder 264960
embedding_agg_layer 0
decoder 8516

Total n parameters (with overcounting of `TransformerEncoderLayer`): 406596

Total n parameters (without overcounting of `TransformerEncoderLayer`): 274116


Replace the aggregation operation and the decoder with new ones and freeze the weights of the encoder part.

In [169]:
def replace_decoder_with_classification_head(
        original_model,
        n_classes,
        device,
        head_hidden_dim=[64],
        head_activation='relu'
    ):
    model = deepcopy(original_model)
    
    # Replace the aggregation operation from `None` (no aggregation, as
    # needed for MLM) to `flatten`.
    model.embedding_agg = 'flatten'
    model.embedding_agg_layer = torch.nn.Flatten(start_dim=-2, end_dim=-1)
    
    decoder_input_dim = model.seq_len * model.embedding_size
    
    # Replace the decoder with a FFNN of appropriate size.
    model.decoder = FFNN(
        dims=(
            [decoder_input_dim]
            + head_hidden_dim
            + [n_classes]
        ),
        activation=head_activation,
        output_activation='softmax',
        batch_normalization=False,
        concatenate_last_dim=False
    ).to(device=device)

    return model


def freeze_encoder_weights(model, trainable_modules=['decoder']):
    """
    Freezes the weights in all the submodules of `models` whose name
    doesn't appear in the `trainable_modules` list.
    """
    for submodule in classification_model.named_children():
        if submodule[0] not in trainable_modules:
            for p in submodule[1].parameters():
                p.requires_grad = False
        
        print(
            submodule[0],
            f'| N parameters: {sum([p.numel() for p in submodule[1].parameters()])}'
            '| Parameters trainable:',
            all([p.requires_grad for p in submodule[1].parameters()])
        )

In [209]:
classification_model = replace_decoder_with_classification_head(
    model,
    n_classes=4,
    device=device,
    head_hidden_dim=[64],
    head_activation='relu'
)

freeze_encoder_weights(classification_model, trainable_modules=['decoder'])

input_embedding | N parameters: 640| Parameters trainable: False
positional_embedding | N parameters: 0| Parameters trainable: True
encoder_layer | N parameters: 132480| Parameters trainable: False
transformer_encoder | N parameters: 264960| Parameters trainable: False
embedding_agg_layer | N parameters: 0| Parameters trainable: True
decoder | N parameters: 2097476| Parameters trainable: True


In [210]:
# Quick test on output shape (after having altered
# the model).
model(test_batch).shape, classification_model(test_batch).shape

(torch.Size([32, 256, 4]), torch.Size([32, 4]))

## Check on the frozen weights

Check that after some epochs of (fake) fine-tuning the values of the frozen weights in the fine-tuned model are still the same as in the original pretrained model.

**Note:** this can be checked by running the following with the call to `freeze_encoder_weights` (above) commented out VS not commented out - in the former case, no weights kept frozen, which can be checked below.

In [211]:
training_batch = torch.randint(0, 4, (32, 256)).to(device=device)
training_labels = torch.randint(0, 4, (32,)).to(device=device)

test_batch = torch.randint(0, 4, (32, 256)).to(device=device)
test_labels = torch.randint(0, 4, (32,)).to(device=device)

In [212]:
classification_model, fine_tuning_history = train_model(
    classification_model,
    training_data=(training_batch, training_labels),
    test_data=(test_batch, test_labels),
    n_epochs=10,
    loss_fn=torch.nn.CrossEntropyLoss(),
    learning_rate=0.001,
    batch_size=test_batch.shape[0]
)

2024-05-02 13:52:30,519 - train_model - INFO - Training model
100%|██████████████████████████| 10/10 [00:00<00:00, 101.84it/s, training_accuracy=tensor(0.), training_loss=tensor(1.0207), val_accuracy=tensor(0., device='cuda:0'), val_loss=tensor(1.3029, device='cuda:0')]
2024-05-02 13:52:30,620 - train_model - INFO - Last epoch: 10


In [213]:
for submodule in zip(classification_model.named_children(), model.named_children()):
    s_classif = submodule[0]
    s_pretrained = submodule[1]
    
    n_params = sum([p.numel() for p in s_classif[1].parameters()])

    if (n_params > 0) and (s_classif[0] != 'decoder'):
        print(
            s_classif[0],
            all([(ps == pp).all() for ps, pp in zip(s_classif[1].parameters(), s_pretrained[1].parameters())])
        )

input_embedding True
encoder_layer True
transformer_encoder True
