In [9]:
# Import necessary libraries
import jax.numpy as np
from jax import random, grad, jit, lax
from jax.scipy.linalg import inv, svd, eigh, det
from jax.numpy.linalg import norm
from tqdm import tqdm
from sklearn.datasets import make_spd_matrix
from jax_models import Lorenz96
from jax_models import visualize_observations, Lorenz96, generate_true_states, generate_gc_localization_matrix
from jax_filters import ensrf_steps
#from jax_vi import KL_gaussian, log_likelihood


# Parameters
F = 8.0
dt = 0.01
num_steps = 30  # Number of time steps
n_timesteps = num_steps
J0 = 0
n = 40   # Number of state variables
Q = 0.1 * np.eye(n)  # Process noise covariance
R_matrix = make_spd_matrix(n)  # Generating a symmetric positive definite matrix for R
R = np.array(R_matrix)  # Observation noise covariance
inv_R = inv(R)
H = np.eye(n)  # Observation matrix


N = 10
n_ensemble = 20
observation_interval = 1
initial_state = random.normal(random.PRNGKey(0), (n,))  # Initial state

l96_model = Lorenz96(dt = 0.01, F = 8)
state_transition_function = l96_model.step
# Generate true states and observations using the Lorenz '96 model
key = random.PRNGKey(0)
observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, l96_model.step, observation_interval)

In [10]:
from jax.lax import scan

def log_likelihood(v, y, H, inv_R, R, J, J0):
    """
    Computes the log-likelihood of observations given state estimates.
    """
    def log_likelihood_j(_, v_y):
        v_j, y_j = v_y
        error = y_j - H @ v_j
        ll = error.T @ inv_R @ error
        return _, ll
    _, lls = lax.scan(log_likelihood_j, None, (v, y))
    sum_ll = sum(lls)
    return -0.5 * sum_ll - 0.5 * (J - J0) * np.log(2 * np.pi) - 0.5 * (J - J0) * np.log(det(R))


def KL_gaussian(m1, C1, m2, C2):
    """
    Computes the Kullback-Leibler divergence between two Gaussian distributions.
    m1, C1: Mean and covariance of the first Gaussian distribution.
    m2, C2: Mean and covariance of the second Gaussian distribution.
    n: number of state variables
    """
    C2_inv = inv(C2)
    log_det_ratio = (np.log(np.linalg.eigvals(C2)).sum() - np.log(np.linalg.eigvals(C1)).sum()).real # log(det(C2) / det(C1)), works better with limited precision because the determinant is practically 0
    return 0.5 * (log_det_ratio - n + np.trace(C2_inv @ C1) + ((m2 - m1).T @ C2_inv @ (m2 - m1)))


def KL_sum(m, C, Q, key):
    """
    Computes the sum of KL divergences between the predicted and updated state distributions.
    """
    def KL_j(_, m_C_y):
        m_prev, m_curr, C_prev, C_curr, key = m_C_y
        key, *subkeys_inner = random.split(key, num=N)
        def inner_map(subkey):
            perturbed_state = m_prev + random.multivariate_normal(subkey, np.zeros(n), C_prev)
            m_pred = state_transition_function(perturbed_state)
            return KL_gaussian(m_curr, C_curr, m_pred, Q) #not sure if use of Q here is correct
        mean_kl = np.mean(lax.map(inner_map, np.array(subkeys_inner)), axis=0)
        return _, mean_kl

    _, mean_kls = scan(KL_j, None, (m[:-1, :], m[1:, :], C[:-1, :, :], C[1:, :, :], np.array(random.split(key, num=m.shape[0]-1))))
    kl_sum = sum(mean_kls)
    return kl_sum


In [11]:

def ensrf_step(ensemble, y, H, Q, R, localization_matrix, inflation):
    n_ensemble = ensemble.shape[1]
    x_m = np.mean(ensemble, axis=1)
    A = ensemble - x_m.reshape((-1, 1))
    Pf = inflation * A @ A.T / (n_ensemble - 1)
    P = Pf * localization_matrix + Q  # Element-wise multiplication for localization
    K = P @ H.T @ np.linalg.inv(H @ P @ H.T + R)
    x_m += K @ (y - H @ x_m)
    M = np.eye(x_m.shape[0]) + P @ H.T @ np.linalg.inv(R) @ H
    # U, s, Vh = svd(M)
    # s_inv_sqrt = np.diag(s**-0.5)
    # M_inv_sqrt = U @ s_inv_sqrt @ Vh apparently svd cannot be gradiented
    eigenvalues, eigenvectors = eigh(M)
    inv_sqrt_eigenvalues = 1 / np.sqrt(eigenvalues)
    Lambda_inv_sqrt = np.diag(inv_sqrt_eigenvalues)
    M_inv_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
    updated_ensemble = x_m.reshape((-1, 1)) + M_inv_sqrt @ A
    return updated_ensemble, P  # Now also returning P


def ensrf_steps(ensemble_init, observations, H, Q, R, localization_matrix, inflation):
    def inner(carry, t):
        ensemble, covariances = carry
        obs = observations[t, :]
        ensemble_updated, P_updated = lax.cond(
            t % observation_interval == 0,
            lambda _: ensrf_step(ensemble, obs, H, Q, R, localization_matrix, inflation),
            lambda _: (ensemble, np.zeros_like(Q)),  # Return zero covariance for non-observation steps
            None)
        covariances = covariances.at[t].set(P_updated)
        return (ensemble_updated, covariances), ensemble_updated

    covariances_init = np.zeros((n_timesteps, *Q.shape))
    _, states = lax.scan(inner, (ensemble_init, covariances_init), np.arange(n_timesteps))

    return states, covariances_init



def var_cost(radius, ensemble_init, observations, H, Q, R, inflation, key, J, J0):
    localization_matrix = generate_gc_localization_matrix(n, radius)
    states, covariances = ensrf_steps(ensemble_init, observations, H, Q, R, localization_matrix, inflation)
    ensemble_mean = np.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension

    key, *subkeys = random.split(key, num=N+1)
    kl_sum = KL_sum(ensemble_mean, covariances, Q, key)
    

    # Calculate log-likelihood values for a batch of perturbed states
    log_likelihood_vals = lax.map(
        lambda subkey: log_likelihood(
            random.multivariate_normal(subkey, ensemble_mean, covariances),
            observations, H, np.linalg.inv(R), R, J, J0),
        np.array(subkeys))

    cost = kl_sum - np.mean(log_likelihood_vals)
    
    return cost

In [12]:
from IPython.display import clear_output
from jax import grad
from tqdm.notebook import tqdm
import jax.numpy as np
from jax import random


# Modification: Use grad to compute the gradient with respect to the inflation parameter
var_cost_grad = grad(var_cost, argnums=0)

radius_opt = 10.01  # Example starting value for inflation
alpha = 1e-6  # Learning rate
key = random.PRNGKey(0)  # Random key
N = 10  # Number of MC samples
m0 = initial_state
C0 = Q  # Initial covariance, assuming Q is your process noise covariance
ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
inflation = 1.2

rmses = []
norms = []
from jax.config import config
config.update("jax_debug_nans", True)
for i in tqdm(range(20)):
    key, subkey = random.split(key)
    localization_matrix = generate_gc_localization_matrix(n, radius_opt)
    states, _ = ensrf_steps(ensemble_init, observations, H, Q, R, localization_matrix, inflation)
    print(states.shape)
    ensemble_mean = np.mean(states, axis=-1)  # Taking the mean across the ensemble members dimension
    rmse = np.sqrt(np.mean((ensemble_mean - true_states)**2))
    rmses.append(rmse)
    #clear_output(wait=True)
    print(f"RMSE: {rmse}")
    
    # Gradient descent step for inflation parameter
    grad_inflation = var_cost_grad(radius_opt, ensemble_init, observations, H, Q, R, inflation, subkey, num_steps, J0)
    radius_opt -= alpha * grad_inflation  # Update inflation parameter
    print(radius_opt)
    



  from jax.config import config


  0%|          | 0/20 [00:00<?, ?it/s]

(30, 40, 20)
RMSE: 0.5659351348876953


FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/google/jax.