In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax

from models.models import SupervisedModel
from tasks import decode_sequence, gen_train_sequence
from training import apply_grads, supervised_loss_and_grads, TrainState
from utils import tree_replace

In [2]:
VOCAB_SIZE = 39
SEQ_LEN = 7

learning_rate = 3e-4
tbptt_window = 40

vocab_size = VOCAB_SIZE
output_dim = VOCAB_SIZE
embedding_dim = 64
layer_sizes = [512, 512, 512, 512]
recurrent_layer_indices = [1, 2]

In [3]:
rng = jax.random.PRNGKey(0)
model_key, train_key, rng = jax.random.split(rng, 3)

model = SupervisedModel(
    rng = model_key,
    vocab_size = vocab_size,
    embedding_dim = embedding_dim,
    layer_sizes = layer_sizes,
    output_dim = output_dim,
    recurrent_layer_indices = recurrent_layer_indices,
)

optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model)

In [4]:
n = 20000
single_seq_len = (SEQ_LEN - 1) // 2
input_ids = jax.random.randint(rng, (n, single_seq_len), 1, vocab_size)
dividers = jnp.zeros((n, 1))
input_ids = jnp.concatenate([input_ids, dividers, input_ids], axis=1, dtype=int)
sequences = {
    'input_ids': input_ids[:, :-1],
    'target_ids': input_ids[:, 1:],
    'loss_mask': jnp.concat([
        jnp.zeros((n, single_seq_len - 1)), jnp.ones((n, single_seq_len + 1))
    ], axis=1).astype(int),
}
print(jax.tree.map(lambda x: x.shape, sequences))
print(jax.tree.map(lambda x: x[0], sequences))

# n = 20000
# gen_train_sequences = jax.vmap(gen_train_sequence, in_axes=(0, None, None, None, None, None))
# sequences = gen_train_sequences(
#     jax.random.split(jax.random.PRNGKey(0), n),
#     2, 2, 3, 26, 10,
# )

# print(jax.tree.map(lambda x: x.shape, sequences))
# print(jax.tree.map(lambda x: x[0], sequences))

{'input_ids': (20000, 6), 'loss_mask': (20000, 6), 'target_ids': (20000, 6)}
{'input_ids': Array([ 8, 32,  6,  0,  8, 32], dtype=int32), 'loss_mask': Array([0, 0, 1, 1, 1, 1], dtype=int32), 'target_ids': Array([32,  6,  0,  8, 32,  6], dtype=int32)}


In [5]:
def loss_fn(model: eqx.Module, rnn_state, sequence):
    input_tokens = sequence['input_ids']
    target_tokens = sequence['target_ids']
    loss_mask = sequence['loss_mask']

    rnn_state, logits = model.forward_sequence(rnn_state, input_tokens)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, target_tokens)
    loss = loss * loss_mask
    return loss.sum() / loss_mask.sum()

value_grad_fn = eqx.filter_value_and_grad(loss_fn)
value_grad_fn = eqx.filter_jit(value_grad_fn)

In [6]:
# Count model params
import math
print(sum(jax.tree.leaves(jax.tree.map(lambda x: math.prod(x.shape), model))))

4516839


In [7]:
grad_sum = jax.tree.map(lambda x: jnp.zeros_like(x), model)
loss_hist = []

for i in range(sequences['input_ids'].shape[0]):
    sequence = jax.tree.map(lambda x: x[i], sequences)
    
    # Calculate loss
    rnn_state = model.init_rnn_state()
    loss, grads = value_grad_fn(model, rnn_state, sequence)
    loss_hist.append(loss)
    grad_sum = jax.tree.map(lambda x, y: x + y, grad_sum, grads)

    if i % 32 == 0:
        grad_sum = jax.tree.map(lambda x: x / 32, grad_sum)
        updates, opt_state = optimizer.update(grad_sum, opt_state, model)
        model = eqx.apply_updates(model, updates)
        grad_sum = jax.tree.map(lambda x: jnp.zeros_like(x), model)

    if i % 200 == 0:
        print(f'Loss: {jnp.mean(jnp.array(loss_hist))}')
        loss_hist = []
    # print(jax.tree.leaves(jax.tree.map(lambda x, y: y - x, model, new_model)))
    # break   

Loss: 3.6674246788024902
Loss: 3.6498289108276367
Loss: 3.5838427543640137
Loss: 3.4177143573760986
Loss: 3.25270938873291
Loss: 2.9945335388183594
Loss: 2.7831203937530518
Loss: 2.7485907077789307
Loss: 2.7455132007598877
Loss: 2.740182399749756
Loss: 2.7428295612335205
Loss: 2.740185260772705
Loss: 2.740895986557007
Loss: 2.7371928691864014
Loss: 2.734788179397583
Loss: 2.740821361541748
Loss: 2.7331976890563965
Loss: 2.7367303371429443
Loss: 2.7345945835113525
Loss: 2.73360276222229
Loss: 2.734902858734131
Loss: 2.7310333251953125
Loss: 2.7313005924224854
Loss: 2.724963665008545
Loss: 2.730478525161743
Loss: 2.720276355743408
Loss: 2.707827091217041
Loss: 2.6909213066101074
Loss: 2.666843891143799
Loss: 2.6484904289245605
Loss: 2.6150712966918945
Loss: 2.581392765045166
Loss: 2.4990134239196777
Loss: 2.4421441555023193
Loss: 2.356104612350464
Loss: 2.3025190830230713
Loss: 2.2855443954467773
Loss: 2.2544915676116943
Loss: 2.232795238494873
Loss: 2.174999713897705
Loss: 2.15241432189

In [8]:
print(model.forward_sequence(model.init_rnn_state(), jnp.array([8, 9, 8, 0, 8, 9]))[1].argmax(axis=1))

[0 0 0 8 9 5]
