# Fine-tuning for root inference

In [1]:
import os
import sys
import numpy as np
import torch

sys.path.append('../../modules/')

from logger_tree_language import get_logger
from pytorch_utilities import load_checkpoint
from models import replace_decoder_with_classification_head, freeze_encoder_weights
from training import train_model
from plotting import plot_training_history

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = get_logger('fine_tuning_root_inference')

%load_ext autoreload
%autoreload 2

In [2]:
PRETRAINING_DATA_PATH = '../../data/mlm_data/slrm_data/labeled_data_fixed_4_8_1.0_0.00000.npy'
DATA_PATH = '../../data/mlm_data/slrm_data/labeled_data_fixed_validation_4_8_1.0_0.00000.npy'
MODEL_DIR = '../../models/mlm_pretraining_1/'

Load data.

In [3]:
q, k, sigma, epsilon, roots_seeds, leaves_seeds, rho_seeds = np.load(DATA_PATH, allow_pickle=True)

# The last index corresponds to the seed that generated the
# data/transition tensors: select one.
seed = 0

shuffled_indices = np.random.choice(range(leaves_seeds.shape[1]), leaves_seeds.shape[1], replace=False)

roots = roots_seeds[:, seed]
roots = roots[shuffled_indices]

leaves = leaves_seeds[..., seed].T
leaves = leaves[shuffled_indices, :]
rho = rho_seeds[..., seed]

# Train-test split.
n_samples_test = 2000

leaves_train = leaves[:-n_samples_test, :]
roots_train = roots[:-n_samples_test]

leaves_test = leaves[-n_samples_test:, :]
roots_test = roots[-n_samples_test:]

logger.info(
    f'N training samples: {leaves_train.shape[0]}'
    f' | N test samples: {leaves_test.shape[0]}'
)

# Data preprocessing.
leaves_train = torch.from_numpy(leaves_train).to(device=device).to(dtype=torch.int64)
leaves_test = torch.from_numpy(leaves_test).to(device=device).to(dtype=torch.int64)

roots_train = torch.from_numpy(roots_train).to(device=device).to(dtype=torch.int64)
roots_test = torch.from_numpy(roots_test).to(device=device).to(dtype=torch.int64)

2024-05-02 18:11:47,298 - fine_tuning_root_inference - INFO - N training samples: 48000 | N test samples: 2000


Load pretrained model.

In [4]:
checkpoint_epochs = sorted([
    int(f.split('_')[-1].split('.')[0])
    for f in os.listdir(MODEL_DIR)
    if '.pt' in f
])

selected_checkpoint_epoch = checkpoint_epochs[-1]

checkpoint_id = [f for f in os.listdir(MODEL_DIR) if f'{selected_checkpoint_epoch}.pt' in f][0]

logger.info(f'Selected checkpoint: {checkpoint_id}')

pretrained_model, _, training_history = load_checkpoint(
    MODEL_DIR,
    checkpoint_id,
    device=device
)

2024-05-02 18:11:47,538 - fine_tuning_root_inference - INFO - Selected checkpoint: mlm_pretraining_1_epoch_6000.pt


Replace the pretrained model's head with a new classification head.

In [5]:
classification_model = replace_decoder_with_classification_head(
    pretrained_model,
    n_classes=q,
    device=device,
    head_hidden_dim=[64, 32],
    head_activation='relu'
)

freeze_encoder_weights(classification_model, trainable_modules=['decoder'])

input_embedding | N parameters: 640 | Parameters trainable: False
positional_embedding | N parameters: 0 | Parameters trainable: True
encoder_layer | N parameters: 132480 | Parameters trainable: False
transformer_encoder | N parameters: 264960 | Parameters trainable: False
embedding_agg_layer | N parameters: 0 | Parameters trainable: True
decoder | N parameters: 2099428 | Parameters trainable: True


Training.

In [6]:
initial_accuracy = (
    torch.argmax(classification_model(leaves_test), dim=-1).detach() == roots_test
).to(dtype=torch.float32).mean().to(device='cpu')

logger.info(f'Initial validation accuracy: {initial_accuracy}')

n_epochs = 10

loss_fn = torch.nn.CrossEntropyLoss()

_, fine_tuning_training_history = train_model(
    model=classification_model,
    training_data=(leaves_train, roots_train),
    test_data=(leaves_test, roots_test),
    n_epochs=n_epochs,
    loss_fn=loss_fn,
    learning_rate=1e-4,
    batch_size=32,
    early_stopper=None
)

plot_training_history(fine_tuning_training_history)

2024-05-02 18:11:49,389 - fine_tuning_root_inference - INFO - Initial accuracy: 0.2510000169277191
2024-05-02 18:11:49,390 - fine_tuning_root_inference - INFO - Training model
100%|███████████████████| 10/10 [00:59<00:00,  5.94s/it, training_accuracy=tensor(0.1697), training_loss=tensor(1.3362), val_accuracy=tensor(0.2870, device='cuda:0'), val_loss=tensor(1.3652, device='cuda:0')]
2024-05-02 18:12:48,748 - fine_tuning_root_inference - INFO - Last epoch: 10
