In [67]:
import jax
import jax.numpy as jnp
import sys
sys.path.append('/Users/mariana/Documents/research/xlstm-jax')

# Force CPU (optional)
jax.config.update("jax_platform_name", "cpu")

from xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple import (
    parallel_stabilized_simple,
    recurrent_step_stabilized_simple,
)

B, NH, S, DH = 1, 2, 8, 16
q = jax.random.normal(jax.random.key(0), (B, NH, S, DH))
k = jax.random.normal(jax.random.key(1), (B, NH, S, DH))
v = jax.random.normal(jax.random.key(2), (B, NH, S, DH))
igate = jnp.zeros((B, NH, S, 1))
fgate = jnp.zeros((B, NH, S, 1))

y = parallel_stabilized_simple(q, k, v, igate, fgate)  # (B, NH, S, DH)
print(y.shape)

(1, 2, 8, 16)


In [68]:
from dataclasses import dataclass
from flax import linen as nn

from xlstm_jax.models.xlstm_clean.components.init import small_init
from xlstm_jax.models.xlstm_clean.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig
from xlstm_jax.models.xlstm_clean.blocks.mlstm.block import xLSTMBlockConfig, mLSTMBlockConfig
from xlstm_jax.models.xlstm_clean.blocks.mlstm.cell import mLSTMBackendNameAndKwargs
from xlstm_jax.models.xlstm_clean.blocks.mlstm.layer import mLSTMCellConfig, mLSTMLayerConfig
from xlstm_jax.models.xlstm_clean.components.feedforward import FeedForwardConfig

In [69]:
@dataclass
class xLSTMTabModelConfig(xLSTMBlockStackConfig):
    embedding_dim: int = 16
    tie_weights: bool = False
    weight_decay_on_embedding: bool = False
    add_embedding_dropout: bool = True
    output_dim: int = 1

class xLSTMTabModel(nn.Module):
    config: xLSTMTabModelConfig

    @nn.compact
    def __call__(self, x: jax.Array, train: bool = False) -> jax.Array:
        x = nn.Dense(
            features=self.config.embedding_dim,
            kernel_init=small_init(self.config.embedding_dim),
            dtype=self.config._dtype,
            name="token_embedding",
        )(x)
        pos_emb = self.param(
            "pos_embedding",
            nn.initializers.normal(stddev=0.02),
            (self.config.context_length, self.config.embedding_dim),
        )
        x = x + pos_emb[:x.shape[1]]
        if self.config.add_embedding_dropout:
            x = nn.Dropout(rate=self.config.dropout)(x, deterministic=not train)
        x = xLSTMBlockStack(config=self.config, name="xlstm_block_stack")(x, train=train)
        logits = nn.Dense(
            features=self.config.output_dim,
            kernel_init=small_init(self.config.embedding_dim),
            use_bias=False,
            dtype=jnp.float32,
            name="pred_head",
        )(x)
        return logits