In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import math
import flax.linen as nn
from dataclasses import dataclass, field

In [3]:
import sys
sys.path.append('/Users/mariana/Documents/projects/ts/nanoTempoPFN')

from pathlib import Path
from src.data.containers import ShardedDataset, DataLoader
from src.data.time_features import compute_batch_time_features

from src.tsf import xLSTMTSModel, ModelConfig


In [3]:
dataset = ShardedDataset(root_path=Path('../data'))
loader = DataLoader(dataset, batch_size=8, future_length=64)

batch = next(loader)
print(f"history: {batch.history.shape}, future: {batch.future.shape}")


history: (8, 1984, 1, 1), future: (8, 64, 1, 1)


In [4]:
# history_tf, future_tf = compute_batch_time_features(
#     start=batch.start,
#     history_length=batch.history_length,
#     future_length=batch.future_length,
#     batch_size=batch.batch_size,
#     frequency=batch.frequency,
# )
# print(f"history_tf: {history_tf.shape}, future_tf: {future_tf.shape}")


history_tf = jax.random.normal(jax.random.PRNGKey(43), (8, 1984, 6))
future_tf = jax.random.normal(jax.random.PRNGKey(44), (8, 64, 6))
print(f"arr1 shape: {history_tf.shape}, arr2 shape: {future_tf.shape}")



arr1 shape: (8, 1984, 6), arr2 shape: (8, 64, 6)


In [5]:
# New config system - values propagate automatically!
from src.config import make_config

cfg = make_config(num_heads=2, head_embedding_dim=8, n_layers=4)

# All nested configs get the same num_heads:
print(f"top-level:    num_heads={cfg.num_heads}")
print(f"block:        num_heads={cfg.weaving_block.num_heads}")
print(f"layer:        num_heads={cfg.weaving_block.layer.num_heads}")
print(f"cell:         num_heads={cfg.weaving_block.layer.mlstm_cell.num_heads}")

top-level:    num_heads=2
block:        num_heads=2
layer:        num_heads=2
cell:         num_heads=2


In [7]:
# Convert OmegaConf to your dataclasses
from model.recurrent_lstm_cell import mLSTMWeavingCellConfig
from model.recurrent_lstm_layer import mLSTMWeavingLayerConfig
from tsf import WeavingBlockLSTMConfig, ModelConfig

def cfg_to_model_config(cfg):
    """Convert OmegaConf config to ModelConfig dataclass."""
    cell = mLSTMWeavingCellConfig(
        embedding_dim=cfg.head_embedding_dim * cfg.num_heads,  # inner dim
        num_heads=cfg.num_heads,
        dtype=cfg.dtype,
    )
    layer = mLSTMWeavingLayerConfig(
        embedding_dim=cfg.head_embedding_dim,
        num_heads=cfg.num_heads,
        dropout=cfg.dropout,
        dtype=cfg.dtype,
        mlstm_cell=cell,
    )
    block = WeavingBlockLSTMConfig(
        n_layers=cfg.n_layers,
        embedding_dim=cfg.head_embedding_dim,
        num_heads=cfg.num_heads,
        weaving_layer_config=layer,
    )
    return ModelConfig(
        input_dim=cfg.input_dim,
        embedding_dim=cfg.embedding_dim,
        head_embedding_dim=cfg.head_embedding_dim,
        num_heads=cfg.num_heads,
        n_layers=cfg.n_layers,
        output_dim=cfg.output_dim,
        weaving_block_config=block,
    )

model_config = cfg_to_model_config(cfg)
model_config

ModelConfig(input_dim=1, embedding_dim=32, head_embedding_dim=8, num_heads=2, n_layers=4, output_dim=9, weaving_block_config=WeavingBlockLSTMConfig(n_layers=4, embedding_dim=8, num_heads=2, weaving_layer_config=mLSTMWeavingLayerConfig(conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=2, embedding_dim=8, bias=False, dropout=0.1, dtype='bfloat16', _num_blocks=1, _inner_embedding_dim=16, mlstm_cell=mLSTMWeavingCellConfig(embedding_dim=16, num_heads=2, dtype='bfloat16'))))

In [8]:
# Or load from YAML:
from src.config import load_config
cfg_from_yaml = load_config('../conf/model.yaml')
print(f"From YAML: num_heads={cfg_from_yaml.num_heads}, all nested={cfg_from_yaml.weaving_block.layer.mlstm_cell.num_heads}")

From YAML: num_heads=2, all nested=2


In [12]:
model_config = cfg_to_model_config(cfg_from_yaml)
model = xLSTMTSModel(config=model_config)

n_channels = 2
input_dim = 1
B = batch.batch_size
S_h, S_f = batch.history_length, batch.future_length
x = jax.random.normal(jax.random.PRNGKey(0), (B, S_h, n_channels, input_dim))
t_hist = jax.random.normal(jax.random.PRNGKey(1), (B, S_h, 1))
t_future = jax.random.normal(jax.random.PRNGKey(2), (B, S_f, 1))

params = model.init(jax.random.PRNGKey(1), x, t_hist, t_future)
preds = model.apply(params, x, t_hist, t_future)
preds.shape

(8, 64, 2, 9)

In [14]:
model_config.weaving_block_config.weaving_layer_config

mLSTMWeavingLayerConfig(conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=2, embedding_dim=8, bias=False, dropout=0.0, dtype='bfloat16', _num_blocks=1, _inner_embedding_dim=32, mlstm_cell=mLSTMWeavingCellConfig(embedding_dim=32, num_heads=4, dtype='bfloat16'))

In [None]:
from model.recurrent_lstm_cell import mLSTMWeavingCellConfig
from model.recurrent_lstm_layer import mLSTMWeavingLayerConfig

# ===============================================
# Super important: _inner_embedding_dim must be 
# DH*NH. TODO: now you can implement the weaving
# layer below.
# ===============================================

# Config
input_dim = 24
B, S, DH = 2, 5, 24
NH = 3
config = mLSTMWeavingLayerConfig(embedding_dim=DH, num_heads=NH)

# Test
x = jax.random.normal(jax.random.PRNGKey(0), (B, S, input_dim))
c_state = jnp.zeros((B, NH, DH, DH))
n_state = jnp.zeros((B, NH, DH, 1))
m_state = jnp.zeros((B, NH, 1, 1))

layer = mLSTMWeavingLayer(config=config)
params = layer.init(jax.random.PRNGKey(1), x, c_state, n_state, m_state)

y, (c, n, m) = layer.apply(params, x, c_state, n_state, m_state)

print(f"Input: {x.shape} -> Output: {y.shape}")

Input: (2, 5, 24) -> Output: (2, 5, 24)


In [None]:
c.shape, n.shape, m.shape

((2, 3, 24, 24), (2, 3, 24, 1), (2, 3, 1, 1))

In [None]:
# Config
input_dim = 24
B, S, DH = 2, 5, 24
NH = 3
config = ModelConfig()

# Test
x = jax.random.normal(jax.random.PRNGKey(0), (B, S, input_dim))
c_state = jnp.zeros((B, NH, DH, DH))
n_state = jnp.zeros((B, NH, DH, 1))
m_state = jnp.zeros((B, NH, 1, 1))

model = WeavingLSTM(config=config)
params = model.init(jax.random.PRNGKey(1), x, c_state, n_state, m_state)

y, (c, n, m) = model.apply(params, x, c_state, n_state, m_state)

print(f"Input: {x.shape} -> Output: {y.shape}")

Input: (2, 5, 24) -> Output: (2, 5, 24)
