In [30]:
import sys
import os
import time
import jaxopt
import numpy as np
import jax
import jax.numpy as jnp
from rich import print
from functools import partial
from jax import jit
from sympy import true
# Plot comparison
import matplotlib.pyplot as plt
# Add parent directory to path to import our modules
# Import parameter classes
import marimo as mo
from functions.simulation import DFSV_params, simulate_DFSV
from functions.jax_params import DFSVParamsDataclass, dfsv_params_to_dict
from functions.bellman_filter import DFSVBellmanFilter
from jaxopt import OptaxSolver
import optax

# Enable 64-bit precision for better numerical stability
jax.config.update("jax_enable_x64", True)

In [31]:
def create_simple_model():
    """Create a simple DFSV model with one factor."""
    # Define model dimensions
    N = 3  # Number of observed series
    K = 1  # Number of factors

    # Factor loadings
    lambda_r = np.array([[0.9], [0.6], [0.3]])

    # Factor persistence
    Phi_f = np.array([[0.95]])

    # Log-volatility persistence
    Phi_h = np.array([[0.98]])

    # Long-run mean for log-volatilities
    mu = np.array([-1.0])
    # Idiosyncratic variance (diagonal)
    sigma2 = np.array([0.1, 0.1, 0.1])

    # Log-volatility noise covariance
    Q_h = np.array([[0.05]])

    # Create parameter object
    params = DFSV_params(
        N=N,
        K=K,
        lambda_r=lambda_r,
        Phi_f=Phi_f,
        Phi_h=Phi_h,
        mu=mu,
        sigma2=sigma2,
        Q_h=Q_h,
    )

    return params


def create_training_data(params, T=100, seed=42):
    """Generate simulated data for training."""
    returns, factors, log_vols = simulate_DFSV(params, T=T, seed=seed)
    return returns, factors, log_vols


@partial(jit, static_argnames=["filter"])
def bellman_objective(params_unconstrained, y, filter, N, K):
    """
    Compute the Bellman objective function for the DFSV model with numerical safeguards.

    Parameters
    ----------
    params : DFSV_params
        Unconstrained Model parameters.
    y : np.ndarray
        Observed data.
    filter : DFSVBellmanFilter
        Bellman filter object.
    N : int
        Number of observed series.
    K : int
        Number of factors.

    Returns
    -------
    float
        The Bellman objective value.
    """
    # Create correct dataclass
    params_unconstrained = DFSVParamsDataclass.from_dict(params_unconstrained, N, K)
    
    # Apply proper constraints to parameters to ensure stability
    # 1. Use tanh to constrain persistence parameters between -1 and 1 for stability
    Phi_f_stable = jnp.tanh(params_unconstrained.Phi_f)
    Phi_h_stable = jnp.tanh(params_unconstrained.Phi_h)
    
    # 2. Use softplus for sigma^2 and Q_h to ensure positivity
    # softplus(x) = log(1 + exp(x)) is a smooth positive function
    sigma2_pos = jax.nn.softplus(params_unconstrained.sigma2)
    Q_h_pos = jax.nn.softplus(params_unconstrained.Q_h)
    
    # 3. Create the new param object with constrained parameters
    constrained_params = params_unconstrained.replace(
        Phi_f=Phi_f_stable,
        Phi_h=Phi_h_stable,
        sigma2=sigma2_pos,
        Q_h=Q_h_pos,
    )
    
    # 4. Run the bellman filter with exception handling
    try:
        ll = DFSVBellmanFilter.jit_log_likelihood_of_params(filter, constrained_params, y)
        # Add a safeguard against extreme values
        ll = jnp.clip(ll, -1e10, 1e10)
        return -ll
    except Exception as e:
        # Return a very high value if there's an exception
        return 1e10

In [32]:
# Create a simple model
params = create_simple_model()
# Generate training data
returns, factors, log_vols = create_training_data(params, T=200)
# Create a Bellman filter object
filter = DFSVBellmanFilter(params.N, params.K)
# Create a JAX-compatible parameter object
jax_params = DFSVParamsDataclass.from_dfsv_params(params)
# Perturb the parameters
jax_params = jax_params.replace(
    lambda_r=jax_params.lambda_r
    + 0.1 * jax.random.normal(jax.random.PRNGKey(0), jax_params.lambda_r.shape),
    Phi_f=jax_params.Phi_f
    + 0.1 * jax.random.normal(jax.random.PRNGKey(1), jax_params.Phi_f.shape),
    Phi_h=jax_params.Phi_h
    + 0.1 * jax.random.normal(jax.random.PRNGKey(2), jax_params.Phi_h.shape),
    mu=jax_params.mu
    + 0.1 * jax.random.normal(jax.random.PRNGKey(3), jax_params.mu.shape),
    # sigma2=jax_params.sigma2
    # + 0.1 * jax.random.normal(jax.random.PRNGKey(4), jax_params.sigma2.shape),
    # Q_h=jax_params.Q_h
    # + 0.1 * jax.random.normal(jax.random.PRNGKey(5), jax_params.Q_h.shape),
)

JAX functions successfully precompiled


# Creating Realistic Starting Parameters

Instead of perturbing the original parameters, we'll create a set of realistic starting parameters that are different from the true parameters but still reasonable for the DFSV model.

In [33]:
def create_realistic_starting_params(true_params):
    """
    Create a set of realistic starting parameters for estimation.
    
    Parameters
    ----------
    true_params : DFSV_params
        The true parameters used for simulation.
        
    Returns
    -------
    DFSVParamsDataclass
        A realistic set of starting parameters for optimization.
    """
    # Extract dimensions
    N = true_params.N
    K = true_params.K
    
    # Create parameter estimates that are reasonably different from true values
    # but still in a plausible range
    
    # Factor loadings - increase magnitude but keep sign pattern
    lambda_r = jnp.array(true_params.lambda_r) * 0.7 + 0.2
    
    # Factor persistence - typically high but less than true
    Phi_f = jnp.array(true_params.Phi_f) * 0.9
    
    # Log-volatility persistence - typically high but less than true
    Phi_h = jnp.array(true_params.Phi_h) * 0.95
    
    # Long-run mean for log-volatilities - slightly higher than true
    mu = jnp.array(true_params.mu) + 0.3
    
    # Idiosyncratic variance - higher than true value
    sigma2 = jnp.array(true_params.sigma2)
    
    # Log-volatility noise covariance - higher than true
    Q_h = jnp.array(true_params.Q_h)
    
    # Create JAX parameter object
    start_params = DFSVParamsDataclass(
        N=N,
        K=K,
        lambda_r=lambda_r,
        Phi_f=Phi_f,
        Phi_h=Phi_h,
        mu=mu,
        sigma2=sigma2,
        Q_h=Q_h
    )
    
    return start_params

# Create realistic starting parameters
starting_params = create_realistic_starting_params(params)

# Print comparison of true vs. starting parameters
print("True parameters:")
print(f"lambda_r:\n{params.lambda_r}")
print(f"Phi_f:\n{params.Phi_f}")
print(f"Phi_h:\n{params.Phi_h}")
print(f"mu:\n{params.mu}")
print(f"sigma2:\n{np.diag(params.sigma2)}")
print(f"Q_h:\n{params.Q_h}")
print("\nStarting parameters:")
print(f"lambda_r:\n{starting_params.lambda_r}")
print(f"Phi_f:\n{starting_params.Phi_f}")
print(f"Phi_h:\n{starting_params.Phi_h}")
print(f"mu:\n{starting_params.mu}")
print(f"sigma2:\n{jnp.diag(starting_params.sigma2)}")
print(f"Q_h:\n{starting_params.Q_h}")

In [34]:
# Use the starting parameters for optimization
# Convert to dictionary for optimization
param_dict, N, K = dfsv_params_to_dict(starting_params)

# Ensure all parameter values are float64 for differentiation
for key in param_dict:
    if isinstance(param_dict[key], (int, np.integer)):
        param_dict[key] = float(param_dict[key])
    elif isinstance(param_dict[key], np.ndarray):
        param_dict[key] = param_dict[key].astype(np.float64)
    elif isinstance(param_dict[key], jnp.ndarray):
        param_dict[key] = param_dict[key].astype(jnp.float64)

# Define objective function for this specific problem
def objective(params):
    # Add gradient clipping for numerical stability
    val = bellman_objective(params, returns, filter, N, K)
    return jnp.nan_to_num(val, nan=1e10, posinf=1e10, neginf=-1e10)  # Replace NaN/inf with finite values

def check_params(params):
    """Check if parameter values are valid and print diagnostics"""
    param_obj = DFSVParamsDataclass.from_dict(params, N, K)
    issues = []
    
    # Check for NaN or inf values
    for name, value in [("lambda_r", param_obj.lambda_r), 
                       ("Phi_f", param_obj.Phi_f),
                       ("Phi_h", param_obj.Phi_h),
                       ("mu", param_obj.mu),
                       ("sigma2", param_obj.sigma2),
                       ("Q_h", param_obj.Q_h)]:
        if jnp.any(jnp.isnan(value)) or jnp.any(jnp.isinf(value)):
            issues.append(f"{name} contains NaN or Inf values")
        
        # Also check for non-float types
        if not jnp.issubdtype(value.dtype, jnp.floating):
            issues.append(f"{name} has non-floating type: {value.dtype}")
    
    # Check for extreme values
    if jnp.any(jnp.abs(param_obj.Phi_f) > 5):
        issues.append(f"Phi_f has extreme values: {param_obj.Phi_f}")
    if jnp.any(jnp.abs(param_obj.Phi_h) > 5):
        issues.append(f"Phi_h has extreme values: {param_obj.Phi_h}")
    
    if issues:
        print("Parameter issues found:")
        for issue in issues:
            print(f"- {issue}")
        return False
    return True

@jit
def objective_with_logging(params):
    val = objective(params)
    grad = jax.grad(objective)(params)
    # Flatten and compute norm
    flat_grads, _ = jax.tree_flatten(grad)
    grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in flat_grads))
    # Clip gradient norm if it's too large
    grad_norm = jnp.minimum(grad_norm, 1e10)
    return val, grad_norm

# Define which parameters to optimize
mask = {
    "lambda_r": True,
    "Phi_f": True,
    "Phi_h": True,
    "mu": True,
    "sigma2": True,  # Keep optimizing sigma2
    "Q_h": True,      # Keep optimizing Q_h
}

# Create optimizer with a lower learning rate for better stability
opt = optax.adam(learning_rate=1e-3)
# Add gradient clipping
opt = optax.chain(optax.clip_by_global_norm(1.0), opt)
masked_optimizer = optax.masked(opt, mask=mask)
solver = OptaxSolver(opt=masked_optimizer, fun=objective, maxiter=100, tol=1e-6, verbose=True)

In [35]:
# Run the optimization with the new starting parameters
print("Starting optimization with realistic parameters...")

# Check parameters before starting optimization
check_params(param_dict)

# Ensure all dictionary values are float
for key in param_dict:
    if isinstance(param_dict[key], (int, np.integer)):
        print(f"Converting {key} from {type(param_dict[key])} to float")
        param_dict[key] = float(param_dict[key])

# Print initial objective value
try:
    initial_objective, initial_grad_norm = objective_with_logging(param_dict)
    print(f"Initial objective value: {initial_objective:.4f}, Gradient norm: {initial_grad_norm:.4f}")
except Exception as e:
    print(f"Error computing initial objective: {e}")
    print("Debugging parameter types:")
    for key, val in param_dict.items():
        print(f"{key}: type={type(val)}, dtype={getattr(val, 'dtype', None)}")

# Custom step-by-step optimization loop for better control
current_params = param_dict
optimizer_state = solver.optimizer.init(current_params)
start_time = time.time()

# Try running with custom safeguards
for i in range(100):  # Max 100 iterations
    try:
        # Compute the value and gradient
        value, grad = jax.value_and_grad(objective)(current_params)
        
        # Check for NaN or Inf in gradient
        flat_grads, _ = jax.tree_flatten(grad)
        has_bad_grads = any(jnp.any(jnp.isnan(g) | jnp.isinf(g)) for g in flat_grads)
        
        if has_bad_grads:
            print(f"Iteration {i}: Found NaN/Inf in gradient, stopping optimization")
            break
            
        # Update parameters
        updates, optimizer_state = solver.optimizer.update(grad, optimizer_state)
        current_params = optax.apply_updates(current_params, updates)
        
        # Compute objective and gradient norm for logging
        obj_value, grad_norm = objective_with_logging(current_params)
        print(f"Iter: {i} Objective Value: {obj_value:.6f} Gradient Norm: {grad_norm:.6f}")
        
        # Check convergence
        if grad_norm < 1e-6:
            print("Converged based on gradient norm")
            break
            
        # Check if parameters are reasonable
        if not check_params(current_params):
            print("Stopping due to parameter issues")
            break
            
    except Exception as e:
        print(f"Error in iteration {i}: {e}")
        break

end_time = time.time()
print(f"Optimization took {end_time - start_time:.2f} seconds")

# Use the final parameters from our custom loop
final_dict = current_params
optimized_params = DFSVParamsDataclass.from_dict(final_dict, N, K)

# Compare true, starting, and optimized parameters
print("\nParameter Comparison:")
print("-" * 50)
print(f"{'Parameter':<10} {'True':<15} {'Starting':<15} {'Optimized':<15}")
print("-" * 50)

# Lambda values (first element)
print(f"lambda_r[0] {params.lambda_r[0][0]:<15.4f} {starting_params.lambda_r[0][0]:<15.4f} {optimized_params.lambda_r[0][0]:<15.4f}")

# Phi values
print(f"Phi_f     {params.Phi_f[0][0]:<15.4f} {starting_params.Phi_f[0][0]:<15.4f} {optimized_params.Phi_f[0][0]:<15.4f}")
print(f"Phi_h     {params.Phi_h[0][0]:<15.4f} {starting_params.Phi_h[0][0]:<15.4f} {optimized_params.Phi_h[0][0]:<15.4f}")

# mu value
print(f"mu        {params.mu[0]:<15.4f} {starting_params.mu[0]:<15.4f} {optimized_params.mu[0]:<15.4f}")

# sigma2 (first element)
print(f"sigma2[0] {params.sigma2[0,0]:<15.4f} {starting_params.sigma2[0,0]:<15.4f} {optimized_params.sigma2[0,0]:<15.4f}")

# Q_h value
print(f"Q_h       {params.Q_h[0,0]:<15.4f} {starting_params.Q_h[0,0]:<15.4f} {optimized_params.Q_h[0,0]:<15.4f}")

AttributeError: 'OptaxSolver' object has no attribute 'optimizer'

In [None]:
# Compare original vs optimized parameters
print("Comparing filter output with original vs optimized parameters")
def stablize_matrix(matrix):
    norm = jnp.linalg.norm(matrix, ord=2)
    return matrix / (1.0 + norm)
final_dict= result.params
# Convert optimized parameters back to standard format if needed
optimized_params = DFSVParamsDataclass.from_dict(final_dict,N,K)
optimized_params = optimized_params.replace(sigma2=jnp.exp(optimized_params.sigma2),
                                            Phi_f=jnp.tanh(optimized_params.Phi_f),
                                            Phi_h=jnp.tanh(optimized_params.Phi_h),
                                            Q_h=jnp.exp(optimized_params.Q_h))
# optimized_params = DFSVParamsDataclass(
#     N=3,
#     K=1,
#     lambda_r=jnp.array([[1.8], [1.2], [0.58]]),
#     Phi_f=jnp.array([[0.934]]),
#     Phi_h=jnp.array([[0.967]]),
#     mu=jnp.array([0.15]),
#     sigma2=jnp.array([0.114, 0.09, 0.097]),
#     Q_h=jnp.array([[0.023]]),
# )
print(optimized_params)