In [3]:
import jax
import jax.numpy as jnp
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from functools import partial

from constants import *
from data_gen import generate_rev_trace, generate_fixed_batch, generate_rev_trace_old
from models import StackRNN


def create_train_state(model, key, learning_rate, dummy_input):
    params = model.init(key, dummy_input)['params']
    tx = optax.chain(
        optax.adam(learning_rate)
    )
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


def masked_loss(logits, targets, mask):
    """Masked softmax cross-entropy loss."""
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
    masked_loss = loss * mask
    return masked_loss.sum() / jnp.maximum(mask.sum(), 1e-9)

def train_step(state, batch):
    inputs, targets, mask = batch
    
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, x=inputs)
        loss = masked_loss(logits, targets, mask)
        print(jnp.argmax(logits, -1))  
        acc = ((jnp.argmax(logits, -1) == targets) * mask).sum() / jnp.maximum(mask.sum(), 1e-9)
        return loss, acc
        
    (loss, acc), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, acc

In [4]:
generate_rev_trace(1,6)

(Array([[2, 2, 2, 2, 1, 2, 3, 2, 1, 2, 2, 2, 2, 0]], dtype=int32),
 Array([[2, 2, 2, 1, 2, 3, 2, 1, 2, 2, 2, 2, 4, 0]], dtype=int32),
 Array([[0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=float32))

In [5]:
generate_rev_trace_old(1,6)

(Array([[1, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32),
 Array([[1, 2, 1, 2, 3, 3, 3, 3, 0, 0, 0, 0, 0]], dtype=int32),
 Array([[0, 0, 0, 0, 0, 2, 1, 2, 1, 0, 0, 0, 0]], dtype=int32),
 Array([[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0]], dtype=int32))

In [3]:

model = StackRNN()
TRAINING_SEQ_LEN = SEQ_LENGTH
key = jax.random.PRNGKey(42)
dummy_input = jnp.zeros((BATCH_SIZE, 2 * TRAINING_SEQ_LEN + 2), dtype=jnp.int32)
state = create_train_state(model, key, LEARNING_RATE, dummy_input)
fixed_batch = generate_rev_trace(1, 6)
print(fixed_batch)
for step in range(3001):
    # Always use the same batch
    state, loss, acc = train_step(state, fixed_batch)
    
    if step % 100 == 0: # Print more frequently
        print(f"Step {step} | Train Loss: {loss:.4f} | Train Acc: {acc:.2%}")

(Array([[2, 1, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32), Array([[1, 3, 1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32), Array([[0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32))
[[3 4 2 2 2 4 4 4 2 2 2 2 4 2]]
Step 0 | Train Loss: 1.4706 | Train Acc: 33.33%
[[3 4 2 2 2 4 4 4 2 2 2 2 4 2]]
[[3 4 2 2 2 4 4 4 2 2 2 2 4 2]]
[[3 4 2 2 4 4 4 4 2 2 2 2 4 2]]
[[3 4 2 2 4 4 4 4 2 2 2 2 4 2]]
[[3 4 2 2 4 4 4 4 2 2 2 2 4 2]]
[[3 4 2 2 4 4 4 4 2 2 2 2 4 2]]
[[3 4 2 2 4 4 4 4 2 2 2 2 4 2]]
[[4 4 2 2 4 4 4 4 2 2 2 2 4 2]]
[[4 1 1 2 4 4 4 4 2 2 2 2 4 2]]
[[4 1 1 2 4 4 4 4 1 2 2 2 4 2]]
[[4 1 1 2 4 4 4 4 1 2 2 2 4 2]]
[[4 1 1 2 4 4 4 4 1 2 2 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 2 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 2 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 2 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 4 4 4]]
[[4 1 1 2 4 4 4 4 1 2 1 4 4 4]]
[

KeyboardInterrupt: 