# JAX Bellman Filter Debug

This notebook provides a minimal working example to debug the JAX-based Bellman filter implementation.

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from functions.simulation import DFSV_params, simulate_DFSV
from functions.filters import DFSVBellmanFilter
from functions.filter_bellman import *
# Set random seed for reproducibility
np.random.seed(42)

# Create a minimal test model
K = 1  # Number of factors
N = 2  # Number of observed series

# Define model parameters
lambda_r = np.array([[1.0], [0.5]])  # Simple factor loadings
Phi_f = np.array([[0.9]])            # Factor persistence
Phi_h = np.array([[0.95]])           # Volatility persistence
mu = np.array([-1.0])                # Log-volatility mean
sigma2 = np.ones(N) * 0.1            # Measurement noise
Q_h = np.array([[0.2]])              # Volatility of log-volatility

# Create parameter object
params = DFSV_params(
    N=N,
    K=K,
    lambda_r=lambda_r,
    Phi_f=Phi_f,
    Phi_h=Phi_h,
    mu=mu,
    sigma2=sigma2,
    Q_h=Q_h
)

print("Model parameters created successfully")

Model parameters created successfully


In [2]:
# Generate synthetic data
T = 20  # Just a few time points for testing
y, factors, log_vols = simulate_DFSV(params, T=T, seed=42)

print("Data generated:")
print(f"y shape: {y.shape}")
print(f"factors shape: {factors.shape}")
print(f"log_vols shape: {log_vols.shape}")

Data generated:
y shape: (20, 2)
factors shape: (20, 1)
log_vols shape: (20, 1)


In [3]:
# Initialize the filter
try:
    bf = DFSVBellmanFilter_JAXOPT3(params)
    bf2 = DFSVBellmanFilter(params)
    print("\nBellman filter initialized")
    print("\nJAX parameters:")
    print(f"lambda_r shape: {bf.jax_lambda_r.shape}")
    print(f"sigma2 shape: {bf.jax_sigma2.shape}")
    print(f"mu shape: {bf.jax_mu.shape}")
    print(f"Phi_f shape: {bf.jax_Phi_f.shape}")
    print(f"Phi_h shape: {bf.jax_Phi_h.shape}")
    print(f"Q_h shape: {bf.jax_Q_h.shape}")
except Exception as e:
    print(f"Error during initialization: {e}")


Bellman filter initialized

JAX parameters:
lambda_r shape: (2, 1)
sigma2 shape: (2, 2)
mu shape: (1, 1)
Phi_f shape: (1, 1)
Phi_h shape: (1, 1)
Q_h shape: (1, 1)


In [4]:
# Initialize state and test prediction step
try:
    # Initialize state
    state0, cov0 = bf.initialize_state(y)
    print(f"Initial state shape: {state0.shape}")
    print(f"Initial covariance shape: {cov0.shape}")
    print(f"\nInitial state:\n{state0}")
    
    # Test prediction step
    predicted_state, predicted_cov = bf.predict(state0, cov0)
    print("\nPrediction step completed")
    print(f"Predicted state:\n{predicted_state}")
    
    # Verify covariance matrices
    try:
        np.linalg.cholesky(predicted_cov)
        print("\nPredicted covariance is positive definite ✓")
    except np.linalg.LinAlgError:
        print("\nWARNING: Predicted covariance is not positive definite!")
        print("Eigenvalues:", np.linalg.eigvals(predicted_cov))
except Exception as e:
    print(f"Error during prediction: {e}")

Initial state shape: (2, 1)
Initial covariance shape: (2, 2)

Initial state:
[[ 0.]
 [-1.]]

Prediction step completed
Predicted state:
[[ 0.]
 [-1.]]

Predicted covariance is positive definite ✓


In [5]:
# jax.profiler.start_trace(log_dir=logdir)
observation = y[0:1, :].T.reshape(-1, 1)
print(f"Observation shape: {observation.shape}")
print(f"Observation values:\n{observation}")

# Check JAX objective function
print("\nTesting JAX objective function...")
alpha_test = predicted_state.copy()
alpha_test[0]=0.7
# Convert inputs for JAX
jax_alpha = jnp.array(alpha_test)
jax_pred = jnp.array(predicted_state)
jax_I_pred = jnp.array(np.linalg.inv(predicted_cov))
jax_obs = jnp.array(observation)

# Test objective function with new parameter order (removed K, N params)
obj_val, grad_val = bf.obj_and_grad_fn(jax_alpha, jax_pred, jax_I_pred, jax_obs,
                            )
print(f"Objective value at predicted state: {float(obj_val)}")

# Test gradient with new parameter order
print(f"Gradient at predicted state:\n{np.array(grad_val)}")
# Perform update step
print("\nPerforming update step...")

updated_state, updated_cov, log_likelihood = bf.update(
    predicted_state, predicted_cov, observation
)


print("\nUpdate step completed")
print(f"Updated state:\n{updated_state}")
print(f"Log-likelihood: {log_likelihood}")

# Verify updated covariance
try:
    np.linalg.cholesky(updated_cov)
    print("\nUpdated covariance is positive definite ✓")
except np.linalg.LinAlgError:
    print("\nWARNING: Updated covariance is not positive definite!")
    print("Eigenvalues:", np.linalg.eigvals(updated_cov))
# jnp.array(updated_state).block_until_ready()
# jax.profiler.stop_trace()

Observation shape: (2, 1)
Observation values:
[[0.]
 [0.]]

Testing JAX objective function...
Objective value at predicted state: 0.3674175149372927
Gradient at predicted state:
[ 2.49708519 -0.16597887]

Performing update step...

Update step completed
Updated state:
[[ 0.       ]
 [-2.0256404]]
Log-likelihood: 1.0193823093408931

Updated covariance is positive definite ✓


In [7]:
print("Running full filter...")
filtered_states, filtered_covs, log_likelihood = bf.filter(y)
print("\nFilter completed successfully!")
print(f"Total log-likelihood: {log_likelihood}")
print(f"\nFiltered states shape: {filtered_states.shape}")
print(f"Filtered covs shape: {filtered_covs.shape}")

# Compare with true states
print("\nCorrelation with true states:")
factor_corr = jnp.corrcoef(filtered_states[:, 0], factors[:, 0])[0, 1]
vol_corr = jnp.corrcoef(filtered_states[:, 1], log_vols[:, 0])[0, 1]
print(f"Factor correlation: {factor_corr:.4f}")
print(f"Log-volatility correlation: {vol_corr:.4f}")

Running full filter...


Bellman Filter Progress: 100%|██████████| 20/20 [00:00<00:00, 103.41it/s]


Filter completed successfully!
Total log-likelihood: -71.77618614712559

Filtered states shape: (20, 2)
Filtered covs shape: (20, 2, 2)

Correlation with true states:
Factor correlation: 0.8462
Log-volatility correlation: -0.0283





In [8]:
import time
bf = DFSVBellmanFilter_JAXOPT3(params)
bf2= DFSVBellmanFilter(params)
# Time the JAX implementation
start_time = time.time()
filtered_states_jaxopt, filtered_covs_jaxopt, ll_jaxopt = bf.filter(y)
jaxopt_time = time.time() - start_time
print(f"Optimized implementation time: {jaxopt_time:.4f} seconds")

# Time the original implementation
start_time = time.time()
filtered_states_orig, filtered_covs_orig, ll_orig = bf2.filter(y)
orig_time = time.time() - start_time
print(f"Original implementation time: {orig_time:.4f} seconds")

# Calculate speedup
speedup = orig_time / jaxopt_time if jaxopt_time > 0 else float('inf')
print(f"Speedup factor: {speedup:.2f}x")

# Compare results for consistency
states_diff = np.max(np.abs(filtered_states_jaxopt - filtered_states_orig))
covs_diff = np.max(np.abs(filtered_covs_jaxopt - filtered_covs_orig))
ll_diff = (ll_jaxopt - ll_orig)

print("\nResult consistency check:")
print(f"Max states difference: {states_diff:.8f}")
print(f"Max covariance difference: {covs_diff:.8f}")
print(f"Log-likelihood difference: {ll_diff:.8f}")

Bellman Filter Progress: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


Optimized implementation time: 13.9001 seconds


Bellman Filter Progress: 100%|██████████| 20/20 [00:09<00:00,  2.15it/s]

Original implementation time: 9.3198 seconds
Speedup factor: 0.67x

Result consistency check:
Max states difference: 6.10226589
Max covariance difference: 36.73515139
Log-likelihood difference: -21.18375551





In [8]:
jnp.linalg.inv(filtered_covs_jaxopt)

Array([[[ 8.42994299e+00,  0.00000000e+00],
        [ 0.00000000e+00,  4.87500303e-01]],

       [[ 4.47641726e+00, -5.73141336e-02],
        [-5.73141336e-02,  5.05676295e-01]],

       [[ 3.66670180e-01,  1.93338940e-02],
        [ 1.93338940e-02,  5.03939783e-01]],

       [[ 1.40360797e+00, -1.41643269e-02],
        [-1.41643269e-02,  5.01916538e-01]],

       [[ 8.05080401e-01, -1.53932105e-01],
        [-1.53932105e-01,  6.28481883e-01]],

       [[ 3.49684596e-01, -1.45728293e-02],
        [-1.45728293e-02,  5.87990225e-01]],

       [[ 2.42566863e-01,  1.75582516e-02],
        [ 1.75582516e-02,  5.79677745e-01]],

       [[ 9.30879021e-02, -2.61147665e-02],
        [-2.61147665e-02,  5.81972547e-01]],

       [[ 3.63539107e-02, -2.10975798e-02],
        [-2.10975798e-02,  5.76576563e-01]],

       [[ 2.70491409e-02,  1.68112557e-03],
        [ 1.68112557e-03,  5.58899494e-01]],

       [[ 2.90100180e-02, -1.96890416e-02],
        [-1.96890416e-02,  5.62934918e-01]],

       [[ 

In [None]:
print(ll_orig)

-50.59243063976549
