# Debug CLDS Model - Figure 4 Experiment
This notebook debugs CLDS training for the figure 4 calcium imaging data.

In [None]:
import os
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
import sys
sys.path.insert(0, '..')

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import matplotlib.pyplot as plt

from dynamax.linear_gaussian_ssm.models import LinearGaussianConjugateSSM, ConditionallyLinearGaussianSSM
from dynamax.utils.utils import Tm_basis, rbf_basis

## Hyperparameters

In [None]:
# Model hyperparameters
state_dim = 5
L = 7  # number of basis functions
kappa = 0.2  # lengthscale for RBF, smoothness for Fourier
sigma = 1.0  # amplitude
basis_type = 'rbf'  # 'rbf' or 'fourier'
has_dynamics_bias = True

# Data parameters
data_path = '/home/groups/swl1/hdlee/nast/neurips_2025'
block_size = 4
standardize = True

# Training parameters
num_iters = 50  # reduced for debugging
seed = 2626
model_seed = 2014

## Load Data

In [None]:
# Load emissions and conditions (calcium imaging data)
emissions_path = os.path.join(data_path, 'data_calcium_v5.npy')
conditions_path = os.path.join(data_path, 'conditions_calcium_v5.npy')

emissions = jnp.load(emissions_path).astype(jnp.float64)
conditions = jnp.load(conditions_path).astype(int)

print(f"Emissions shape: {emissions.shape}")
print(f"Conditions shape: {conditions.shape}")
print(f"Unique conditions: {jnp.unique(conditions)}")

In [None]:
def select_non_consecutive(key, a, b, k):
    """Select k non-consecutive integers from [a, b]"""
    n = b - a + 1
    if k > (n + 1) // 2:
        raise ValueError("Not enough non-consecutive numbers to select.")
    
    available = jnp.arange(a, b + 1)
    selected = []

    for _ in range(k):
        key, subkey = jr.split(key)
        idx = jr.randint(subkey, (), 0, len(available))
        choice = available[idx]
        selected.append(choice)

        # Remove choice and its neighbors
        mask = (available != choice) & (available != choice - 1) & (available != choice + 1)
        available = available[mask]

    return jnp.sort(jnp.array(selected))

In [None]:
# Split data
num_conditions = len(np.unique(conditions))
num_blocks = len(emissions) // block_size
num_trials = num_blocks * block_size
emissions = emissions[:num_trials]
conditions = conditions[:num_trials]

# Create train/test split at block level (non-consecutive test blocks)
block_masks = jnp.ones(num_blocks, dtype=bool)
num_test_blocks = num_blocks // 4
key = jr.PRNGKey(seed)
test_idx = select_non_consecutive(key, 6, num_blocks-6, num_test_blocks)
block_masks = block_masks.at[test_idx].set(False)

# Temporal indices (block-level)
block_id_nums = jnp.repeat(jnp.arange(num_blocks, dtype=float), block_size)
block_id_nums = block_id_nums / (num_blocks - 1)  # normalize to [0, 1]

trial_masks = jnp.repeat(block_masks, block_size)
train_conditions = conditions[trial_masks]
test_conditions = conditions[~trial_masks]

# Standardize
if standardize:
    train_obs_ = emissions[trial_masks]
    train_obs_mean = jnp.mean(train_obs_, axis=(0, 1), keepdims=True)
    train_obs_std = jnp.std(train_obs_, axis=(0, 1), keepdims=True)
    train_obs = (emissions - train_obs_mean) / train_obs_std
else:
    train_obs = emissions

_, sequence_length, emission_dim = train_obs.shape
test_obs = train_obs[~trial_masks]

print(f"Num blocks: {num_blocks}, Block size: {block_size}")
print(f"Num train trials: {trial_masks.sum()}, Num test trials: {(~trial_masks).sum()}")
print(f"Emission dim: {emission_dim}, Sequence length: {sequence_length}")
print(f"Unique block_id_nums (train): {len(jnp.unique(block_id_nums[trial_masks]))}")
print(f"Test block indices: {test_idx}")

## Initialize Model

In [None]:
# Create basis functions
if basis_type == 'rbf':
    basis_funcs = rbf_basis(L, M_conditions=1, sigma=sigma, kappa=kappa)
else:
    period = 1.0 + 6.0 * kappa
    basis_funcs = Tm_basis(L, M_conditions=1, sigma=sigma, kappa=kappa, period=period)

print(f"Number of basis functions: {len(basis_funcs)}")

# Initialize model
model = ConditionallyLinearGaussianSSM(
    state_dim=state_dim,
    emission_dim=emission_dim,
    num_conditions=num_conditions,
    has_dynamics_bias=has_dynamics_bias,
    torus_basis_funcs=basis_funcs,
    num_trials=len(train_obs[trial_masks]),
)

key = jr.PRNGKey(model_seed)
params, props = model.initialize(key=key)

print(f"\nInitial emission weights shape: {params.emissions.weights.shape}")
print(f"  Expected: (L={len(basis_funcs)}, emission_dim={emission_dim}, state_dim={state_dim})")

## Diagnostic: Check Basis Functions

In [None]:
# Evaluate basis functions at different time points
print("Basis function values at different time points:")
t_values = [0.0, 0.25, 0.5, 0.75, 1.0]
for t in t_values:
    phi = model.wpgs_C.evaluate_basis(t)
    print(f"  t={t:.2f}: {phi[:5]}... (sum={phi.sum():.3f})")

# Plot basis functions
t_range = jnp.linspace(0, 1, 100)
phi_values = jnp.array([model.wpgs_C.evaluate_basis(t) for t in t_range])

plt.figure(figsize=(10, 4))
for i in range(min(len(basis_funcs), 7)):
    plt.plot(t_range, phi_values[:, i], label=f'Basis {i}')
plt.xlabel('Time (normalized)')
plt.ylabel('Basis function value')
plt.title(f'{basis_type.upper()} Basis Functions (L={L}, kappa={kappa}, sigma={sigma})')
plt.legend()
plt.grid(True)
plt.show()

## Train Model

In [None]:
# Train CLDS
best_params, train_lps = model.fit_em(
    params=params,
    props=props,
    emissions=train_obs[trial_masks],
    conditions=train_conditions,
    block_id_nums=block_id_nums[trial_masks],
    num_iters=num_iters,
    use_wandb=False,
)

print(f"\nFinal train log-likelihood: {train_lps[-1]:.2f}")

In [None]:
# Plot training curve
plt.figure(figsize=(8, 4))
plt.plot(train_lps)
plt.xlabel('EM Iteration')
plt.ylabel('Log-likelihood')
plt.title('CLDS Training Curve')
plt.grid(True)
plt.show()

## Diagnostic: Check Learned Weights

In [None]:
# Check emission weights
W_C = best_params.emissions.weights
print(f"Emission weights shape: {W_C.shape}")
print(f"Mean absolute weight per basis function:")
mean_abs_weights = jnp.abs(W_C).mean(axis=(1, 2))
for i, w in enumerate(mean_abs_weights):
    print(f"  Basis {i}: {w:.4f}")

In [None]:
# Check emission matrix variation over time
print("\nEmission matrix norm at different times:")
C_matrices = []
for t in t_values:
    C_t = model.wpgs_C(best_params.emissions.weights, t)
    C_matrices.append(C_t)
    print(f"  t={t:.2f}: ||C|| = {jnp.linalg.norm(C_t):.4f}")

# Relative change from start to end
C_0 = C_matrices[0]
C_1 = C_matrices[-1]
rel_change = jnp.linalg.norm(C_1 - C_0) / jnp.linalg.norm(C_0)
print(f"\n||C(1) - C(0)|| / ||C(0)|| = {rel_change:.4f}")
print(f"  (This should be > 0.1 for meaningful time variation)")

In [None]:
# Plot emission matrix norm over time
C_norms = [jnp.linalg.norm(model.wpgs_C(best_params.emissions.weights, t)) for t in t_range]

plt.figure(figsize=(8, 4))
plt.plot(t_range, C_norms)
plt.xlabel('Time (normalized)')
plt.ylabel('||C(t)||')
plt.title('Emission Matrix Norm Over Time')
plt.grid(True)

# Mark test block positions
test_block_times = test_idx / (num_blocks - 1)
for t in test_block_times:
    plt.axvline(x=t, color='r', linestyle='--', alpha=0.3)
plt.show()

## Compare with LDS

In [None]:
# Train LDS for comparison
lds_model = LinearGaussianConjugateSSM(
    state_dim=state_dim,
    emission_dim=emission_dim,
    num_conditions=num_conditions,
    has_dynamics_bias=has_dynamics_bias,
)

key = jr.PRNGKey(model_seed)
lds_params, lds_props = lds_model.initialize(key=key)

best_lds_params, lds_train_lps = lds_model.fit_em(
    params=lds_params,
    props=lds_props,
    emissions=train_obs[trial_masks],
    conditions=train_conditions,
    num_iters=num_iters,
    use_wandb=False,
)

print(f"LDS final train log-likelihood: {lds_train_lps[-1]:.2f}")
print(f"CLDS final train log-likelihood: {train_lps[-1]:.2f}")
print(f"Difference (CLDS - LDS): {train_lps[-1] - lds_train_lps[-1]:.2f}")

In [None]:
# Compare test log-likelihoods
test_block_ids = block_id_nums[~trial_masks]

clds_test_ll = model.batch_marginal_log_prob(
    best_params, test_obs, conditions=test_conditions, trial_ids=test_block_ids
)

lds_test_ll = lds_model.batch_marginal_log_prob(
    best_lds_params, test_obs, conditions=test_conditions
)

print(f"\nTest Log-Likelihoods:")
print(f"  LDS:  {lds_test_ll:.2f}")
print(f"  CLDS: {clds_test_ll:.2f}")
print(f"  Difference (CLDS - LDS): {clds_test_ll - lds_test_ll:.2f}")

if clds_test_ll < lds_test_ll:
    print(f"\n  WARNING: CLDS test LL < LDS test LL (overfitting!)")

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Training curves
axes[0].plot(train_lps, label='CLDS')
axes[0].plot(lds_train_lps, label='LDS')
axes[0].set_xlabel('EM Iteration')
axes[0].set_ylabel('Log-likelihood')
axes[0].set_title('Training Curves')
axes[0].legend()
axes[0].grid(True)

# Test LL comparison
axes[1].bar(['LDS', 'CLDS'], [lds_test_ll, clds_test_ll])
axes[1].set_ylabel('Test Log-likelihood')
axes[1].set_title('Test Performance')
axes[1].grid(True)

plt.tight_layout()
plt.show()

## Diagnostic: Overfitting Analysis

In [None]:
# Check per-test-block log-likelihoods
print("Per-test-block log-likelihoods:")
print(f"{'Block':<8} {'Time':<8} {'CLDS':<12} {'LDS':<12} {'Diff':<12}")
print("-" * 52)

for i, block_idx in enumerate(test_idx):
    # Get trials for this block
    block_start = i * block_size
    block_end = (i + 1) * block_size
    block_obs = test_obs[block_start:block_end]
    block_conds = test_conditions[block_start:block_end]
    block_times = test_block_ids[block_start:block_end]
    
    # CLDS LL for this block
    clds_block_ll = model.batch_marginal_log_prob(
        best_params, block_obs, conditions=block_conds, trial_ids=block_times
    )
    
    # LDS LL for this block
    lds_block_ll = lds_model.batch_marginal_log_prob(
        best_lds_params, block_obs, conditions=block_conds
    )
    
    time_pos = float(block_idx) / (num_blocks - 1)
    diff = clds_block_ll - lds_block_ll
    print(f"{block_idx:<8} {time_pos:<8.3f} {clds_block_ll:<12.2f} {lds_block_ll:<12.2f} {diff:<12.2f}")

## Diagnostic: Regularization Analysis

In [None]:
# Check if regularization is too strong/weak
print("Regularization Analysis:")
print(f"  L (basis functions): {len(basis_funcs)}")
print(f"  state_dim: {state_dim}")
print(f"  L * state_dim = {len(basis_funcs) * state_dim}")
print(f"  Current regularization coefficient: 1.0")
print(f"\nTo reduce overfitting, increase regularization in models.py line 1715:")
print(f"  ZTZ + 10.0 * jnp.eye(...)  # or 100.0")
print(f"\nOr reduce model complexity:")
print(f"  - Fewer basis functions (smaller L)")
print(f"  - Larger kappa (smoother variation)")

## Visualize Emission Matrix Changes

In [None]:
# Visualize how C changes over time
n_timepoints = 5
fig, axes = plt.subplots(1, n_timepoints, figsize=(15, 3))

for i, t in enumerate(jnp.linspace(0, 1, n_timepoints)):
    C_t = model.wpgs_C(best_params.emissions.weights, float(t))
    im = axes[i].imshow(C_t[:20, :], aspect='auto', cmap='RdBu_r', vmin=-0.5, vmax=0.5)
    axes[i].set_title(f't = {t:.2f}')
    axes[i].set_xlabel('Latent dim')
    if i == 0:
        axes[i].set_ylabel('Emission dim (first 20)')

plt.colorbar(im, ax=axes, shrink=0.8)
plt.suptitle('Emission Matrix C(t) at Different Times')
plt.tight_layout()
plt.show()