In [1]:
import math

import equinox as eqx
import jax
import jax.numpy as jnp
import optax

from models.models import SupervisedModel
from models.xlstm import xLSTM
from tasks import decode_sequence, gen_train_sequence

In [2]:
VOCAB_SIZE = 39
SEQ_LEN = 7

learning_rate = 3e-4
tbptt_window = 40

vocab_size = VOCAB_SIZE

In [3]:
rng = jax.random.PRNGKey(0)
model_key, train_key, rng = jax.random.split(rng, 3)


model = xLSTM(
    vocab_size = vocab_size,
    hidden_dim = 512,
    n_blocks = 2,
    n_heads = 4,
    ms_ratio = (1, 1),
    mlstm_kwargs = None,
    slstm_kwargs = {'use_conv': True},
    penultimate_norm = True,
    key = model_key,
)

optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model)

# Count model params
print(f'Number of model params: {sum(jax.tree.leaves(jax.tree.map(lambda x: math.prod(x.shape), model)))}')

Number of model params: 16357087


In [4]:
# n = 20000
# single_seq_len = (SEQ_LEN - 1) // 2
# input_ids = jax.random.randint(rng, (n, single_seq_len), 1, vocab_size)
# dividers = jnp.zeros((n, 1))
# input_ids = jnp.concatenate([input_ids, dividers, input_ids], axis=1, dtype=int)
# sequences = {
#     'input_ids': input_ids[:, :-1],
#     'target_ids': input_ids[:, 1:],
#     'loss_mask': jnp.concat([
#         jnp.zeros((n, single_seq_len - 1)), jnp.ones((n, single_seq_len + 1))
#     ], axis=1).astype(int),
# }
# print(jax.tree.map(lambda x: x.shape, sequences))
# print(jax.tree.map(lambda x: x[0], sequences))

n = 20000
gen_train_sequences = jax.vmap(gen_train_sequence, in_axes=(0, None, None, None, None, None))
sequences = gen_train_sequences(
    jax.random.split(jax.random.PRNGKey(0), n),
    2, 2, 3, 26, 10,
)

print(jax.tree.map(lambda x: x.shape, sequences))
print(jax.tree.map(lambda x: x[0], sequences))

{'input_ids': (20000, 34), 'loss_mask': (20000, 34), 'target_ids': (20000, 34)}
{'input_ids': Array([21, 11, 36, 32, 27, 37, 23, 15, 36, 31, 34, 37,  3,  1, 36, 31, 33,
       38,  3,  1, 36, 31, 33, 37, 23, 15, 36, 31, 34, 37, 21, 11, 36, 32],      dtype=int32), 'loss_mask': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1.],      dtype=float32), 'target_ids': Array([11, 36, 32, 27, 37, 23, 15, 36, 31, 34, 37,  3,  1, 36, 31, 33, 38,
        3,  1, 36, 31, 33, 37, 23, 15, 36, 31, 34, 37, 21, 11, 36, 32, 27],      dtype=int32)}


In [5]:
def loss_fn(model: eqx.Module, rnn_state, sequence):
    input_tokens = sequence['input_ids']
    target_tokens = sequence['target_ids']
    loss_mask = sequence['loss_mask']

    rnn_state, logits = model.forward_sequence(rnn_state, input_tokens)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, target_tokens)
    loss = loss * loss_mask
    return loss.sum() / loss_mask.sum()

value_grad_fn = eqx.filter_value_and_grad(loss_fn)
value_grad_fn = eqx.filter_jit(value_grad_fn)

In [6]:
grad_sum = jax.tree.map(lambda x: jnp.zeros_like(x), model)
loss_hist = []

for i in range(sequences['input_ids'].shape[0]):
    sequence = jax.tree.map(lambda x: x[i], sequences)
    
    # Calculate loss
    rnn_state = model.init_rnn_state()
    loss, grads = value_grad_fn(model, rnn_state, sequence)
    loss_hist.append(loss)
    grad_sum = jax.tree.map(lambda x, y: x + y, grad_sum, grads)

    if i % 32 == 0:
        grad_sum = jax.tree.map(lambda x: x / 32, grad_sum)
        updates, opt_state = optimizer.update(grad_sum, opt_state, model)
        model = eqx.apply_updates(model, updates)
        grad_sum = jax.tree.map(lambda x: jnp.zeros_like(x), model)

    if i % 200 == 0:
        print(f'Loss: {jnp.mean(jnp.array(loss_hist))}')
        loss_hist = []
    # print(jax.tree.leaves(jax.tree.map(lambda x, y: y - x, model, new_model)))
    # break   

Loss: 3.4923510551452637
Loss: 3.362978458404541
Loss: 2.546983480453491
Loss: 2.3057403564453125
Loss: 2.2646918296813965
Loss: 2.2025532722473145
Loss: 2.0822091102600098
Loss: 1.9879021644592285
Loss: 1.9218788146972656
Loss: 1.871410608291626
Loss: 1.8016164302825928
Loss: 1.7356141805648804
Loss: 1.6483252048492432
Loss: 1.5167373418807983
Loss: 1.4584559202194214
Loss: 1.3738263845443726
Loss: 1.2687947750091553
Loss: 1.1708855628967285
Loss: 1.019388198852539
Loss: 0.8592879772186279
Loss: 0.6810572147369385
Loss: 0.426394522190094
Loss: 0.25848618149757385
Loss: 0.21052886545658112
Loss: 0.19102875888347626
Loss: 0.2147325724363327
Loss: 0.15186169743537903
Loss: 0.16175296902656555
Loss: 0.1293259710073471
Loss: 0.1615724265575409
Loss: 0.1087224930524826
Loss: 0.1046803817152977
Loss: 0.09115424752235413
Loss: 0.0863632932305336
Loss: 0.10024692118167877
Loss: 0.05959255248308182
Loss: 0.056992292404174805
Loss: 0.07310356199741364
Loss: 0.07619929313659668
Loss: 0.0904267653