In [16]:
# Import necessary libraries
import jax.numpy as np
from jax import random, grad, jit
from jax.scipy.linalg import inv
from jax.numpy.linalg import norm
from tqdm import tqdm
from sklearn.datasets import make_spd_matrix
from jax_models import Lorenz96
from jax_models import visualize_observations, Lorenz96, generate_true_states, generate_gc_localization_matrix
from jax_filters import ensrf_steps
#from jax_vi import KL_gaussian, log_likelihood


# Parameters
F = 8.0
dt = 0.01
num_steps = 30  # Number of time steps
n_timesteps = num_steps
J = num_steps
J0 = 0
n = 40   # Number of state variables
Q = 0.1 * np.eye(n)  # Process noise covariance
R_matrix = make_spd_matrix(n)  # Generating a symmetric positive definite matrix for R
R = np.array(R_matrix)  # Observation noise covariance
inv_R = inv(R)
H = np.eye(n)  # Observation matrix


N = 10
n_ensemble = 20
observation_interval = 1
initial_state = random.normal(random.PRNGKey(0), (n,))  # Initial state

l96_model = Lorenz96(dt = 0.01, F = 8)
state_transition_function = l96_model.step
# Generate true states and observations using the Lorenz '96 model
key = random.PRNGKey(0)
observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, l96_model.step, observation_interval)

In [10]:
import jax.numpy as jnp
from jax import random, grad, jit, jacfwd, lax, vmap, jacrev
from jax.scipy.linalg import inv, det, svd
from jax.lax import scan
@jit
def KL_gaussian(m1, C1, m2, C2):
    """
    Computes the Kullback-Leibler divergence between two Gaussian distributions.
    m1, C1: Mean and covariance of the first Gaussian distribution.
    m2, C2: Mean and covariance of the second Gaussian distribution.
    """
    C2_inv = inv(C2)
    log_det_ratio = (jnp.log(jnp.linalg.eigvals(C2)).sum() - jnp.log(jnp.linalg.eigvals(C1)).sum()).real # log(det(C2) / det(C1)), works better with limited precision because the determinant is practically 0
    return 0.5 * (log_det_ratio - n + jnp.trace(C2_inv @ C1) + ((m2 - m1).T @ C2_inv @ (m2 - m1)))

@jit
def log_likelihood(v, y):
    """
    v: State estimates.
    y: Observations.
    """
    def log_likelihood_j(_, v_y):
        v_j, y_j = v_y
        error = y_j - H @ v_j
        ll = error.T @ inv_R @ error
        return _, ll
    _, lls = scan(log_likelihood_j, None, (v[1:, :], y))
    sum_ll = sum(lls)
    return -0.5 * sum_ll - 0.5 * (J - J0) * jnp.log(2 * jnp.pi) - 0.5 * (J - J0) * jnp.log(det(R))


In [18]:


@jit
def KL_sum(m, C, Q, key):
    """
    Computes the sum of KL divergences between the predicted and updated state distributions.
    """
    def KL_j(_, m_C_y):
        m_prev, m_curr, C_prev, C_curr, key = m_C_y
        key, *subkeys_inner = random.split(key, num=N)
        def inner_map(subkey):
            perturbed_state = m_prev + random.multivariate_normal(subkey, np.zeros(n), C_prev)
            m_pred = state_transition_function(perturbed_state)
            return KL_gaussian(m_curr, C_curr, m_pred, Q) #not sure if use of Q here is correct
        mean_kl = np.mean(lax.map(inner_map, np.array(subkeys_inner)), axis=0)
        return _, mean_kl

    _, mean_kls = scan(KL_j, None, (m[:-1, :], m[1:, :], C[:-1, :, :], C[1:, :, :], np.array(random.split(key, num=m.shape[0]-1))))
    kl_sum = sum(mean_kls)
    return kl_sum


@jit
def filter_step(m_C_prev, y_curr, K, Q, H, R):
    """
    Apply a single forecast and update step using the Kalman filter.
    """
    m_prev, C_prev = m_C_prev
    m_pred = state_transition_function(m_prev)
    F_jac = jacrev(state_transition_function)(m_pred)
    m_update = (np.eye(n) - K @ H) @ m_pred + K @ y_curr
    C_pred = F_jac @ C_prev @ F_jac.T + Q
    C_update = (np.eye(n) - K @ H) @ C_pred @ (np.eye(n) - K @ H).T + K @ R @ K.T
    return (m_update, C_update), (m_update, C_update)

@jit
def filtered(K, m0, C0, Q, H, R, y):
    """
    Applies the filtering process to estimate the system state over time.
    """
    _, m_C = scan(lambda m_C_prev, y_curr: filter_step(m_C_prev, y_curr, K, Q, H, R), (m0, C0), y)

    m, C = m_C
    return np.vstack((m0, m)), np.vstack((C0.reshape(1, n, n), C))

@jit
def var_cost(K, m0, C0, Q, H, R, y, key, J, J0):
    """
    Computes the cost function for optimization, combining KL divergence and log-likelihood.
        J, J0, H, inv_R, R, n: Parameters for log_likelihood calculation.
    """
    m, C = filtered(K, m0, C0, Q, H, R,y)
    key, *subkeys = random.split(key, num=N+1)
    log_likelihood_vals = lax.map(lambda subkey: log_likelihood(random.multivariate_normal(subkey, m, C), y), np.array(subkeys))
    return (KL_sum(m, C, Q, key) - np.mean(log_likelihood_vals))

In [19]:
from IPython.display import clear_output

var_cost_grad = grad(var_cost, argnums=0)

# Initial guess for K and optimization parameters
K_opt = np.eye(n) * 0.4
alpha = 1e-6
key = random.PRNGKey(0)
N = 10
m0 = initial_state
C0 = Q
#kl_divergence_diff = []
#iterations = []
rmses = []
norms = []
for i in tqdm(range(20)):
    key, _ = random.split(key)
    m, C = filtered(K_opt, m0, C0, Q, H, R, observations)
    rmses.append(np.sqrt(np.mean((m[1:,:] - true_states)**2)))
    clear_output(wait=True)
    print(np.sqrt(np.mean((m[1:,:] - true_states)**2)))
    norms.append(np.linalg.norm(K_opt))
    K_opt -= alpha * var_cost_grad(K_opt, initial_state, Q, Q,H, R, observations, key, num_steps, J0)


100%|██████████| 20/20 [00:06<00:00,  2.95it/s]

0.24399751



