In [1]:
import sys
sys.path.append('/Users/mariana/Documents/research/xlstm-jax')
from xlstm_jax.models.xlstm_clean.blocks.mlstm.backend import  recurrent_step_stabilized_simple

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import jax
import jax.numpy as jnp
import numpy as np
import math

In [4]:
# Define dimensions
B = 2    # Batch size
NH = 1   # Number of heads
DH = 16  # Head dimension

# Create random key for reproducibility
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 8)

# Create random arrays matching the function signature
c_state = jax.random.normal(keys[0], (B, NH, DH, DH))      # (B, NH, DH, DH)
n_state = jax.random.normal(keys[1], (B, NH, DH, 1))       # (B, NH, DH, 1)
m_state = jax.random.normal(keys[2], (B, NH, 1, 1))        # (B, NH, 1, 1)
q = jax.random.normal(keys[3], (B, NH, 1, DH))             # (B, NH, 1, DH)
k = jax.random.normal(keys[4], (B, NH, 1, DH))             # (B, NH, 1, DH)
v = jax.random.normal(keys[5], (B, NH, 1, DH))             # (B, NH, 1, DH)
igate_preact = jax.random.normal(keys[6], (B, NH, 1, 1))   # (B, NH, 1, 1)
fgate_preact = jax.random.normal(keys[7], (B, NH, 1, 1))   # (B, NH, 1, 1)
eps = 1e-6

# Print shapes to verify
print(f"c_state: {c_state.shape}")
print(f"n_state: {n_state.shape}")
print(f"m_state: {m_state.shape}")
print(f"q: {q.shape}")
print(f"k: {k.shape}")
print(f"v: {v.shape}")
print(f"igate_preact: {igate_preact.shape}")
print(f"fgate_preact: {fgate_preact.shape}")

c_state: (2, 1, 16, 16)
n_state: (2, 1, 16, 1)
m_state: (2, 1, 1, 1)
q: (2, 1, 1, 16)
k: (2, 1, 1, 16)
v: (2, 1, 1, 16)
igate_preact: (2, 1, 1, 1)
fgate_preact: (2, 1, 1, 1)


In [5]:
hidden_state, (c_state_new, n_state_new, m_state_new) = recurrent_step_stabilized_simple(
    c_state, n_state, m_state, q, k, v, igate_preact, fgate_preact, eps
)

In [6]:
# Input shape (B, NH, S, DH)
# Output shape (B, NH, S, DH)
#  single step shape(B, NH, 1, DH)

# f input = c_state, n_state, m_state, q, k, v, igate_preact, fgate_preact, eps

def weaving_recurrent_lsmt(q, k, v, igate_preact, fgate_preact, eps=1e-6):

    B, NH, S, DH = q.shape

    # Initialize the carry
    c_state = jnp.zeros((B, NH, DH, DH))
    n_state = jnp.zeros((B, NH, DH, 1))
    m_state = jnp.zeros((B, NH, 1, 1))

    def recurrent_step(carry, xs):
        c_state, n_state, m_state = carry
        q, k, v, fgate_preact, igate_preact = xs

        # Add dimension and transpose accordingly
        q = jnp.expand_dims(q, axis=2)  # (B, NH, 1, DH)
        k = jnp.expand_dims(k, axis=-1)  # (B, NH, DH, 1)
        v = jnp.expand_dims(v, axis=2)  # (B, NH, 1, DH)

        # gates
        log_fg_act = jax.nn.log_sigmoid(fgate_preact)  # (B, NH, 1, 1)

        # update rule
        m_state_new = jnp.maximum(log_fg_act + m_state, igate_preact)  # (B, NH, 1, 1)

        fg_act = jnp.exp(log_fg_act + m_state - m_state_new)  # (B, NH, 1, 1)
        ig_act = jnp.exp(igate_preact - m_state_new)  # (B, NH, 1, 1)

        k_scaled = k / math.sqrt(DH)

        c_state_new = fg_act * c_state + ig_act * (k_scaled @ v)  # (B, NH, DH, DH)
        n_state_new = fg_act * n_state + ig_act * k_scaled  # (B, NH, DH, 1)
        
        h_num = q @ c_state_new  # (B, NH, 1, DH)

        qn_dotproduct = q @ n_state_new  # (B, NH, 1, 1)
        max_val = jnp.exp(-m_state_new)  # (B, NH, 1, 1)
        h_denom = jnp.maximum(jnp.abs(qn_dotproduct), max_val) + eps
        h = h_num / h_denom  # (B, NH, 1, DH) / (B, NH, 1, 1) = (B, NH, 1, DH)

        h = h.squeeze(axis=2)  # (B, NH, DH)
        carry = (c_state_new, n_state_new, m_state_new)

        return carry, h
    
    # q, k, v have to have shape (S, B, NH, DH)
    q = jnp.transpose(q, (2, 0, 1, 3))  # from (B, NH, S, DH) to (S, B, NH, DH)
    k = jnp.transpose(k, (2, 0, 1, 3))  # from (B, NH, S, DH) to (S, B, NH, DH)
    v = jnp.transpose(v, (2, 0, 1, 3))  # from (B, NH, S, DH) to (S, B, NH, DH)
    igate_preact = jnp.transpose(igate_preact, (2, 0, 1, 3))  # from (B, NH, S, 1) to (S, B, NH, 1)
    fgate_preact = jnp.transpose(fgate_preact, (2, 0, 1, 3))  # from (B, NH, S, 1) to (S, B, NH, 1)
    igate_preact = jnp.expand_dims(igate_preact, axis=-1)  # (S, B, NH, 1, 1)
    fgate_preact = jnp.expand_dims(fgate_preact, axis=-1)  # (S, B, NH, 1, 1)

    (c_state_new, n_state_new, m_state_new), out = jax.lax.scan(f=recurrent_step,
        init=(c_state, n_state, m_state),
        xs=(q, k, v, fgate_preact, igate_preact)
    )

    out = jnp.transpose(out, (1, 2, 0, 3))  # from (S, B, NH, DH) to (B, NH, S, DH)

    return out, (c_state_new, n_state_new, m_state_new)

In [7]:
# queries: (B, NH, S, DH)
# keys: (B, NH, S, DH)
# values: (B, NH, S, DH)
# igate_preact: (B, NH, S, 1)
# fgate_preact: (B, NH, S, 1)
S = 5
NH = 3
B = 2
DH = 16

# Create random arrays matching the function signature
c_state = jax.random.normal(keys[0], (B, NH, DH, DH))      # (B, NH, DH, DH)
n_state = jax.random.normal(keys[1], (B, NH, DH))       # (B, NH, DH, 1)
m_state = jax.random.normal(keys[2], (B, NH, 1))        # (B, NH, 1, 1)

# What the f inside lax will see -- we scan through the S dimension.
q = jax.random.normal(keys[3], (B, NH, S, DH))            
k = jax.random.normal(keys[4], (B, NH, S, DH))            
v = jax.random.normal(keys[5], (B, NH, S, DH))            
igate_preact = jax.random.normal(keys[6], (B, NH, S, 1))  
fgate_preact = jax.random.normal(keys[7], (B, NH, S, 1))  
eps = 1e-6

out, (c_state_new, n_state_new, m_state_new) = weaving_recurrent_lsmt(q, k, v, igate_preact, fgate_preact, eps)

In [8]:
out.shape, c_state_new.shape, n_state_new.shape, m_state_new.shape

((2, 3, 5, 16), (2, 3, 16, 16), (2, 3, 16, 1), (2, 3, 1, 1))

## Test mLSTMWeavingCell

In [10]:
sys.path.insert(0, '/Users/mariana/Documents/research/nanoTempoPFN/src')
from model.recurrent_lstm_cell import mLSTMWeavingCell, mLSTMWeavingCellConfig

# Config
B, S, D, NH = 2, 5, 8, 3
config = mLSTMWeavingCellConfig(embedding_dim=D, num_heads=NH)

# Random inputs: (B, NH, S, D)
key = jax.random.PRNGKey(0)
q, k, v = [jax.random.normal(kk, (B, S, D*NH)) for kk in jax.random.split(key, 3)]

# Init and run
cell = mLSTMWeavingCell(config=config)
params = cell.init(jax.random.PRNGKey(1), q, k, v)
h_out, (c, n, m) = cell.apply(params, q, k, v)

print(f"Input shape: {q.shape}")
print(f"Output h: {h_out.shape}")
print(f"States c, n, m: {c.shape}, {n.shape}, {m.shape}")

SHAPES igate_preact, fgate_preact = (2, 5, 3), (2, 5, 3)
SHAPES igate_preact, fgate_preact = (2, 5, 3), (2, 5, 3)
Input shape: (2, 5, 24)
Output h: (2, 5, 24)
States c, n, m: (2, 3, 8, 8), (2, 3, 8, 1), (2, 3, 1, 1)


In [22]:
from xlstm_jax.models.xlstm_clean.components.linear_headwise import LinearHeadwiseExpand, LinearHeadwiseExpandConfig
from xlstm_jax.models.xlstm_clean.components.init import small_init

# Config
B, S = 2, 5
in_features = 32
num_heads = 32
embedding_dim = 64  # for small_init

# Random input: (B, S, in_features)
key = jax.random.PRNGKey(0)
x_mlstm = jax.random.normal(key, (B, S, in_features))

# Create and run
v_proj = LinearHeadwiseExpand(
    config=LinearHeadwiseExpandConfig(
        in_features=in_features,
        num_heads=num_heads,
        bias=True,
        dtype="float32",
    ),
    kernel_init=small_init(embedding_dim),
    name="v_proj",
)

params = v_proj.init(jax.random.PRNGKey(1), x_mlstm)
v = v_proj.apply(params, x_mlstm)

print(f"Input shape: {x_mlstm.shape}")
print(f"Output shape: {v.shape}")
print(f"Params: {jax.tree_util.tree_map(lambda x: x.shape, params)}")

Input shape: (2, 5, 32)
Output shape: (2, 5, 32)
Params: {'params': {'bias': (32,), 'kernel': (32, 1, 1)}}
