In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, jacfwd, jacrev
from jax.scipy.linalg import inv, svd, eigh, det
from jax.lax import scan
from scipy.linalg import solve_discrete_are, norm
from jax import random, jit, value_and_grad

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from jax.tree_util import Partial
from functools import partial
from jax_vi import KL_gaussian, log_likelihood, KL_sum, plot_optimization_results, plot_k_matrices
from jax_filters import apply_filtering_fixed_nonlinear, kalman_filter_process, filter_step_nonlinear, ensrf_steps
from jax_models import visualize_observations, Lorenz63, generate_true_states, generate_localization_matrix

from flax.training import train_state
from functools import partial
import optax

N = 10  # number of Monte Carlo samples
num_steps = 1000  # Total number of time steps
num_train_steps = 500  # Number of training time steps
num_test_steps = num_steps - num_train_steps  # Number of testing time steps
J0 = 0
n = 3   # Number of state variables
key = random.PRNGKey(42)  # Random key for reproducibility
Q = 0.1 * jnp.eye(n)  # Process noise covariance
R = 0.05 * jnp.eye(n)  # Observation noise covariance
H = jnp.eye(n)  # Observation matrix (identity matrix for direct observation of all state variables)

n_ensemble = 10
observation_interval = 1
initial_state = random.normal(random.PRNGKey(42), (n,)) 
m0 = initial_state
C0 = Q

l63_model = Lorenz63()
l63_step = Partial(l63_model.step)

jacobian_function = jacrev(l63_step, argnums=0)
jac_func = Partial(jacobian_function)
state_transition_function = l63_step

# Generate true states and observations for 1000 steps
true_states, observations = generate_true_states(key, num_steps, n, initial_state, H, Q, R, l63_step, observation_interval)

In [2]:
from jax import random, jit, lax

@jit
def kalman_step(state, observation, params):
    m_prev, C_prev = state
    state_transition_function, jacobian_function, H, Q, R = params
    
    # Prediction step
    m_pred = state_transition_function(m_prev)
    F_jac = jacobian_function(m_prev)
    C_pred = F_jac @ C_prev @ F_jac.T + Q
    
    # Update step
    S = H @ C_pred @ H.T + R
    K_curr = C_pred @ H.T @ jnp.linalg.inv(S)
    m_update = m_pred + K_curr @ (observation - H @ m_pred)
    C_update = (jnp.eye(H.shape[1]) - K_curr @ H) @ C_pred
    
    return (m_update, C_update), (m_pred, C_pred, m_update, C_update, K_curr)

@jit
def kalman_filter_process(state_transition_function, jacobian_function, m0, C0, observations, H, Q, R):
    params = (state_transition_function, jacobian_function, H, Q, R)
    initial_state = (m0, C0)
    
    # Modified scan to capture both prediction and analysis states
    _, (m_preds, C_preds, m_updates, C_updates, Ks) = lax.scan(
        lambda state, obs: kalman_step(state, obs, params),
        initial_state, 
        observations
    )
    
    return m_preds, C_preds, m_updates, C_updates, Ks


m_preds, C_preds, m_updates, C_updates, Ks = kalman_filter_process(state_transition_function, jac_func, m0, C0, observations, H, Q, R)

In [3]:
key, subkey = random.split(key)
ensemble_init = random.multivariate_normal(subkey, m0, C0, (n_ensemble,)).T  # Shape: (n, ensemble_size)
localization_matrix = generate_localization_matrix(3,1)

ensemble_preds, C_preds, ensembles, covariances = ensrf_steps(state_transition_function, ensemble_init, num_train_steps, observations, 1, H, Q, R, localization_matrix=localization_matrix, inflation=1.9, key=key)



In [5]:
@partial(jit, static_argnums=(2, 5))
def nn_analysis_filter_steps(
    state_transition_function,
    ensemble_init,
    num_steps,
    observations,
    observation_interval,
    model,
    params,
    key,
):
    model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
    key, *subkeys = random.split(key, num=num_steps + 1)
    subkeys = jnp.array(subkeys)

    def inner(carry, t):
        ensemble = carry
        ensemble_predicted = model_vmap(ensemble)
        def true_fun(_):
            # Flatten the predicted ensemble and prepare input for the model
            pred_flat = ensemble_predicted.reshape(-1)  # (n_ensemble * n,)
            input_t = jnp.concatenate([pred_flat, observations[t]])  # Append observation
            # Use the NN to predict the analysis ensemble
            analysis_flat = model.apply(params, input_t)
            # Reshape back to (n, n_ensemble)
            analysis_ensemble = analysis_flat.reshape(ensemble_predicted.shape)
            return analysis_ensemble

        def false_fun(_):
            return ensemble_predicted

        updated_ensemble = lax.cond(
            t % observation_interval == 0, true_fun, false_fun, operand=None
        )
        return updated_ensemble, (ensemble_predicted, updated_ensemble)

    # Perform filtering over all time steps
    _, (ensemble_preds, ensembles) = lax.scan(
        inner, ensemble_init, jnp.arange(num_steps)
    )

    return ensemble_preds, ensembles


In [12]:
from flax import linen as nn

class AnalysisNet(nn.Module):
    input_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        return x
        
key, subkey = random.split(key)
kl_model = AnalysisNet(input_dim=inputs.shape[1], output_dim=outputs.shape[1])
params = kl_model.init(subkey, inputs[0])
#params = mse_state.params

num_train_steps = 500

@jit
def var_cost(params):
    key, *subkeys = random.split(random.PRNGKey(42), num=N+1)

    y = observations[:num_train_steps]
    ensemble_init = random.normal(key, (n, n_ensemble))  
    ensemble_preds, ensembles = nn_analysis_filter_steps(
        state_transition_function, ensemble_init, num_train_steps, y, observation_interval, kl_model, params, key
    )

    m_preds = jnp.mean(ensemble_preds, axis=2)

    m_updates = jnp.mean(ensembles, axis=2)

    C_preds = jnp.array([jnp.cov(ensemble.T, rowvar=False) for ensemble in ensemble_preds])
    C_updates = jnp.array([jnp.cov(ensemble.T, rowvar=False) for ensemble in ensembles])
    
    kl_sum = KL_sum(m_preds, C_preds, m_updates, C_updates, n, state_transition_function, Q, key)

    def inner_map(subkey):
        return log_likelihood(random.multivariate_normal(subkey, m_updates, C_updates), y, H, R, num_train_steps, J0)

    cost = kl_sum - jnp.nanmean(jax.lax.map(inner_map, jnp.vstack(subkeys)))
    return cost
    
tx = optax.adam(learning_rate=1e-3)
kl_state = train_state.TrainState.create(apply_fn=kl_model.apply, params=params, tx=tx)

@jit
def train_step(state):
    loss, grads = value_and_grad(var_cost)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

num_epochs = 200

for epoch in tqdm(range(num_epochs)):    
    epoch_loss = 0
    kl_state, loss = train_step(kl_state)
    epoch_loss = loss
    if epoch % 5 == 0:
        print(f"Epoch {epoch + 1}, Loss: Loss: {epoch_loss:.2e}")


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1, Loss: Loss: 4.11e+06
Epoch 6, Loss: Loss: 1.83e+06
Epoch 11, Loss: Loss: 1.20e+06
Epoch 16, Loss: Loss: 5.10e+05
Epoch 21, Loss: Loss: 2.07e+05
Epoch 26, Loss: Loss: 1.58e+05
Epoch 31, Loss: Loss: 1.21e+05
Epoch 36, Loss: Loss: 8.24e+04
Epoch 41, Loss: Loss: 5.95e+04
Epoch 46, Loss: Loss: 4.84e+04
Epoch 51, Loss: Loss: 3.65e+04
Epoch 56, Loss: Loss: 2.52e+04
Epoch 61, Loss: Loss: 1.80e+04
Epoch 66, Loss: Loss: 1.20e+04
Epoch 71, Loss: Loss: 7.36e+03
Epoch 76, Loss: Loss: 4.95e+03
Epoch 81, Loss: Loss: 3.55e+03
Epoch 86, Loss: Loss: 2.62e+03
Epoch 91, Loss: Loss: 2.06e+03
Epoch 96, Loss: Loss: 1.64e+03
Epoch 101, Loss: Loss: 1.36e+03
Epoch 106, Loss: Loss: 1.17e+03
Epoch 111, Loss: Loss: 1.02e+03
Epoch 116, Loss: Loss: 9.14e+02
Epoch 121, Loss: Loss: 8.35e+02
Epoch 126, Loss: Loss: 8.41e+02
Epoch 131, Loss: Loss: 8.05e+02
Epoch 136, Loss: Loss: 7.49e+02
Epoch 141, Loss: Loss: 6.33e+02
Epoch 146, Loss: Loss: 6.55e+02
Epoch 151, Loss: Loss: 6.34e+02
Epoch 156, Loss: Loss: 5.52e+0

In [22]:
class AnalysisNet2(nn.Module):
    input_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x, deterministic: bool = True):
        batch_size = 1
        pred_ensemble_flat = x[:n_ensemble * n]  
        obs = x[n_ensemble * n:]  
        pred_ensemble = pred_ensemble_flat.reshape(n_ensemble, n)  # Shape: (10, 3)

        # Apply a 1D convolution to the ensemble part
        pred_ensemble = nn.Conv(features=64, kernel_size=3, strides=1, padding='VALID')(pred_ensemble)
        pred_ensemble = nn.relu(pred_ensemble)  

        # Flatten the convolutional output
        pred_ensemble_flat = pred_ensemble.reshape(-1)  # Shape: (e, 8 * 64)

        # Concatenate the processed ensemble with the observations
        combined = jnp.concatenate([pred_ensemble_flat, obs], axis=-1)  # Shape: (batch_size, 8 * 64 + 3)

        # Pass through fully connected layers
        x = nn.Dense(512)(combined)
        x = nn.relu(x)
        x = nn.Dropout(0.3)(x, deterministic=deterministic)

        # Another fully connected layer
        x = nn.Dense(512)(x)
        x = nn.relu(x)

        # Final output layer to produce the updated analysis ensemble
        x = nn.Dense(self.output_dim)(x)  # Shape: (batch_size, output_dim)
        return x


Loss from MSE-trained NN analysis filter

In [27]:
key, subkey = random.split(key)
kl_model2 = AnalysisNet2(input_dim=inputs.shape[1], output_dim=outputs.shape[1])
params = kl_model2.init(subkey, inputs[0])
#params = mse_state.params

num_train_steps = 500

@jit
def var_cost(params):
    key, *subkeys = random.split(random.PRNGKey(42), num=N+1)

    y = observations[:num_train_steps]
    ensemble_init = random.normal(key, (n, n_ensemble))  
    ensemble_preds, ensembles = nn_analysis_filter_steps(
        state_transition_function, ensemble_init, num_train_steps, y, observation_interval, kl_model2, params, key
    )

    m_preds = jnp.mean(ensemble_preds, axis=2)

    m_updates = jnp.mean(ensembles, axis=2)

    C_preds = jnp.array([jnp.cov(ensemble.T, rowvar=False) for ensemble in ensemble_preds])
    C_updates = jnp.array([jnp.cov(ensemble.T, rowvar=False) for ensemble in ensembles])
    
    kl_sum = KL_sum(m_preds, C_preds, m_updates, C_updates, n, state_transition_function, Q, key)

    def inner_map(subkey):
        return log_likelihood(random.multivariate_normal(subkey, m_updates, C_updates), y, H, R, num_train_steps, J0)

    cost = kl_sum - jnp.nanmean(jax.lax.map(inner_map, jnp.vstack(subkeys)))
    return cost
    
tx = optax.adam(learning_rate=1e-4)
kl_state2 = train_state.TrainState.create(apply_fn=kl_model2.apply, params=params, tx=tx)

@jit
def train_step(state):
    loss, grads = value_and_grad(var_cost)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

num_epochs = 200

for epoch in tqdm(range(num_epochs)):    
    epoch_loss = 0
    kl_state2, loss = train_step(kl_state2)
    epoch_loss = loss
    if epoch % 5 == 0:
        print(f"Epoch {epoch + 1}, Loss: Loss: {epoch_loss:.2e}")


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1, Loss: Loss: 4.33e+06
Epoch 6, Loss: Loss: 4.15e+06
Epoch 11, Loss: Loss: 3.93e+06
Epoch 16, Loss: Loss: 3.50e+06
Epoch 21, Loss: Loss: 2.53e+06
Epoch 26, Loss: Loss: 2.08e+06
Epoch 31, Loss: Loss: 1.59e+06
Epoch 36, Loss: Loss: 1.20e+06
Epoch 41, Loss: Loss: 9.09e+05
Epoch 46, Loss: Loss: 6.84e+05
Epoch 51, Loss: Loss: 5.66e+05
Epoch 56, Loss: Loss: 4.56e+05
Epoch 61, Loss: Loss: 3.53e+05
Epoch 66, Loss: Loss: 2.60e+05
Epoch 71, Loss: Loss: 1.98e+05
Epoch 76, Loss: Loss: 1.63e+05
Epoch 81, Loss: Loss: 1.29e+05
Epoch 86, Loss: Loss: 1.11e+05
Epoch 91, Loss: Loss: 9.78e+04
Epoch 96, Loss: Loss: 8.89e+04
Epoch 101, Loss: Loss: 8.21e+04
Epoch 106, Loss: Loss: 7.69e+04
Epoch 111, Loss: Loss: 7.30e+04
Epoch 116, Loss: Loss: 7.00e+04
Epoch 121, Loss: Loss: 6.75e+04
Epoch 126, Loss: Loss: 6.58e+04
Epoch 131, Loss: Loss: 6.42e+04
Epoch 136, Loss: Loss: 6.25e+04
Epoch 141, Loss: Loss: 6.20e+04
Epoch 146, Loss: Loss: 6.10e+04
Epoch 151, Loss: Loss: 5.97e+04
Epoch 156, Loss: Loss: 5.67e+0

In [158]:
ensemble_preds, nn_ensemble = nn_analysis_filter_steps(
    state_transition_function=l63_step,
    ensemble_init=ensemble_init,
    num_steps=num_steps,
    observations=observations,
    observation_interval=1,
    model=mse_model,
    params=kl_state.params,
    key=key,
)
ensemble_means = jnp.mean(nn_ensemble, axis=2)  # Shape: (num_steps, n)
pred_means = jnp.mean(ensemble_preds, axis=2) 
# Compute the RMSE
rmse = jnp.sqrt(jnp.mean((ensemble_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print("RMSE Analysis to True State",jnp.mean(rmse))
rmse = jnp.sqrt(jnp.mean((pred_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print("RMSE Predicted to True State", jnp.mean(rmse))
true_states_std = jnp.std(true_states, axis=1)  # Shape: (num_steps,)
print("Standard Deviation of True States (per timestep):", jnp.mean(true_states_std))


# Compute the RMSE for extended kalman
rmse = jnp.sqrt(jnp.mean((m_updates - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print("MSE Loss from Kalman Filter", jnp.mean(rmse))

key, subkey = random.split(key)
ensemble_init = random.multivariate_normal(subkey, m0, C0, (n_ensemble,)).T  # Shape: (n, ensemble_size)
localization_matrix = generate_localization_matrix(3,1)

ensemble_preds, C_preds, ensembles, covariances = ensrf_steps(state_transition_function, ensemble_init, num_steps, observations, 1, H, Q, R, localization_matrix=localization_matrix, inflation=1.9, key=key)

ensemble_means = jnp.mean(ensembles, axis=2)  # Shape: (num_steps, n)

# Compute the RMSE
rmse = jnp.sqrt(jnp.mean((ensemble_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print(jnp.mean(rmse))

RMSE Analysis to True State 15.012643
RMSE Predicted to True State 15.237214
Standard Deviation of True States (per timestep): 12.618656


In [35]:
import properscoring as ps
import jax.numpy as jnp
from jax import random

# MSE Loss Calculation
def compute_mse_for_all_timesteps(pred_means, true_states):
    mse_vals = jnp.mean((pred_means - true_states) ** 2, axis=1)  # Shape: (num_steps,)
    return jnp.mean(mse_vals)

# Neural Network Filter
ensemble_preds, nn_ensemble = nn_analysis_filter_steps(
    state_transition_function=l63_step,
    ensemble_init=ensemble_init,
    num_steps=num_steps,
    observations=observations,
    observation_interval=1,
    model=kl_model,
    params=kl_state.params,
    key=key,
)
ensemble_means = jnp.mean(nn_ensemble, axis=2)  # Shape: (num_steps, n)
pred_means = jnp.mean(ensemble_preds, axis=2)

# Compute CRPS using properscoring for the neural network-based filtering
crps_nn_analysis = jnp.mean(jnp.array([ps.crps_ensemble(true_states[t], nn_ensemble[t]) for t in range(nn_ensemble.shape[0])]))
crps_nn_pred = jnp.mean(jnp.array([ps.crps_ensemble(true_states[t], ensemble_preds[t]) for t in range(ensemble_preds.shape[0])]))

# Compute MSE for the neural network-based filtering
mse_nn_analysis = compute_mse_for_all_timesteps(ensemble_means, true_states)
mse_nn_pred = compute_mse_for_all_timesteps(pred_means, true_states)

print("CRPS Analysis to True State (NN)", crps_nn_analysis)
print("CRPS Predicted to True State (NN)", crps_nn_pred)
print("MSE Analysis to True State (NN)", mse_nn_analysis)
print("MSE Predicted to True State (NN)", mse_nn_pred)

# Compute CRPS for extended Kalman Filter (EKF) using properscoring
crps_ekf = jnp.mean(jnp.array([ps.crps_ensemble(true_states[t], m_updates[t]) for t in range(m_updates.shape[0])]))
mse_ekf = compute_mse_for_all_timesteps(m_updates, true_states)

print("CRPS Loss from Kalman Filter", crps_ekf)
print("MSE Loss from Kalman Filter", mse_ekf)

# Ensemble Kalman Filter (EnsRF)
key, subkey = random.split(key)
ensemble_init = random.multivariate_normal(subkey, m0, C0, (n_ensemble,)).T  # Shape: (n, ensemble_size)
localization_matrix = generate_localization_matrix(3, 1)

ensemble_preds, C_preds, ensembles, covariances = ensrf_steps(
    state_transition_function, ensemble_init, num_steps, observations, 
    1, H, Q, R, localization_matrix=localization_matrix, inflation=1.9, key=key
)

# Compute CRPS for the ensemble-based filtering method (EnsRF) using properscoring
crps_ensrf = jnp.mean(jnp.array([ps.crps_ensemble(true_states[t], ensembles[t]) for t in range(ensembles.shape[0])]))
mse_ensrf = compute_mse_for_all_timesteps(jnp.mean(ensembles, axis=2), true_states)

print("CRPS from Ensemble Kalman Filter (EnsRF)", crps_ensrf)
print("MSE from Ensemble Kalman Filter (EnsRF)", mse_ensrf)


CRPS Analysis to True State (NN) 12.7385025
CRPS Predicted to True State (NN) 12.739106
MSE Analysis to True State (NN) 293.74744
MSE Predicted to True State (NN) 293.79538
CRPS Loss from Kalman Filter 0.19136108
MSE Loss from Kalman Filter 0.057385277
CRPS from Ensemble Kalman Filter (EnsRF) 0.15973847
MSE from Ensemble Kalman Filter (EnsRF) 0.058262277


Loss from Classic Kalman Filter:

Loss from EnSRF

In [160]:
key, subkey = random.split(key)
ensemble_init = random.multivariate_normal(subkey, m0, C0, (n_ensemble,)).T  # Shape: (n, ensemble_size)
localization_matrix = generate_localization_matrix(3,1)

ensemble_preds, C_preds, ensembles, covariances = ensrf_steps(state_transition_function, ensemble_init, num_steps, observations, 1, H, Q, R, localization_matrix=localization_matrix, inflation=1.9, key=key)

ensemble_means = jnp.mean(ensembles, axis=2)  # Shape: (num_steps, n)

# Compute the RMSE
rmse = jnp.sqrt(jnp.mean((ensemble_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print(jnp.mean(rmse))

0.22254649
