# Root inference with padded sequences

In [None]:
import sys
import pickle
import numpy as np
import torch

sys.path.append('../../modules/')

from logger_tree_language import get_logger
from tree_language_vocab import TreeLanguageVocab
from tree_generation_nontrivial_topology import pad_sequence
from models import TransformerClassifier
from training import train_model
from plotting import plot_training_history

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = get_logger('root_inference_padding')

%load_ext autoreload
%autoreload 2

Load and preprocess data.

In [None]:
DATA_PATH = '../../data/topological_tree_data/topological_tree_data_4_5_1.0_0.0.pkl'

q, l, sigma, epsilon = DATA_PATH.split('.pkl')[0].split('_')[-4:]

q = int(q)
l = int(l)
sigma = float(sigma)
epsilon = float(epsilon)
max_seq_len = 2 ** l

logger.info(f'Loading data from: {DATA_PATH}')
logger.info(f'q = {q} | l = {l} | sigma = {sigma} | epsilon = {epsilon}')

with open(DATA_PATH, 'rb') as f:
    data = pickle.load(f)

In [None]:
vocab = TreeLanguageVocab(
    n_symbols=q,
    use_pad_token=True,
    use_mask_token=True
)

roots = np.array([t[1] for t in data])
leaves = np.vstack([
    pad_sequence(t[0], max_seq_len, vocab('<pad>'))
    for t in data
])

leaves = torch.from_numpy(leaves).to(device=device).to(dtype=torch.int64)
roots = torch.from_numpy(roots).to(device=device).to(dtype=torch.int64)

shuffled_indices = np.random.choice(range(leaves.shape[0]), leaves.shape[0], replace=False)

n_samples_test = int(0.2 * leaves.shape[0])

leaves_train = leaves[:-n_samples_test, ...]
leaves_test = leaves[-n_samples_test:, ...]

# Define and preprocess targets (with one-hot encoding).
roots_train = torch.nn.functional.one_hot(
    roots[:-n_samples_test, ...], num_classes=q
).to(dtype=torch.float32).to(device=device)
roots_test = torch.nn.functional.one_hot(
    roots[-n_samples_test:, ...], num_classes=q
).to(dtype=torch.float32).to(device=device)

leaves_train.shape, roots_train.shape, leaves_test.shape, roots_test.shape

Define model.

In [None]:
embedding_size = 128
vocab_size = len(vocab)

# Instantiate model.
root_inference_model_params = dict(
    seq_len=max_seq_len,
    embedding_size=embedding_size,
    n_tranformer_layers=2,  # Good: 4
    n_heads=1,
    vocab_size=vocab_size,
    encoder_dim_feedforward=2 * embedding_size,
    positional_encoding=True,
    n_special_tokens=len(vocab.special_symbols),  # We assume the special tokens correspond to the last `n_special_tokens` indices.
    embedding_agg='flatten',
    decoder_hidden_sizes=[64],  # Good: [64]
    decoder_activation='relu',  # Good: 'relu'
    decoder_output_activation='identity'
)

root_inference_model = TransformerClassifier(
    **root_inference_model_params
).to(device=device)

Training.

In [None]:
n_epochs = 5
learning_rate = 1e-3

In [None]:
_, training_history = train_model(
    model=root_inference_model,
    training_data=(leaves_train, roots_train),
    test_data=(leaves_test, roots_test),
    n_epochs=n_epochs,
    loss_fn=torch.nn.CrossEntropyLoss(),
    learning_rate=learning_rate,
    batch_size=32,
    early_stopper=None,
    training_history=None,
    checkpointing_period_epochs=None,
    model_dir=None,
    checkpoint_id=None,
    tensorboard_log_dir=None,
    padding_token=vocab('<pad>')
)

plot_training_history(training_history)