In [1]:
from functools import partial
from typing import List

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax

from copy_task import init_ccopy_task, step_supervised_ccopy_task
from models import SupervisedLSTMModel
from training import train_on_sequence

In [2]:
vocab_size = 10
min_seq_len = 1
max_seq_len = 10

tbptt_window = 40
rng = jax.random.PRNGKey(0)

In [3]:
rng, param_gen_key = jax.random.split(rng)
model = SupervisedLSTMModel(param_gen_key, vocab_size=vocab_size)

rnn_state = model.init_rnn_state()
x = jnp.ones((tbptt_window,), dtype=jnp.int32)

tx = optax.adam(0.0003)
opt_state = tx.init(eqx.filter(model, eqx.is_array))
tx_update_fn = tx.update

In [4]:
env_state = init_ccopy_task(rng, vocab_size, min_seq_len, max_seq_len)
step_env = step_supervised_ccopy_task

@partial(jax.jit, backend='cpu')
def gen_train_sequence(env_state):

    step_fn = lambda state, _: step_env(state)
    env_state, (input_ids, target_ids, loss_mask) = jax.lax.scan(step_fn, env_state, length=tbptt_window)
    sequence = {
        'input_ids': input_ids,
        'target_ids': target_ids,
        'loss_mask': loss_mask,
    }

    sequence = jax.tree.map(lambda x: jnp.array(x), sequence, is_leaf=lambda x: isinstance(x, List))
    return env_state, sequence

In [5]:
train_step = eqx.filter_jit(train_on_sequence, backend='gpu')
rnn_state = model.init_rnn_state()

loss_hist = []
for i in range(1000000):
    env_state, sequence = gen_train_sequence(env_state)

    opt_state, model, rnn_state, loss = train_step(
        model, opt_state, tx_update_fn, rnn_state, sequence, tbptt_window)
    loss_hist.append(loss)

    if i % 100 == 0:
        loss_hist = loss_hist[-100:]
        print(np.mean(loss_hist))


2.3052506
2.1723282
2.0925398
2.0864456
2.0861359
2.0773766
2.0200624
1.9224962
1.791669
1.711072
1.5965805
1.5484114
1.4695032
1.4575665


KeyboardInterrupt: 