In [15]:
# Import necessary libraries
import jax.numpy as jnp
from jax import random, grad, jit, lax
from jax.scipy.linalg import inv, svd, eigh, det
from jax.numpy.linalg import norm
from tqdm import tqdm
from sklearn.datasets import make_spd_matrix
from jax_models import Lorenz96
from jax_models import visualize_observations, Lorenz96, generate_true_states, generate_gc_localization_matrix
#from jax_filters import ensrf_steps
import jax
import matplotlib.pyplot as plt
#from jax_vi import KL_gaussian, log_likelihood
from jax.tree_util import Partial

def create_stable_matrix(n, key):
    # Generate a symmetric random matrix
    A = random.normal(key, (n, n))
    A = (A + A.T) / 2
    # Ensure the matrix has a spectral radius < 1 for stability
    eigenvalues, eigenvectors = eigh(A)
    scaled_eigenvalues = eigenvalues / (jnp.abs(eigenvalues).max() + 0.1)  # Scale eigenvalues to ensure stability
    A_stable = eigenvectors @ jnp.diag(scaled_eigenvalues) @ eigenvectors.T
    return A_stable

key = random.PRNGKey(0)
n = 40
Q = 0.1 * jnp.eye(n)  # Process noise covariance
R_matrix = 0.5 * jnp.eye(n)#make_spd_matrix(n)  # Generating a symmetric positive definite matrix for R
R = jnp.array(R_matrix)  # Observation noise covariance
H = jnp.eye(n)  # Observation matrix
initial_state = random.normal(random.PRNGKey(0), (n,))  # Initial state
observation_interval = 1
n_ensemble = 20

A_stable = create_stable_matrix(n, key)

M = A_stable
def state_transition_function(x):
    return jnp.dot(M, x)

def jacobian_function(x):
    return M

jac_func = Partial(jacobian_function)
A_step = Partial(state_transition_function)

num_steps = 100

In [2]:
A_stable #first number should be - 0.004985

Array([[-0.00498583, -0.13649331, -0.02905141, ...,  0.04102162,
        -0.0355752 ,  0.11369595],
       [-0.13649331,  0.01910555, -0.03708115, ..., -0.03015236,
         0.0199717 , -0.09355967],
       [-0.02905142, -0.03708115,  0.09458236, ..., -0.04205253,
        -0.07820393, -0.11111479],
       ...,
       [ 0.04102163, -0.03015236, -0.04205254, ..., -0.01375088,
         0.06516878, -0.07160039],
       [-0.03557519,  0.0199717 , -0.07820394, ...,  0.06516878,
         0.13236032,  0.04353936],
       [ 0.11369595, -0.09355966, -0.11111479, ..., -0.07160041,
         0.04353936, -0.20437019]], dtype=float32)

In [46]:
observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, A_step, observation_interval+ 3)

In [47]:
observations.mean(axis=1)

Array([-0.04309407,         nan,         nan,         nan,  0.06407922,
               nan,         nan,         nan, -0.09228519,         nan,
               nan,         nan,  0.10621283,         nan,         nan,
               nan,  0.15251684,         nan,         nan,         nan,
       -0.208201  ,         nan,         nan,         nan,  0.1527056 ,
               nan,         nan,         nan, -0.2755527 ,         nan,
               nan,         nan,  0.28428125,         nan,         nan,
               nan, -0.08476555,         nan,         nan,         nan,
        0.19532718,         nan,         nan,         nan, -0.18693644,
               nan,         nan,         nan, -0.11860991,         nan,
               nan,         nan,  0.01496275,         nan,         nan,
               nan, -0.00263208,         nan,         nan,         nan,
        0.12697113,         nan,         nan,         nan,  0.03223038,
               nan,         nan,         nan,  0.05954238,      

In [40]:
import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.linalg import inv
import numpy as np  # For handling NaN checks with JAX arrays
from jax_filters import ensrf_steps

@jit
def kalman_filter_step(carry, input):
    m_prev, C_prev, M, H, Q, R, observation_interval = carry
    y_curr = input
    m_pred = M @ m_prev
    C_pred = M @ C_prev @ M.T + Q
    def update():
        S = H @ C_pred @ H.T + R
        K = C_pred @ H.T @ inv(S)
        y_hat = H @ m_pred
        m_update = m_pred + K @ (y_curr - y_hat)
        C_update = (jnp.eye(C_prev.shape[0]) - K @ H) @ C_pred
        return m_update, C_update
    def no_update():
        return m_pred, C_pred
    is_observation_available = jnp.logical_not(jnp.any(jnp.isnan(y_curr)))
    m_update, C_update = lax.cond(is_observation_available, update, no_update)

    return (m_update, C_update, M, H, Q, R, observation_interval), (m_update, C_update)

def apply_kalman_filter(y, m0, C0, M, H, Q, R, observation_interval):
    carry_init = (m0, C0, M, H, Q, R, observation_interval)
    _, (ms, Cs) = lax.scan(kalman_filter_step, carry_init, y)
    return ms, Cs

ms, Cs = apply_kalman_filter(observations, initial_state, Q, A_stable, H, Q, R,0)



In [48]:

@jit
def kalman_filter_step(carry, input):
    m_prev, C_prev, state_transition_function, jacobian_function, H, Q, R = carry
    y_curr = input
    m_pred = state_transition_function(m_prev)
    F_jac = jacobian_function(m_prev)
    C_pred = F_jac @ C_prev @ F_jac.T + Q
    
    def update():
        S = H @ C_pred @ H.T + R
        K = C_pred @ H.T @ inv(S)
        y_hat = H @ m_pred
        m_update = m_pred + K @ (y_curr - y_hat)
        C_update = (jnp.eye(H.shape[1]) - K @ H) @ C_pred
        return m_update, C_update
    
    def no_update():
        return m_pred, C_pred
    
    is_observation_available = jnp.logical_not(jnp.any(jnp.isnan(y_curr)))
    m_update, C_update = lax.cond(is_observation_available, update, no_update)

    return (m_update, C_update, state_transition_function, jacobian_function, H, Q, R), (m_update, C_update)

@jit
def apply_kalman_filter(y, m0, C0, state_transition_function, jacobian_function, H, Q, R):
    carry_init = (m0, C0, state_transition_function, jacobian_function, H, Q, R)
    _, (ms, Cs) = lax.scan(kalman_filter_step, carry_init, y)
    return ms, Cs



In [52]:
from jax_filters import kalman_filter_process
ms1, Cs1, ks1 =kalman_filter_process(A_step, jac_func, m0, C0, observations, H, Q, R)


In [53]:
ms, Cs = apply_kalman_filter(observations, m0, C0, A_step, jac_func, H, Q, R)

In [57]:
ms1


Array([[ 0.07836241, -0.7641989 , -0.12098181, ..., -0.20468104,
        -0.53091204,  0.32743043],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       ...,
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan]], dtype=float32)

In [5]:
from jax_vi import KL_gaussian, log_likelihood, KL_sum


def var_cost(inflation, ensemble_init, observations, H, Q, R, localization_matrix, key, num_steps, J0):
   
    states, covariances = ensrf_steps(A_step, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, localization_matrix, inflation, key)

    ensemble_mean = jnp.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension
    key, *subkeys = random.split(key, num=N+1)
    kl_sum = KL_sum(ensemble_mean, covariances, n, A_step, Q, key, N)
    def inner_map(subkey):
        return log_likelihood(random.multivariate_normal(subkey, ensemble_mean, covariances), observations, H, R, num_steps, J0)  # Sometimes the covariances are negative definite. Fix
    cost = kl_sum - jnp.nanmean(jax.lax.map(inner_map, jnp.vstack(subkeys)))
    return cost

In [20]:
J = num_steps
var_cost(1.3, ensemble_init, observations, H, Q, R, localization_matrix, key, J, J0)

Array(8204.043, dtype=float32)

In [6]:
from IPython.display import clear_output
from jax import grad
from tqdm.notebook import tqdm
import jax.numpy as jnp
from jax import random
import properscoring



# Modification: Use grad to compute the gradient with respect to the inflation parameter
var_cost_grad = grad(var_cost, argnums=0)

J0 = 0
inflation_opt = 1.1  # Example starting value for inflation
alpha = 1e-6  # Learning rate
key = random.PRNGKey(0)  # Random key
N = 10  # Number of MC samples
m0 = initial_state
C0 = Q  # Initial covariance, assuming Q is your process noise covariance
localization_matrix = generate_gc_localization_matrix(n, 15)# jnp.ones((n, n)) # 
n_ensemble  = 20
ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T

crpss = []
rmses=[]
inflations = []
true_filter_divergences = []

# from jax import config
# config.update("jax_debug_nans", True)

for i in tqdm(range(50)):
    key, subkey = random.split(key)
        
    # Gradient descent step for inflation parameter
    grad_inflation = var_cost_grad(inflation_opt, ensemble_init, observations, H, Q, R, localization_matrix, subkey, num_steps, J0)
    inflation_opt -= alpha * grad_inflation  # Update inflation parameter
    
    inflations.append(inflation_opt)

    states, covariances = ensrf_steps(A_step, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, localization_matrix, inflation, key)
    
    ensemble_mean = jnp.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension
    #true_filter_divergences.append(KL_gaussian(ensemble_mean.T, covariances, ms.T, Cs))

    rmse = jnp.sqrt(jnp.mean((ensemble_mean - true_states)**2))
    rmses.append(rmse)
    crps = properscoring.crps_ensemble(true_states, states).mean(axis=1).mean()
    crpss.append(crps)
    #clear_output(wait=True)
    
    print(inflation_opt, crps)

  0%|          | 0/50 [00:00<?, ?it/s]

nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.1315303498848616
nan 0.13153

In [51]:
ms

Array([[ 0.07836241, -0.7641989 , -0.12098181, ..., -0.20468104,
        -0.53091204,  0.32743043],
       [ 0.12314163, -0.52652967, -0.26787078, ..., -0.44446802,
        -0.34339535,  0.09057826],
       [ 0.21506685, -0.34971774,  0.22789443, ..., -0.23412657,
        -0.08423281,  0.46471524],
       ...,
       [ 0.60597956, -0.00444032, -0.76824206, ...,  0.15328997,
        -0.3658951 , -0.06035931],
       [ 0.06862107,  0.02553357,  0.6577666 , ...,  0.07941943,
        -0.88178587,  0.00518188],
       [-0.05369499,  0.08855119, -0.27075857, ..., -0.10694164,
        -0.12803587,  0.00748018]], dtype=float32)