In [85]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, jacfwd, jacrev
from jax.scipy.linalg import inv, svd, eigh, det
from jax.lax import scan
from scipy.linalg import solve_discrete_are, norm
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from jax.tree_util import Partial
from functools import partial
from jax_vi import KL_gaussian, log_likelihood, KL_sum, plot_optimization_results, plot_k_matrices
from jax_filters import apply_filtering_fixed_nonlinear, kalman_filter_process, filter_step_nonlinear, ensrf_steps
from jax_models import visualize_observations, Lorenz63, generate_true_states, generate_localization_matrix

from functools import partial
import optax

N = 10 # number of Monte Carlo samples
num_steps = 500  # Number of time steps
J0 = 0
n = 3   # Number of state variables
key = random.PRNGKey(0)  # Random key for reproducibility
Q = 0.1 * jnp.eye(n)  # Process noise covariance
R = 0.05 * jnp.eye(n)  # Observation noise covariance
H = jnp.eye(n) # Observation matrix (identity matrix for direct observation of all state variables)

n_ensemble = 10
observation_interval = 1
initial_state = random.normal(random.PRNGKey(0), (n,)) 
m0 = initial_state
C0 = Q

l63_model = Lorenz63()
l63_step = Partial(l63_model.step)

jacobian_function = jacrev(l63_step, argnums=0)
jac_func = Partial(jacobian_function)
state_transition_function = l63_step

observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, l63_step, observation_interval)



In [92]:
key, subkey = random.split(key)
ensemble_init = random.multivariate_normal(subkey, m0, C0, (n_ensemble,)).T  # Shape: (n, ensemble_size)
localization_matrix = generate_localization_matrix(3,1)

ensemble_preds, C_preds, ensembles, covariances = ensrf_steps(state_transition_function, ensemble_init, num_steps, observations, 1, H, Q, R, localization_matrix=localization_matrix, inflation=1.9, key=key)



In [102]:
from jax import random, jit, lax

@jit
def kalman_step(state, observation, params):
    m_prev, C_prev = state
    state_transition_function, jacobian_function, H, Q, R = params
    
    # Prediction step
    m_pred = state_transition_function(m_prev)
    F_jac = jacobian_function(m_prev)
    C_pred = F_jac @ C_prev @ F_jac.T + Q
    
    # Update step
    S = H @ C_pred @ H.T + R
    K_curr = C_pred @ H.T @ jnp.linalg.inv(S)
    m_update = m_pred + K_curr @ (observation - H @ m_pred)
    C_update = (jnp.eye(H.shape[1]) - K_curr @ H) @ C_pred
    
    return (m_update, C_update), (m_pred, C_pred, m_update, C_update, K_curr)

@jit
def kalman_filter_process(state_transition_function, jacobian_function, m0, C0, observations, H, Q, R):
    params = (state_transition_function, jacobian_function, H, Q, R)
    initial_state = (m0, C0)
    
    # Modified scan to capture both prediction and analysis states
    _, (m_preds, C_preds, m_updates, C_updates, Ks) = lax.scan(
        lambda state, obs: kalman_step(state, obs, params),
        initial_state, 
        observations
    )
    
    return m_preds, C_preds, m_updates, C_updates, Ks


m_preds, C_preds, m_updates, C_updates, Ks = kalman_filter_process(state_transition_function, jac_func, m0, C0, observations, H, Q, R)

In [104]:
m_preds.shape

(500, 3)

In [105]:
ensemble_preds = ensemble_preds.transpose(0, 2, 1)  # Shape: (num_steps, ensemble_size, n)
ensembles = ensembles.transpose(0, 2, 1)            # Shape: (num_steps, ensemble_size, n)

# Prepare inputs and outputs. Note that rather than use ensemble mean and covariances, we are using the entire ensemble
inputs = []
outputs = []

for t in range(num_steps):
    pred_ensemble = ensemble_preds[t]  # Shape: (ensemble_size, n)
    obs = observations[t]              # Shape: (observation_dim,)
    analysis_ensemble = ensembles[t]   # Shape: (ensemble_size, n)
    
    pred_ensemble_flat = pred_ensemble.reshape(-1)  # Shape: (ensemble_size * n,)
    input_t = jnp.concatenate([pred_ensemble_flat, obs])  # Shape: (ensemble_size * n + observation_dim,)
    
    output_t = analysis_ensemble.reshape(-1)  # Flatten the analysis ensemble    
    inputs.append(input_t)
    outputs.append(output_t)

inputs = jnp.array(inputs)   # Shape: (num_steps, input_dim)
outputs = jnp.array(outputs) # Shape: (num_steps, output_dim)

print(inputs.shape)
print(outputs.shape)

(500, 33)
(500, 30)


In [56]:
def kl_loss(params, batch):
    # note that this loss does not use y at all, similar to fixed gain experiments
    
    x, y = batch  # x: (batch_size, input_dim), y: (batch_size, ensemble_size * n)
    batch_size = x.shape[0]
    ensemble_size = n_ensemble
    n = y.shape[1] // ensemble_size  # State dimension (since y is flattened)
    pred_ensemble_flat = x[:, :ensemble_size * n]  # Shape: (batch_size, ensemble_size * n)
    obs = x[:, ensemble_size * n:]  # Shape: (batch_size, observation_dim)

    # Reshape the predicted ensemble to (batch_size, ensemble_size, n)
    pred_ensemble = pred_ensemble_flat.reshape(batch_size, ensemble_size, n)
    def compute_mean_cov(ensemble):
        mu = jnp.mean(ensemble, axis=0)  # Mean over ensemble members, shape: (n,)
        Sigma = jnp.cov(ensemble, rowvar=False)  # Covariance, shape: (n, n)
        return mu, Sigma

    mu_pred_ensemble, Sigma_pred_ensemble = jax.vmap(compute_mean_cov)(pred_ensemble)
    # Model prediction: predicted analysis ensemble (flattened)
    preds = model.apply(params, x)  # preds: (batch_size, ensemble_size * n)
    # Reshape preds and y to (batch_size, ensemble_size, n)
    preds = preds.reshape(batch_size, ensemble_size, n)
    y = y.reshape(batch_size, ensemble_size, n)
    # Compute mean and covariance of the predicted analysis ensemble
    mu_anal, Sigma_anal = jax.vmap(compute_mean_cov)(preds)

    # Compute mean and covariance of the target analysis ensemble (unused)
    # mu_y, Sigma_y = jax.vmap(compute_mean_cov)(y)
    key = random.PRNGKey(0)  # Random key for reproducibility
    key, *subkeys = random.split(key, num=N+1)
    kl_sum = KL_sum(mu_pred_ensemble, Sigma_pred_ensemble, mu_anal, Sigma_anal, n, state_transition_function, Q, key)
    def inner_map(subkey):
        return log_likelihood(random.multivariate_normal(subkey, mu_anal, Sigma_anal), obs, H, R, J=num_steps/batch_size, J0=0)  
    cost = kl_sum - jnp.mean(jax.lax.map(inner_map, jnp.vstack(subkeys)))

    return cost


In [131]:
from flax import linen as nn

class AnalysisNet(nn.Module):
    input_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        return x
        
key, subkey = random.split(key)
model = AnalysisNet(input_dim=inputs.shape[1], output_dim=outputs.shape[1])
params = model.init(subkey, inputs[0])
tx = optax.adam(learning_rate=1e-3)
kl_state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def mse_loss(params, batch):
    x, y = batch
    preds = model.apply(params, x)
    loss = jnp.mean((preds - y) ** 2)
    return loss

@jit
def train_step(state, batch):
    loss, grads = value_and_grad(kl_loss)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

batch_size = 32
num_epochs = 20
num_batches = inputs.shape[0] // batch_size

for epoch in tqdm(range(num_epochs)):
    # Shuffle the data
    perm = random.permutation(random.PRNGKey(epoch), inputs.shape[0])
    inputs_shuffled = inputs[perm]
    outputs_shuffled = outputs[perm]
    
    epoch_loss = 0
    for i in range(num_batches):
        batch_indices = slice(i * batch_size, (i + 1) * batch_size)
        batch = (
            inputs_shuffled[batch_indices],
            outputs_shuffled[batch_indices],
        )
        kl_state, loss = train_step(kl_state, batch)
        epoch_loss += loss
    epoch_loss /= num_batches
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1, Loss: 35939450880.0
Epoch 2, Loss: 7468823040.0
Epoch 3, Loss: 4116057856.0
Epoch 4, Loss: 2309668608.0
Epoch 5, Loss: 1059303552.0
Epoch 6, Loss: 578861440.0
Epoch 7, Loss: 392701216.0
Epoch 8, Loss: 303462976.0
Epoch 9, Loss: 177092080.0
Epoch 10, Loss: 170842816.0
Epoch 11, Loss: 136709280.0
Epoch 12, Loss: 163207536.0
Epoch 13, Loss: 140942144.0
Epoch 14, Loss: 114853976.0
Epoch 15, Loss: 138087680.0
Epoch 16, Loss: 176819472.0
Epoch 17, Loss: 180131328.0
Epoch 18, Loss: 149741504.0
Epoch 19, Loss: 328492224.0
Epoch 20, Loss: 234007072.0


In [106]:
def evaluate_mse_loss(params, inputs, outputs):
    predictions = jax.vmap(lambda x: model.apply(params, x))(inputs)
    mse_loss = jnp.mean((predictions - outputs) ** 2)
    return mse_loss
final_mse_loss = evaluate_mse_loss(state.params, inputs, outputs)
print(f"Final MSE Loss (All Timesteps): {final_mse_loss}")


Final MSE Loss (All Timesteps): 120.82587432861328


In [128]:
params = model.init(subkey, inputs[0])
tx = optax.adam(learning_rate=1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def mse_loss(params, batch):
    x, y = batch
    preds = model.apply(params, x)
    loss = jnp.mean((preds - y) ** 2)
    return loss

@jit
def train_step(state, batch):
    loss, grads = value_and_grad(mse_loss)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

batch_size = 32
num_epochs = 20
num_batches = inputs.shape[0] // batch_size

for epoch in tqdm(range(num_epochs)):
    # Shuffle the data
    perm = random.permutation(random.PRNGKey(epoch), inputs.shape[0])
    inputs_shuffled = inputs[perm]
    outputs_shuffled = outputs[perm]
    
    epoch_loss = 0
    for i in range(num_batches):
        batch_indices = slice(i * batch_size, (i + 1) * batch_size)
        batch = (
            inputs_shuffled[batch_indices],
            outputs_shuffled[batch_indices],
        )
        state, loss = train_step(state, batch)
        epoch_loss += loss
    epoch_loss /= num_batches
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1, Loss: 69.54347229003906
Epoch 2, Loss: 5.544004440307617
Epoch 3, Loss: 1.4658074378967285
Epoch 4, Loss: 0.45176395773887634
Epoch 5, Loss: 0.22495898604393005
Epoch 6, Loss: 0.16154973208904266
Epoch 7, Loss: 0.13796788454055786
Epoch 8, Loss: 0.12263406813144684
Epoch 9, Loss: 0.11631610989570618
Epoch 10, Loss: 0.11260665208101273
Epoch 11, Loss: 0.11302581429481506
Epoch 12, Loss: 0.10883960127830505
Epoch 13, Loss: 0.10708191245794296
Epoch 14, Loss: 0.10203983634710312
Epoch 15, Loss: 0.10470855981111526
Epoch 16, Loss: 0.0971754938364029
Epoch 17, Loss: 0.10021547973155975
Epoch 18, Loss: 0.0974670797586441
Epoch 19, Loss: 0.0876523032784462
Epoch 20, Loss: 0.09050250798463821


In [108]:
final_mse_loss = evaluate_mse_loss(state.params, inputs, outputs)
print(f"Final MSE Loss (All Timesteps): {final_mse_loss}")


Final MSE Loss (All Timesteps): 0.08142220228910446


In [118]:
@partial(jit, static_argnums=(2, 5))
def nn_analysis_filter_steps(
    state_transition_function,
    ensemble_init,
    num_steps,
    observations,
    observation_interval,
    model,
    params,
    key,
):
    """
    A simplified filter step function that uses a neural network for analysis.

    Args:
       
        model: The trained neural network model.
        params: Parameters of the trained neural network.
        key: PRNG key for reproducibility.
    Returns:
        ensemble_preds: Predicted ensembles at each time step, shape (num_steps, n, n_ensemble).
        ensembles: Analysis ensembles at each time step, shape (num_steps, n, n_ensemble).
    """
    model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
    key, *subkeys = random.split(key, num=num_steps + 1)
    subkeys = jnp.array(subkeys)

    def inner(carry, t):
        ensemble = carry
        ensemble_predicted = model_vmap(ensemble)

        def true_fun(_):
            # Flatten the predicted ensemble and prepare input for the model
            pred_flat = ensemble_predicted.reshape(-1)  # (n_ensemble * n,)
            input_t = jnp.concatenate([pred_flat, observations[t]])  # Append observation
            # Use the NN to predict the analysis ensemble
            analysis_flat = model.apply(params, input_t)
            # Reshape back to (n, n_ensemble)
            analysis_ensemble = analysis_flat.reshape(ensemble_predicted.shape)
            return analysis_ensemble

        def false_fun(_):
            return ensemble_predicted

        updated_ensemble = lax.cond(
            t % observation_interval == 0, true_fun, false_fun, operand=None
        )
        return updated_ensemble, (ensemble_predicted, updated_ensemble)

    # Perform filtering over all time steps
    _, (ensemble_preds, ensembles) = lax.scan(
        inner, ensemble_init, jnp.arange(num_steps)
    )

    return ensemble_preds, ensembles


In [133]:
ensemble_preds, nn_ensemble = nn_analysis_filter_steps(
    state_transition_function=l63_step,
    ensemble_init=ensemble_init,
    num_steps=num_steps,
    observations=observations,
    observation_interval=1,
    model=model,
    params=state.params,
    key=key,
)
ensemble_means = jnp.mean(nn_ensemble, axis=2)  # Shape: (num_steps, n)

# Compute the RMSE
rmse = jnp.sqrt(jnp.mean((ensemble_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print(jnp.mean(rmse))

20.56104


In [125]:
ensemble_means = m_updates

# Compute the RMSE
rmse = jnp.sqrt(jnp.mean((ensemble_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print(jnp.mean(rmse))

0.18123291


In [127]:
key, subkey = random.split(key)
ensemble_init = random.multivariate_normal(subkey, m0, C0, (n_ensemble,)).T  # Shape: (n, ensemble_size)
localization_matrix = generate_localization_matrix(3,1)

ensemble_preds, C_preds, ensembles, covariances = ensrf_steps(state_transition_function, ensemble_init, num_steps, observations, 1, H, Q, R, localization_matrix=localization_matrix, inflation=1.9, key=key)

ensemble_means = jnp.mean(ensembles, axis=2)  # Shape: (num_steps, n)

# Compute the RMSE
rmse = jnp.sqrt(jnp.mean((ensemble_means - true_states) ** 2, axis=1))  # Shape: (num_steps,)
print(jnp.mean(rmse))

0.18224059
