In [7]:
# Import necessary libraries
import jax.numpy as jnp
from jax import random, grad, jit, lax
from jax.scipy.linalg import inv, svd, eigh, det
from jax.numpy.linalg import norm
from tqdm import tqdm
from sklearn.datasets import make_spd_matrix
from jax_models import Lorenz96
from jax_models import visualize_observations, Lorenz96, generate_true_states, generate_gc_localization_matrix
#from jax_filters import ensrf_steps
import jax
import matplotlib.pyplot as plt
#from jax_vi import KL_gaussian, log_likelihood
from jax.tree_util import Partial

def create_stable_matrix(n, key):
    # Generate a symmetric random matrix
    A = random.normal(key, (n, n))
    A = (A + A.T) / 2
    
    # Ensure the matrix has a spectral radius < 1 for stability
    eigenvalues, eigenvectors = eigh(A)
    scaled_eigenvalues = eigenvalues / (jnp.abs(eigenvalues).max() + 0.1)  # Scale eigenvalues to ensure stability
    A_stable = eigenvectors @ jnp.diag(scaled_eigenvalues) @ eigenvectors.T
    
    return A_stable

key = random.PRNGKey(0)
n = 40
Q = 0.1 * jnp.eye(n)  # Process noise covariance
R_matrix = 0.5 * jnp.eye(n)#make_spd_matrix(n)  # Generating a symmetric positive definite matrix for R
R = jnp.array(R_matrix)  # Observation noise covariance
H = jnp.eye(n)  # Observation matrix
initial_state = random.normal(random.PRNGKey(0), (n,))  # Initial state
observation_interval = 1


A_stable = create_stable_matrix(n, key)


def state_transition_function(x):
    return jnp.dot(A_stable, x)

A_step = Partial(state_transition_function)


num_steps = 100

In [2]:
#A_stable

In [8]:
observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, A_step, observation_interval)

In [9]:
import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.linalg import inv
import numpy as np  # For handling NaN checks with JAX arrays

@jit
def kalman_filter_step(carry, input):
    m_prev, C_prev, M, H, Q, R, observation_interval = carry
    y_curr = input
    m_pred = M @ m_prev
    C_pred = M @ C_prev @ M.T + Q
    def update():
        S = H @ C_pred @ H.T + R
        K = C_pred @ H.T @ inv(S)
        y_hat = H @ m_pred
        m_update = m_pred + K @ (y_curr - y_hat)
        C_update = (jnp.eye(C_prev.shape[0]) - K @ H) @ C_pred
        return m_update, C_update
    def no_update():
        return m_pred, C_pred
    is_observation_available = jnp.logical_not(jnp.any(jnp.isnan(y_curr)))
    m_update, C_update = lax.cond(is_observation_available, update, no_update)

    return (m_update, C_update, M, H, Q, R, observation_interval), (m_update, C_update)

def apply_kalman_filter(y, m0, C0, M, H, Q, R, observation_interval):
    carry_init = (m0, C0, M, H, Q, R, observation_interval)
    _, (ms, Cs) = lax.scan(kalman_filter_step, carry_init, y)
    return ms, Cs

ms, Cs = apply_kalman_filter(observations, initial_state, Q, A_stable, H, Q, R,1)



In [38]:
inflation = 1.1
states, covariances = ensrf_steps(A_step, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, localization_matrix, inflation, key)
#print(covariances)

In [48]:
# from jax_vi import KL_gaussian, log_likelihood, KL_sum


# def var_cost(inflation, ensemble_init, observations, H, Q, R, localization_matrix, key, num_steps, J0):
   
#     states, covariances = ensrf_steps(A_step, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, localization_matrix, inflation, key)

#     ensemble_mean = jnp.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension
#     key, *subkeys = random.split(key, num=N+1)
#     kl_sum = KL_sum(ensemble_mean, covariances, n, A_step, Q, key, N)
#     print(kl_sum)

#     def inner_map(subkey):
#         return log_likelihood(random.multivariate_normal(subkey, ensemble_mean, covariances), observations, H, R, num_steps, J0)  # Sometimes the covariances are negative definite. Fix
#     cost = kl_sum - jnp.nanmean(jax.lax.map(inner_map, jnp.vstack(subkeys)))
#     return cost

In [69]:
from jax.lax import scan

def nearest_pd_matrix(M):
    eigval, eigvec = jnp.linalg.eig(M)
    eigval = eigval.at[eigval < 0].set(0)

    return eigvec.dot(jnp.diag(eigval)).dot(eigvec.T)

@jit
def log_likelihood(v, y, H, R, J, J0):
    """
    Computes the log-likelihood of observations given state estimates as a sum.
    Compares this to R, observation noise
    """
    def log_likelihood_j(_, v_y):
        v_j, y_j = v_y
        error = y_j - H @ v_j
        ll = error.T @ inv(R) @ error
        return _, ll
    _, lls = lax.scan(log_likelihood_j, None, (v, y))
    sum_ll = jnp.nansum(lls)
    return -0.5 * sum_ll - 0.5 * (J - J0) * jnp.log(2 * jnp.pi) - 0.5 * (J - sum(jnp.isnan(lls)) - J0) * jnp.linalg.slogdet(R)[1]


@jit
def KL_gaussian(m1, C1, m2, C2):
    """
    Computes the Kullback-Leibler divergence between two Gaussian distributions.
    m1, C1: Mean and covariance of the first Gaussian distribution.
    m2, C2: Mean and covariance of the second Gaussian distribution.
    n: number of state variables
    """
    C2_inv = inv(C2)
    log_det_ratio = (jnp.log(jnp.linalg.eigvals(C2)).sum() - jnp.log(jnp.linalg.eigvals(C1)).sum()).real # log(det(C2) / det(C1)), works better with limited precision because the determinant is practically 0
    return 0.5 * (log_det_ratio - n + jnp.trace(C2_inv @ C1) + ((m2 - m1).T @ C2_inv @ (m2 - m1)))

@jit
def KL_sum(m, C, Q, key):
    """
    Computes the sum of KL divergences between the predicted and updated state distributions.
    Calculates divergence between pairs of predicted and updated state distributions
    
    """
    def KL_j(_, m_C_y):
        m_prev, m_curr, C_prev, C_curr, key = m_C_y
        key, *subkeys_inner = random.split(key, num=N)
        def inner_map(subkey):
            perturbed_state = m_prev + random.multivariate_normal(subkey, jnp.zeros(n), C_prev)
            m_pred = A_step(perturbed_state)
            return KL_gaussian(m_curr, C_curr, m_pred, Q) 
        mean_kl = jnp.nanmean(lax.map(inner_map, jnp.array(subkeys_inner)), axis=0)
        return _, mean_kl

    _, mean_kls = scan(KL_j, None, (m[:-1, :], m[1:, :], C[:-1, :, :], C[1:, :, :], jnp.array(random.split(key, num=m.shape[0]-1))))
    kl_sum = sum(mean_kls)
    return kl_sum

In [70]:
@jit
def ledoit_wolf(P, shrinkage):
    return (1 - shrinkage) * P + shrinkage * jnp.trace(P)/P.shape[0] * jnp.eye(P.shape[0])

@jit
def sqrtm(M):
    eigenvalues, eigenvectors = jnp.linalg.eigh(M)
    inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
    Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
    M_inv_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
    return M_inv_sqrt.real

shrinkage = 0.1

@jit
def ensrf_step(ensemble, y, H, Q, R, localization_matrix, inflation, key):
    n_ensemble = ensemble.shape[1]
    x_m = jnp.mean(ensemble, axis=1)
    #ensemble += random.multivariate_normal(key, jnp.zeros(ensemble.shape[0]), Q, (n_ensemble,)).T
    A = ensemble - x_m.reshape((-1, 1))
    A = A*inflation
    Pf = (A @ A.T) / (n_ensemble - 1) + Q
    P = Pf * localization_matrix  # Element-wise multiplication for localization
    K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
    x_m += K @ (y - H @ x_m)
    #M = jnp.eye(x_m.shape[0]) + P @ H.T @ jnp.linalg.inv(R) @ H
    #M_inv_sqrt = inv(jax.scipy.linalg.sqrtm(M).real)#inv_sqrtmh(M)
    M_inv_sqrt = sqrtm(jnp.eye(x_m.shape[0]) - K@H)
    updated_A = M_inv_sqrt @ A
    updated_ensemble = x_m.reshape((-1, 1)) + updated_A
    #updated_A = updated_ensemble - jnp.mean(updated_ensemble, axis=1).reshape((-1, 1))
    updated_P = (updated_A @ updated_A.T / (n_ensemble - 1))
    updated_P = ledoit_wolf(updated_P, shrinkage)
    return updated_ensemble, updated_P# + jnp.eye(x_m.shape[0])*1e-4 

@jit
def ensrf_steps(ensemble_init, observations, H, Q, R, localization_matrix, inflation, key):
    """
    Deterministic Ensemble Square Root Filter generalized for any model.
    """
    model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
    key, *subkeys = random.split(key, num= num_steps+1)
    subkeys = jnp.array(subkeys)

    def inner(carry, t):
        ensemble, _ = carry
        ensemble_predicted = model_vmap(ensemble)

        # Update ensemble and covariance based on observation availability
        ensemble_updated, Pf_updated = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, localization_matrix, inflation, subkeys[t])
    

        return (ensemble_updated, Pf_updated), (ensemble_updated, Pf_updated)

    # Initialize storage for covariance matrices, including one for the initial state
    covariance_init = jnp.zeros((n,n))

    # Perform the scan over timesteps with the initial ensemble and covariances
    _, output = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

    ensembles, covariances = output

    return ensembles, covariances

In [71]:
@jit
def var_cost(inflation, ensemble_init, observations, H, Q, R, localization_matrix, key, J, J0):
   
    states, covariances = ensrf_steps(ensemble_init, observations, H, Q, R, localization_matrix, inflation, key)
    ensemble_mean = jnp.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension
    key, *subkeys = random.split(key, num=N+1)
    kl_sum = KL_sum(ensemble_mean, covariances, Q, key)

    def inner_map(subkey):
        return log_likelihood(random.multivariate_normal(subkey, ensemble_mean, covariances), observations, H, R, J, J0)  # Sometimes the covariances are negative definite. Fix
    cost = kl_sum - jnp.nanmean(jax.lax.map(inner_map, jnp.vstack(subkeys)))
    
    return cost

In [77]:
J = num_steps
var_cost(inflation + 0.8, ensemble_init, observations, H, Q, R, localization_matrix, key, J, J0)

Array(8452.507, dtype=float32)

In [84]:
from IPython.display import clear_output
from jax import grad
from tqdm.notebook import tqdm
import jax.numpy as jnp
from jax import random
import properscoring



# Modification: Use grad to compute the gradient with respect to the inflation parameter
var_cost_grad = grad(var_cost, argnums=0)

J0 = 0
inflation_opt = 1.1  # Example starting value for inflation
alpha = 1e-6  # Learning rate
key = random.PRNGKey(0)  # Random key
N = 10  # Number of MC samples
m0 = initial_state
C0 = Q  # Initial covariance, assuming Q is your process noise covariance
localization_matrix = generate_gc_localization_matrix(n, 15)# jnp.ones((n, n)) # 
n_ensemble  = 20
ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T

crpss = []
rmses=[]
inflations = []
true_filter_divergences = []

# from jax import config
# config.update("jax_debug_nans", True)

for i in tqdm(range(50)):
    key, subkey = random.split(key)
        
    # Gradient descent step for inflation parameter
    grad_inflation = var_cost_grad(inflation_opt, ensemble_init, observations, H, Q, R, localization_matrix, subkey, num_steps, J0)
    inflation_opt -= alpha * grad_inflation  # Update inflation parameter
    
    inflations.append(inflation_opt)

    states, covariances = ensrf_steps(ensemble_init, observations, H, Q, R, localization_matrix, inflation_opt, key)
    
    ensemble_mean = jnp.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension
    print(ensemble_mean.shape)
    print(ms.shape)
    true_filter_divergences.append(KL_gaussian(ensemble_mean.T, covariances, ms.T, Cs))

    rmse = jnp.sqrt(jnp.mean((ensemble_mean - true_states)**2))
    rmses.append(rmse)
    crps = properscoring.crps_ensemble(true_states, states).mean(axis=1).mean()
    crpss.append(crps)
    #clear_output(wait=True)
    
    print(inflation_opt, crps)

  0%|          | 0/50 [00:00<?, ?it/s]

(100, 40)
(100, 40)


ValueError: Incompatible shapes for broadcasting: shapes=[(40,), (100, 100, 100)]

In [51]:
ms

Array([[ 0.07836241, -0.7641989 , -0.12098181, ..., -0.20468104,
        -0.53091204,  0.32743043],
       [ 0.12314163, -0.52652967, -0.26787078, ..., -0.44446802,
        -0.34339535,  0.09057826],
       [ 0.21506685, -0.34971774,  0.22789443, ..., -0.23412657,
        -0.08423281,  0.46471524],
       ...,
       [ 0.60597956, -0.00444032, -0.76824206, ...,  0.15328997,
        -0.3658951 , -0.06035931],
       [ 0.06862107,  0.02553357,  0.6577666 , ...,  0.07941943,
        -0.88178587,  0.00518188],
       [-0.05369499,  0.08855119, -0.27075857, ..., -0.10694164,
        -0.12803587,  0.00748018]], dtype=float32)