# Use transformers with padded sequences

In [1]:
import os
import sys
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
import torch

sys.path.append('../../modules/')

from logger_tree_language import get_logger
from tree_generation_nontrivial_topology import pad_sequence
from tree_language_vocab import TreeLanguageVocab
from models import TransformerClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = get_logger('transformers_with_padding')

%load_ext autoreload
%autoreload 2

## Load and preprocess data

In [2]:
DATA_PATH = '../../data/topological_tree_data/topological_tree_data_4_5_1.0_0.0.pkl'

Load data.

In [3]:
q, l, sigma, epsilon = DATA_PATH.split('.pkl')[0].split('_')[-4:]

q = int(q)
l = int(l)
sigma = float(sigma)
epsilon = float(epsilon)
max_seq_len = 2 ** l

logger.info(f'Loading data from: {DATA_PATH}')
logger.info(f'q = {q} | l = {l} | sigma = {sigma} | epsilon = {epsilon}')

with open(DATA_PATH, 'rb') as f:
    data = pickle.load(f)

2024-05-30 18:31:25,080 - transformers_with_padding - INFO - Loading data from: ../../data/topological_tree_data/topological_tree_data_4_5_1.0_0.0.pkl
2024-05-30 18:31:25,080 - transformers_with_padding - INFO - q = 4 | l = 5 | sigma = 1.0 | epsilon = 0.0


Preprocess data and train-test split.

In [4]:
vocab = TreeLanguageVocab(
    n_symbols=q,
    use_pad_token=True
)

roots = np.array([t[1] for t in data])
leaves = np.vstack([
    pad_sequence(t[0], max_seq_len, vocab('<pad>'))
    for t in data
])

leaves = torch.from_numpy(leaves).to(device=device).to(dtype=torch.int64)
roots = torch.from_numpy(roots).to(device=device).to(dtype=torch.int64)

shuffled_indices = np.random.choice(range(leaves.shape[0]), leaves.shape[0], replace=False)

n_samples_test = int(0.2 * leaves.shape[0])

leaves_train = leaves[:-n_samples_test, ...]
leaves_test = leaves[-n_samples_test:, ...]

roots_train = roots[:-n_samples_test, ...]
roots_test = roots[-n_samples_test:, ...]

leaves_train.shape, roots_train.shape, leaves_test.shape, roots_test.shape

(torch.Size([80000, 32]),
 torch.Size([80000]),
 torch.Size([20000, 32]),
 torch.Size([20000]))

## Modelling

**WARNING:** for now we're just performing a test similar to masked language modelling, but without any actual masking, just to see that everything works.

In [5]:
embedding_size = 128
vocab_size = len(vocab)

# Instantiate model.
model_params = dict(
    seq_len=max_seq_len,
    embedding_size=embedding_size,
    n_tranformer_layers=2,  # Good: 4
    n_heads=1,
    vocab_size=vocab_size,
    encoder_dim_feedforward=2 * embedding_size,
    positional_encoding=True,
    n_special_tokens=len(vocab.special_tokens),  # We assume the special tokens correspond to the last `n_special_tokens` indices.
    embedding_agg=None,
    decoder_hidden_sizes=[64],  # Good: [64]
    decoder_activation='relu',  # Good: 'relu'
    decoder_output_activation='identity'
)

model = TransformerClassifier(
    **model_params
).to(device=device)



In [51]:
# Test of a forward pass with masking on tha padded positions.
batch_size = 32

batch = leaves_test[:batch_size, ...]
batch_src_key_padding_mask = (batch == vocab('<pad>'))

model(batch, src_key_padding_mask=batch_src_key_padding_mask).detach()

tensor([[[-1.3238e-01,  1.9102e-02,  2.4874e-01,  1.1229e-01],
         [-7.1113e-02,  2.3200e-01,  1.9844e-01,  6.2113e-02],
         [ 1.3027e-01,  2.9371e-01,  1.4724e-01, -1.7555e-02],
         ...,
         [-2.2940e-01,  2.1863e-01,  2.1749e-02,  4.1221e-01],
         [-2.2846e-01, -2.8933e-02,  1.2085e-01,  1.8709e-01],
         [ 7.7312e-02, -4.8308e-02,  3.2073e-02,  7.1469e-02]],

        [[-8.7766e-02,  2.5789e-01, -1.2684e-02,  1.2546e-01],
         [-1.5359e-01, -8.1707e-02,  2.2614e-01,  2.3379e-01],
         [-7.7169e-02,  1.6883e-01, -2.7284e-02, -1.2796e-01],
         ...,
         [-1.7184e-01,  9.0511e-02, -4.4769e-02,  5.2114e-02],
         [ 6.7503e-02, -2.5153e-02,  9.7739e-02,  3.6951e-02],
         [ 1.0356e-01,  2.4701e-02,  1.0044e-01, -2.6803e-02]],

        [[-1.0182e-01,  2.4339e-01,  1.2259e-01, -1.3709e-05],
         [-1.0311e-01,  1.6431e-01,  2.1993e-01,  1.2202e-01],
         [ 3.1087e-02,  2.9377e-01, -4.9692e-02,  2.1498e-01],
         ...,
         