In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax

from models.models import SupervisedModel
from tasks import decode_sequence, gen_train_sequence
from training import apply_grads, supervised_loss_and_grads, TrainState
from utils import tree_replace

In [2]:
VOCAB_SIZE = 39
SEQ_LEN = 7

learning_rate = 3e-4
tbptt_window = 40

vocab_size = VOCAB_SIZE
output_dim = VOCAB_SIZE
embedding_dim = 64
layer_sizes = [1024, 1024, 1024, 1024, 1024]
recurrent_layer_indices = [1]

In [3]:
rng = jax.random.PRNGKey(0)
model_key, train_key, rng = jax.random.split(rng, 3)

model = SupervisedModel(
    rng = model_key,
    vocab_size = vocab_size,
    embedding_dim = embedding_dim,
    layer_sizes = layer_sizes,
    output_dim = output_dim,
    recurrent_layer_indices = recurrent_layer_indices,
)

optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model)

In [4]:
n = 20000
single_seq_len = (SEQ_LEN - 1) // 2
input_ids = jax.random.randint(rng, (n, single_seq_len), 1, vocab_size)
dividers = jnp.zeros((n, 1))
input_ids = jnp.concatenate([input_ids, dividers, input_ids], axis=1, dtype=int)
sequences = {
    'input_ids': input_ids[:, :-1],
    'target_ids': input_ids[:, 1:],
    'loss_mask': jnp.concat([
        jnp.zeros((n, single_seq_len - 1)), jnp.ones((n, single_seq_len + 1))
    ], axis=1).astype(int),
}
print(jax.tree.map(lambda x: x.shape, sequences))
print(jax.tree.map(lambda x: x[0], sequences))

# n = 2000
# gen_train_sequences = jax.vmap(gen_train_sequence, in_axes=(0, None, None, None, None, None))
# sequences = gen_train_sequences(
#     jax.random.split(jax.random.PRNGKey(0), n),
#     2, 2, 3, 26, 10,
# )

# print(jax.tree.map(lambda x: x.shape, sequences))
# print(jax.tree.map(lambda x: x[0], sequences))

{'input_ids': (20000, 6), 'loss_mask': (20000, 6), 'target_ids': (20000, 6)}
{'input_ids': Array([ 8, 32,  6,  0,  8, 32], dtype=int32), 'loss_mask': Array([0, 0, 1, 1, 1, 1], dtype=int32), 'target_ids': Array([32,  6,  0,  8, 32,  6], dtype=int32)}


In [5]:
# for k in sequences.keys():
#     if k == 'loss_mask':
#         print(f'{k}: {sequences[k][0]}')
#     else:
#         print(f'{k}: {decode_sequence(sequences[k][0])}')
#     print()

In [6]:
def loss_fn(model: eqx.Module, rnn_state, sequence):
    # input_tokens = sequence[:-1]
    # target_tokens = sequence[1:]
    # loss_mask = jnp.concatenate([jnp.zeros(single_seq_len - 1), jnp.ones(single_seq_len + 1)], axis=0)
    input_tokens = sequence['input_ids']
    target_tokens = sequence['target_ids']
    loss_mask = sequence['loss_mask']

    rnn_state, logits = model.forward_sequence(rnn_state, input_tokens)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, target_tokens)
    loss = loss * loss_mask
    return loss.sum() / loss_mask.sum()

value_grad_fn = eqx.filter_value_and_grad(loss_fn)
value_grad_fn = eqx.filter_jit(value_grad_fn)

In [7]:
grad_sum = jax.tree.map(lambda x: jnp.zeros_like(x), model)

for i in range(sequences['input_ids'].shape[0]):
    sequence = jax.tree.map(lambda x: x[i], sequences)
    
    # Calculate loss
    rnn_state = model.init_rnn_state()
    loss, grads = value_grad_fn(model, rnn_state, sequence)
    grad_sum = jax.tree.map(lambda x, y: x + y, grad_sum, grads)

    if i % 32 == 0:
        grad_sum = jax.tree.map(lambda x: x / 32, grad_sum)
        updates, opt_state = optimizer.update(grad_sum, opt_state, model)
        model = eqx.apply_updates(model, updates)
        grad_sum = jax.tree.map(lambda x: jnp.zeros_like(x), model)

    if i % 200 == 0:
        print('Loss:', loss)
    # print(jax.tree.leaves(jax.tree.map(lambda x, y: y - x, model, new_model)))
    # break   

Loss: 3.656991
Loss: 3.358118
Loss: 3.4352367
Loss: 3.036385
Loss: 2.8181643
Loss: 2.7609034
Loss: 2.7429702
Loss: 2.7409058
Loss: 2.7366726
Loss: 2.7695687
Loss: 2.7389858
Loss: 2.7403665
Loss: 2.764772
Loss: 2.7696602
Loss: 2.8110623
Loss: 2.6930664
Loss: 2.7036529
Loss: 2.8424745
Loss: 2.755174
Loss: 2.7199306
Loss: 2.7833457
Loss: 2.685863
Loss: 2.5764189
Loss: 2.668539
Loss: 2.3070753
Loss: 2.2298875
Loss: 2.2746902
Loss: 2.543953
Loss: 2.50347
Loss: 2.0879667
Loss: 2.693561
Loss: 2.4434566
Loss: 2.536697
Loss: 2.094209
Loss: 2.5198364
Loss: 2.4208665
Loss: 2.5156798
Loss: 2.5500512
Loss: 1.9125876
Loss: 2.496033
Loss: 2.2372282
Loss: 2.2250853
Loss: 2.238473
Loss: 2.6655438
Loss: 2.0474038
Loss: 2.4619071
Loss: 2.5652966
Loss: 2.2556965
Loss: 2.4294877
Loss: 2.1936798
Loss: 1.8801885
Loss: 1.7161896
Loss: 2.2408798
Loss: 1.9096909
Loss: 2.2188463
Loss: 1.8410776
Loss: 1.6144189
Loss: 1.9596899
Loss: 1.8165623
Loss: 1.9093534
Loss: 2.2095146
Loss: 2.1384342
Loss: 2.0848017
Loss: 2

In [8]:
print(model.forward_sequence(model.init_rnn_state(), jnp.array([8, 9, 8, 0, 8, 9]))[1].argmax(axis=1))

[0 0 0 8 8 8]
