In [2]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE = 64
SEQ_LENGTH = 60       
HIDDEN_DIM = 64
STACK_DEPTH = SEQ_LENGTH
LEARNING_RATE = 1e-3
STEPS = 2000

# input
VOCAB_PAD, VOCAB_0, VOCAB_1, VOCAB_EQ = 0, 1, 2, 3
VOCAB_SIZE = 4

# stack
STACK_NULL, STACK_0, STACK_1 = 0, 1, 2
STACK_VOCAB_SIZE = 3

# memory actions
ACT_NOOP, ACT_PUSH_0, ACT_PUSH_1, ACT_POP = 0, 1, 2, 3
NUM_MEM_ACTIONS = 4

# buffer
OUT_NOOP, OUT_EMIT_0, OUT_EMIT_1 = 0, 1, 2
NUM_BUF_ACTIONS = 3

#controller states
STATE_READ = 0
STATE_WRITE = 1
NUM_STATES = 2

In [3]:
print("JAX Backend:", jax.default_backend())
print("Devices:", jax.devices())

JAX Backend: gpu
Devices: [CudaDevice(id=0)]


---
## data gen


In [4]:
def generate_rev_trace(key, batch_size, seq_length=SEQ_LENGTH):
    # seeding numpy RNG from JAX RNG
    seed = int(jax.random.randint(key, (), 0, 2**30))
    rng = np.random.default_rng(seed)
    
    lengths = rng.integers(1, seq_length + 1, size=batch_size)
    seq_len = 2 * seq_length + 1
    
    inputs = np.full((batch_size, seq_len), VOCAB_PAD, dtype=np.int32)
    tgt_act = np.full((batch_size, seq_len), ACT_NOOP, dtype=np.int32)
    tgt_buf = np.full((batch_size, seq_len), OUT_NOOP, dtype=np.int32)
    tgt_state = np.full((batch_size, seq_len), STATE_READ, dtype=np.int32)
    
    for i in range(batch_size):
        L = lengths[i]
        bits = rng.integers(1, 3, size=L) 
        
        #READ
        inputs[i, :L] = bits
        inputs[i, L] = VOCAB_EQ
        tgt_act[i, :L] = bits 
        tgt_state[i, :L] = STATE_READ
        tgt_state[i, L] = STATE_WRITE # Switch at =
        
        #WRITE
        pop_start, pop_end = L, L + L
        tgt_act[i, pop_start:pop_end] = ACT_POP
        tgt_state[i, pop_start:pop_end] = STATE_WRITE
        
        
        reversed_bits = bits[::-1]
        tgt_buf[i, pop_start+1 : pop_end+1] = reversed_bits
        
            
    return jnp.array(inputs), jnp.array(tgt_act), jnp.array(tgt_buf), jnp.array(tgt_state)

---
## Models

model 1: Hard stack RNN

In [5]:
def hard_update_stack(stack, ptr, action):
    # READ
    pop_ptr = jnp.maximum(0, ptr - 1)
    popped_val = stack[pop_ptr]
    
    # WRITE
    push_val = action 
    is_push = (action == ACT_PUSH_0) | (action == ACT_PUSH_1)
    is_pop = (action == ACT_POP)
    
    # UPDATE
    new_stack_push = stack.at[ptr].set(push_val)
    new_ptr_push = ptr + 1
    new_stack_pop = stack.at[pop_ptr].set(STACK_NULL)
    new_ptr_pop = pop_ptr
    
    stack = jnp.where(is_push, new_stack_push, stack)
    stack = jnp.where(is_pop, new_stack_pop, stack)
    ptr = jnp.where(is_push, new_ptr_push, ptr)
    ptr = jnp.where(is_pop, new_ptr_pop, ptr)
    
    r_t = jnp.where(is_pop, popped_val, STACK_NULL)
    return stack, ptr, r_t



class HardStackMachine(nn.Module):
    @nn.compact
    def __call__(self, x, true_act, true_s, use_forcing):
        batch_size, seq_len = x.shape
        
        # Init Hard State
        carry = (
            jnp.zeros((batch_size, STACK_DEPTH), dtype=jnp.int32), # Stack
            jnp.zeros((batch_size,), dtype=jnp.int32),             # Ptr
            jnp.zeros((batch_size,), dtype=jnp.int32),             # Reg
            jnp.zeros((batch_size,), dtype=jnp.int32)              # State
        )
        
        def cell(carry, inputs):
            stack, ptr, r_prev, s_prev = carry
            x_t, t_act, t_s, forcing = inputs
            
            # Embed
            flat = jnp.concatenate([
                nn.Embed(VOCAB_SIZE, HIDDEN_DIM)(x_t),
                jax.nn.one_hot(s_prev, NUM_STATES),
                jax.nn.one_hot(r_prev, STACK_VOCAB_SIZE)
            ], axis=-1)
            
            # Heads
            l_mem = nn.Dense(NUM_MEM_ACTIONS)(flat)
            l_buf = nn.Dense(NUM_BUF_ACTIONS)(flat)
            l_state = nn.Dense(NUM_STATES)(flat)
            
            pred_act = jnp.argmax(l_mem, axis=-1)
            pred_state = jnp.argmax(l_state, axis=-1)
            
            # Forcing
            action_to_exec = jnp.where(forcing > 0, t_act, pred_act)
            next_s = jnp.where(forcing > 0, t_s, pred_state)
            
            # Update
            stack, ptr, r_new = jax.vmap(hard_update_stack)(stack, ptr, action_to_exec)
            return (stack, ptr, r_new, next_s), (l_mem, l_buf, l_state)
        
        
        scan_in = (x, true_act, true_s, jnp.full((batch_size, seq_len), use_forcing))
        scan_layer = nn.scan(
            cell, 
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=1, 
            out_axes=1)
        
        final_carry, out = scan_layer(carry, scan_in)
        return out

---
model 2: soft stack RNN

In [6]:
def soft_update_stack(stack, ptr_dist, action_probs):
    
    p_noop, p_p0, p_p1, p_pop = action_probs[0], action_probs[1], action_probs[2], action_probs[3]
    total_push = p_p0 + p_p1
    
    # 1. READ
    pop_ptr_dist = jnp.roll(ptr_dist, -1).at[-1].set(0.0)
    pop_ptr_dist = pop_ptr_dist.at[0].add(ptr_dist[0]) # Clamp 0
    read_vec = jnp.sum(stack * pop_ptr_dist[:, None], axis=0)
    
    # 2. WRITE
    eps = 1e-9
    val_p0, val_p1 = p_p0/(total_push+eps), p_p1/(total_push+eps)
    write_vec = jnp.array([0., 1., 0.]) * val_p0 + jnp.array([0., 0., 1.]) * val_p1
    
    # Gates
    write_gate = ptr_dist[:, None] * total_push
    pop_gate = pop_ptr_dist[:, None] * p_pop
    
    # Update Stack
    stack = stack * (1. - write_gate) + write_vec[None, :] * write_gate
    stack = stack * (1. - pop_gate) + jnp.array([1., 0., 0.])[None, :] * pop_gate
    
    # 3. MOVE POINTER
    push_ptr_dist = jnp.roll(ptr_dist, 1).at[0].set(0.0)
    new_ptr_dist = (p_noop * ptr_dist) + (total_push * push_ptr_dist) + (p_pop * pop_ptr_dist)
    
    r_t = read_vec * p_pop
    return stack, new_ptr_dist, r_t


class SoftStackMachine(nn.Module):
    @nn.compact
    def __call__(self, x, true_act, true_s, use_forcing):
        batch_size, seq_len = x.shape
        
        # Init Soft State
        carry = (
            jnp.zeros((batch_size, STACK_DEPTH, STACK_VOCAB_SIZE)), # Stack
            jnp.zeros((batch_size, STACK_DEPTH)).at[:,0].set(1.0),  # Ptr Dist
            jnp.zeros((batch_size, STACK_VOCAB_SIZE)),              # Reg Vec
            jnp.zeros((batch_size,), dtype=jnp.int32)               # State (Discrete)
        )
        
        def cell(carry, inputs):
            stack, ptr, r_prev, s_prev = carry
            x_t, t_act, t_s, forcing = inputs
            
            # Embed (Note: r_prev is vector now)
            flat = jnp.concatenate([
                nn.Embed(VOCAB_SIZE, HIDDEN_DIM)(x_t),
                jax.nn.one_hot(s_prev, NUM_STATES),
                nn.Dense(HIDDEN_DIM)(r_prev)
            ], axis=-1)
            
            l_mem = nn.Dense(NUM_MEM_ACTIONS)(flat)
            l_buf = nn.Dense(NUM_BUF_ACTIONS)(flat)
            l_state = nn.Dense(NUM_STATES)(flat)
            
            # Soft Action Mixing
            probs = nn.softmax(l_mem)
            t_onehot = jax.nn.one_hot(t_act, NUM_MEM_ACTIONS)
            # Interpolate between True One-Hot and Predicted Soft Probs
            # forcing shape needs strictly (Batch, 1) here
            f_gate = forcing[:, None]
            mixed_act = (f_gate * t_onehot) + ((1.0 - f_gate) * probs)
            
            nxt_s = jnp.where(forcing > 0, t_s, jnp.argmax(l_state, -1))
            
            # Soft Update
            stack, ptr, r_new = jax.vmap(soft_update_stack)(stack, ptr, mixed_act)
            return (stack, ptr, r_new, nxt_s), (l_mem, l_buf, l_state)

        # forcing needs to be floats for interpolation
        scan_in = (x, true_act, true_s, jnp.full((batch_size, seq_len), use_forcing, dtype=jnp.float32))
        _, out = nn.scan(cell, variable_broadcast="params", split_rngs={"params": False}, in_axes=1, out_axes=1)(carry, scan_in)
        return out

---
model 3: vanilla RNN

In [7]:
class VanillaRNN(nn.Module):
    @nn.compact
    def __call__(self, x, true_act, true_s, use_forcing):
        # Ignores stack inputs
        batch_size, seq_len = x.shape
        
        lstm = nn.LSTMCell()
        carry = (jnp.zeros((batch_size, HIDDEN_DIM)), jnp.zeros((batch_size, HIDDEN_DIM)))
        
        def cell(carry, x_t):
            new_c, new_h = lstm(carry, nn.Embed(VOCAB_SIZE, HIDDEN_DIM)(x_t))
            return (new_c, new_h), new_h

        _, hidden = nn.scan(cell, variable_broadcast="params", split_rngs={"params": False}, in_axes=1, out_axes=1)(carry, x)
        return nn.Dense(NUM_BUF_ACTIONS)(hidden)

---
model 4: reinforcement learning style?

In [8]:
class ReinforceMachine(nn.Module):
    @nn.compact
    def __call__(self, x, key):
        batch_size, seq_len = x.shape
        # Init Hard State + Keys
        carry = (
            jnp.zeros((batch_size, STACK_DEPTH), dtype=jnp.int32),
            jnp.zeros((batch_size,), dtype=jnp.int32),
            jnp.zeros((batch_size,), dtype=jnp.int32),
            jnp.zeros((batch_size,), dtype=jnp.int32),
            jax.random.split(key, batch_size) # Keys for sampling
        )
        
        def cell(carry, x_t):
            stack, ptr, r_prev, s_prev, rng = carry
            
            flat = jnp.concatenate([
                nn.Embed(VOCAB_SIZE, HIDDEN_DIM)(x_t),
                jax.nn.one_hot(s_prev, NUM_STATES),
                jax.nn.one_hot(r_prev, STACK_VOCAB_SIZE)
            ], axis=-1)
            
            l_mem, l_buf, l_state = nn.Dense(NUM_MEM_ACTIONS)(flat), nn.Dense(NUM_BUF_ACTIONS)(flat), nn.Dense(NUM_STATES)(flat)
            
            # SAMPLING
            rng, k1, k2 = jax.random.split(rng, 3)
            samp_act = jax.random.categorical(k1, l_mem)
            samp_s = jax.random.categorical(k2, l_state)
            
            # Log Probs
            lp_mem = jax.nn.log_softmax(l_mem)[samp_act]
            lp_s = jax.nn.log_softmax(l_state)[samp_s]
            
            # Update (Self-Driven)
            stack, ptr, r_new = jax.vmap(hard_update_stack)(stack, ptr, samp_act)
            return (stack, ptr, r_new, samp_s, rng), (l_buf, lp_mem, lp_s)

        _, out = nn.scan(cell, variable_broadcast="params", split_rngs={"params": False}, in_axes=1, out_axes=1)(carry, x)
        return out

---
## Training

In [13]:
def create_train_state(model, key):
    dummy_x = jnp.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=jnp.int32)
    # Init with dummies
    params = model.init(key, dummy_x, dummy_x, dummy_x, use_forcing=True)
    tx = optax.adam(LEARNING_RATE)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def train_step(state, batch):
    inputs, tgt_mem, tgt_buf, tgt_state = batch
    
    def loss_fn(params):
        # Run model (Teaching Forcing = True)
        out = state.apply_fn(params, inputs, tgt_mem, tgt_state, True)
        
        # Unpack outputs based on model type
        if isinstance(out, tuple): 
            l_mem, l_buf, l_state = out
            loss_m = optax.softmax_cross_entropy_with_integer_labels(l_mem, tgt_mem).mean()
            loss_s = optax.softmax_cross_entropy_with_integer_labels(l_state, tgt_state).mean()
        else: 
            l_buf = out
            loss_m, loss_s = 0., 0.
            
        loss_b = optax.softmax_cross_entropy_with_integer_labels(l_buf, tgt_buf).mean()
        
        # Metrics
        acc = jnp.mean(jnp.argmax(l_buf, -1) == tgt_buf)
        total_loss = loss_b + loss_m + loss_s
        return total_loss, acc

    (loss, acc), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    return state.apply_gradients(grads=grads), loss, acc

def run_experiment():
    key = jax.random.PRNGKey(42)
    models = {
        "Hard Stack": HardStackMachine(),
        "Soft Stack": SoftStackMachine(),
        "Vanilla RNN": VanillaRNN()
    }
    
    results = {}
    
    for name, model in models.items():
        print(f"\n--- Training {name} ---")
        key, init_k = jax.random.split(key)
        state = create_train_state(model, init_k)
        
        history = []
        for i in range(STEPS):
            key, batch_k = jax.random.split(key)
            batch = generate_rev_trace(batch_k, BATCH_SIZE)
            
            state, loss, acc = train_step(state, batch)
            
            if i % 100 == 0:
                print(f"Step {i:04d} | Acc: {acc:.2%}")
            history.append(acc)
        results[name] = history
        
    # Comparison Plot
    plt.figure(figsize=(10, 6))
    for name, hist in results.items():
        plt.plot(hist, label=name)
    plt.xlabel("Steps")
    plt.ylabel("Accuracy")
    plt.title("Stack vs Vanilla Performance")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

run_experiment()


--- Training Hard Stack ---


AttributeError: 'tuple' object has no attribute '_state'

In [10]:
run_comparison()


--- Training Hard Stack ---


AttributeError: 'tuple' object has no attribute '_state'