# Fine-tuning for root inference

In [None]:
import os
import sys
import numpy as np
import torch

sys.path.append('../../modules/')

from logger_tree_language import get_logger
from pytorch_utilities import load_checkpoint, count_model_params
from models import replace_decoder_with_classification_head, freeze_encoder_weights
from training import train_model
from model_evaluation_tree_language import compute_accuracy
from plotting import plot_training_history

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = get_logger('fine_tuning_root_inference')

%load_ext autoreload
%autoreload 2

In [None]:
PRETRAINING_DATA_PATH = '../../data/mlm_data/slrm_data/labeled_data_fixed_4_8_1.0_0.00000.npy'
DATA_PATH = '../../data/mlm_data/slrm_data/labeled_data_fixed_validation_4_8_1.0_0.00000.npy'
MODEL_DIR = '../../models/mlm_pretraining_2/'

Load data.

In [None]:
q, k, sigma, epsilon, roots_seeds, leaves_seeds, rho_seeds = np.load(DATA_PATH, allow_pickle=True)

# The last index corresponds to the seed that generated the
# data/transition tensors: select one.
seed = 0

shuffled_indices = np.random.choice(range(leaves_seeds.shape[1]), leaves_seeds.shape[1], replace=False)

roots = roots_seeds[:, seed]
roots = roots[shuffled_indices]

leaves = leaves_seeds[..., seed].T
leaves = leaves[shuffled_indices, :]
rho = rho_seeds[..., seed]

# Train-test split.
n_samples_test = 2000

leaves_train = leaves[:-n_samples_test, :]
roots_train = roots[:-n_samples_test]

leaves_test = leaves[-n_samples_test:, :]
roots_test = roots[-n_samples_test:]

logger.info(
    f'N training samples: {leaves_train.shape[0]}'
    f' | N test samples: {leaves_test.shape[0]}'
)

# Data preprocessing.
leaves_train = torch.from_numpy(leaves_train).to(device=device).to(dtype=torch.int64)
leaves_test = torch.from_numpy(leaves_test).to(device=device).to(dtype=torch.int64)

# roots_train = torch.from_numpy(roots_train).to(device=device).to(dtype=torch.int64)
# roots_test = torch.from_numpy(roots_test).to(device=device).to(dtype=torch.int64)
roots_train = torch.nn.functional.one_hot(
    torch.from_numpy(roots_train).to(dtype=torch.int64), num_classes=q
).to(dtype=torch.float32).to(device=device)
roots_test = torch.nn.functional.one_hot(
    torch.from_numpy(roots_test).to(dtype=torch.int64), num_classes=q
).to(dtype=torch.float32).to(device=device)

Load pretrained model.

In [None]:
checkpoint_epochs = sorted([
    int(f.split('_')[-1].split('.')[0])
    for f in os.listdir(MODEL_DIR)
    if '.pt' in f
])

selected_checkpoint_epoch = checkpoint_epochs[-1]

checkpoint_id = [f for f in os.listdir(MODEL_DIR) if f'{selected_checkpoint_epoch}.pt' in f][0]

logger.info(f'Selected checkpoint: {checkpoint_id}')

pretrained_model, _, training_history = load_checkpoint(
    MODEL_DIR,
    checkpoint_id,
    device=device
)

plot_training_history(training_history)

Replace the pretrained model's head with a new classification head.

In [None]:
classification_model = replace_decoder_with_classification_head(
    pretrained_model,
    n_classes=q,
    device=device,
    embedding_agg='flatten',
    head_hidden_dim=[128, 64, 32],
    head_activation='relu',
    head_output_activation='identity',
    head_batch_normalization=False,
    head_dropout_p=None
)

del pretrained_model

freeze_encoder_weights(classification_model, trainable_modules=['decoder'])

logger.info(f'Total number of parameters: {count_model_params(classification_model)}')

In [None]:
classification_model

Checks.

In [None]:
batch_size = 32

# Shape: (batch_size, q).
pred = classification_model(leaves_train[:batch_size]).detach()

print(pred.shape)

compute_accuracy(pred, roots_train[:batch_size])

Training.

In [None]:
fine_tuning_training_history = {
    'training_loss': [],
    'val_loss': [],
    'training_accuracy': [],
    'val_accuracy': [],
    'learning_rate': []
}

In [None]:
n_epochs = 25

loss_fn = torch.nn.CrossEntropyLoss()

_, fine_tuning_training_history = train_model(
    model=classification_model,
    training_data=(leaves_train, roots_train),
    test_data=(leaves_test, roots_test),
    n_epochs=n_epochs,
    loss_fn=loss_fn,
    learning_rate=1e-4,
    batch_size=32,
    early_stopper=None,
    training_history=fine_tuning_training_history,
    tensorboard_log_dir=None
)

plot_training_history(fine_tuning_training_history)

Hypotheses:
- Pre-training with MLM does not help in our case.
- We got some of the hyperparameters wrong.
- Use a `<cls>` token?
- Wrong data regime --> Not enough training data?