# The goal of this notebook is to demonstrate that the forward model is differentiable with respect to the Brownian noise paths, this allows the possibility of driving with different signals, to produce a forecast with better skill.

In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
import sys
sys.path.append('..')
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter, ParticleFilter_Sequential
from jax import config
jax.config.update("jax_enable_x64", True)
import numpy as np

float64


Initialisation of a twin experiment. 

In [None]:
signal_params = ConfigDict(KDV_params_2_SALT_LEARNING)
ensemble_params = ConfigDict(KDV_params_2_SALT_LEARNING)
signal_params.update(E=1,P=3,noise_magnitude = 0.1,stochastic_advection_basis='sin')
ensemble_params.update(E=128,P=3,noise_magnitude = 0.1,stochastic_advection_basis='sin')

Now we specify the models, by calling the class.

In [None]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

In [None]:
observation_spatial_frequency = 1
observation_locations = np.arange(0,signal_model.x.shape[0],observation_spatial_frequency)
observation_noise = 0.01
number_of_observations_time = ensemble_model.params.nt#32
observation_temporal_frequency = int(ensemble_model.params.nt/number_of_observations_time)

print(observation_locations)

In [None]:
key = jax.random.PRNGKey(0)
noise = jax.random.normal(key, shape=(ensemble_params.nt, ensemble_params.E, ensemble_params.P))

In [None]:

pf_systematic = ParticleFilter_Sequential(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    ess_threshold = 1.1,# set to less than one, for less resampling based on ess.
    resampling="systematic",#'default',
    observation_locations = observation_locations,
    Driving_Noise = noise
)

In [None]:
key = jax.random.PRNGKey(0)
da_steps = number_of_observations_time
initial_weights = jnp.ones(ensemble_params.E) / ensemble_params.E
final, all = pf_systematic.run(initial_ensemble, initial_weights, initial_signal, da_steps,key) #the final input is scan length? 

In [None]:
for i, arr in enumerate(all):
    print(f"Shape of all[{i}]: {arr.shape}") 

# if ParticleFilter_Sequential then we have 4 outputs, 

Put in the initial condition.

In [None]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[2]], axis=0)
observations = jnp.concatenate([initial_signal[None,...], all[3]], axis=0)
observations = all[2][:,:, observation_locations]
print(observations.shape)
print(particles.shape)

In [None]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles[da_step,:,:].T, color='b',label='particles',linewidth=0.1)
    if da_step > 0:
        plt.plot(signal_model.x[observation_locations], observations[da_step-1,0,:], 'ro',label='observations')
    plt.show()

interact(plot, da_step=(0, da_steps))

In [None]:
ensemble_params_1 = ConfigDict(KDV_params_2_SALT_LEARNING)
ensemble_params_1.update(E=128,P=3,noise_magnitude = 0.1,stochastic_advection_basis='sin')
ensemble_model_1 = ETD_KT_CM_JAX_Vectorised(ensemble_params_1)
initial_ensemble_1 = initial_condition(ensemble_model_1.x, ensemble_params_1.E, ensemble_params_1.initial_condition)
key2 = jax.random.PRNGKey(0) # sometimes the key matters!
pf_systematic_1 = ParticleFilter(
    n_particles = ensemble_params_1.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model_1,
    signal_model = signal_model,
    sigma = observation_noise,
    resampling='none',
    observation_locations = observation_locations,
)

final_1, all_1 = pf_systematic_1.run(initial_ensemble_1, initial_signal, da_steps, key2) #the final input is scan length? 


In [None]:
particles_1 =jnp.concatenate([initial_ensemble_1[None,...], all_1[0]], axis=0)
signal_1 = jnp.concatenate([initial_signal[None,...], all_1[1]], axis=0)
observations_1 = jnp.concatenate([initial_signal[None,...], all_1[2]], axis=0)

def new_plot(da_step):
    plt.figure(figsize=(10, 5))

    plt.plot(signal_model.x, particles[da_step,:,:].T, color='b',label='particles',linewidth=0.1)
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles_1[da_step,:,:].T, color='k',label='particles_no_da',linewidth=0.01)
    if da_step > 0:
        plt.plot(signal_model.x[observation_locations], observations[da_step-1,0,:], 'ro',label='observations')

interact(new_plot, da_step=(0, da_steps))

In [None]:
da_step = da_steps  # Only plot at the final DA stage
plt.figure(figsize=(10, 5),dpi=300)
plt.plot(signal_model.x, particles[da_step,:,:].T, color='b', linewidth=0.05)
#plt.plot(signal_model.x, particles[da_step,0,:].T, color='b', label='PF-ensemble', linewidth=0.1)
# Plot invisible lines for legend with desired linewidths
plt.plot([], [], color='b', label='PF-ensemble', linewidth=2)
plt.plot([], [], color='k', label='No-PF-ensemble', linewidth=2)
plt.plot(signal_model.x, signal[da_step,0,:], color='k', label='signal',linewidth=2,linestyle='--')  
plt.plot(signal_model.x, particles_1[da_step,:,:].T, color='k', linewidth=0.05)
#plt.plot(signal_model.x, particles_1[da_step,0,:].T, color='k', label='No-PF-ensemble', linewidth=0.1)
for i, obs_idx in enumerate(observation_locations):
    plt.vlines(
        x=signal_model.x[obs_idx],
        ymin=observations[da_step-1, 0, i],
        ymax=signal[da_step, 0, obs_idx],
        color='red',
        linewidth=2,
        alpha=0.7
    )
plt.plot(signal_model.x[observation_locations], observations[da_step-1,0,:], 'ro', label='Noisy - observations')
plt.legend()
plt.title('Ensemble and Signal at Final Data Assimilation Step: Particle Filter vs No Particle Filter')
plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX3_KS_PF_NPF.png',bbox_inches='tight',dpi=300)
plt.show()

In [None]:
from metrics.ensemble import rmse, crps

# Compute RMSE between signal and particle ensemble mean
print(particles.shape)
print(signal.shape)
rmse_score = rmse(signal[:,:,:], particles)
# Compute CRPS between signal and particles
crps_score = crps(signal[:,:,:], particles)

rmse_score_1 = rmse(signal[:,:,:], particles_1)
# Compute CRPS between signal and particles
crps_score_1 = crps(signal[:,:,:], particles_1)

dpi=300
plt.figure(figsize=(7, 4), dpi=dpi)
plt.plot(np.arange(da_steps + 1), rmse_score, color='black', marker='o', label='Particle Filter')
plt.plot(np.arange(da_steps + 1), rmse_score_1, color='blue', marker='s', label='No Particle Filter')
plt.xlabel('Lead Time (DA step)')
plt.ylabel('RMSE')
plt.title('RMSE')
plt.legend()
plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX3_KS_PF_NPF_RMSE.png', bbox_inches='tight', dpi=dpi)
plt.show()

plt.figure(figsize=(7, 4), dpi=dpi)
plt.plot(np.arange(da_steps + 1), crps_score, color='black', marker='o', label='Particle Filter')
plt.plot(np.arange(da_steps + 1), crps_score_1, color='blue', marker='s', label='No Particle Filter')
plt.xlabel('Lead Time (DA step)')
plt.ylabel('Score')
plt.title('CRPS vs Lead Time')
plt.legend()
plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX3_KS_PF_NPF_CRPS.png', bbox_inches='tight', dpi=dpi)

plt.show()

In [None]:
from metrics.ensemble import crps_internal
from jax import grad

#  how to run
key = jax.random.PRNGKey(0)
noise = jax.random.normal(key, shape=(ensemble_params.nt, ensemble_params.E, ensemble_params.P))

# pf_systematic = ParticleFilter_Sequential(
#     n_particles = ensemble_params.E,
#     n_steps = observation_temporal_frequency,
#     n_dim = initial_signal.shape[-1],
#     forward_model = ensemble_model,
#     signal_model = signal_model,
#     sigma = observation_noise,
#     ess_threshold = 1.1,# set to less than one, for less resampling based on ess.
#     resampling="systematic",#'default',
#     observation_locations = observation_locations,
#     Driving_Noise = noise
# )
# key = jax.random.PRNGKey(0)
# da_steps = number_of_observations_time
# initial_weights = jnp.ones(ensemble_params.E) / ensemble_params.E
# final, all = pf_systematic.run(initial_ensemble, initial_weights, initial_signal, da_steps,key)
# particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
# signal = jnp.concatenate([initial_signal[None,...], all[2]], axis=0)
# crps_score = crps_internal(signal[:,:,:], particles).mean(axis=0)

# Define a function that computes the mean CRPS given noise
def crps_wrt_noise(noise_input):
    pf = ParticleFilter_Sequential(# never do the 
        n_particles=ensemble_params.E,
        n_steps=observation_temporal_frequency,
        n_dim=initial_signal.shape[-1],
        forward_model=ensemble_model,
        signal_model=signal_model,
        sigma=observation_noise,
        ess_threshold=0.0,
        resampling='none',
        observation_locations=observation_locations,
        Driving_Noise=noise_input
    )
    final, all = pf.run(initial_ensemble, initial_weights, initial_signal, da_steps, key)
    particles = jnp.concatenate([initial_ensemble[None, ...], all[0]], axis=0)
    signal = jnp.concatenate([initial_signal[None, ...], all[2]], axis=0)
    return crps_internal(signal, particles).mean()

# Compute the gradient of mean CRPS with respect to the noise
crps_grad = grad(crps_wrt_noise)(noise)
print(crps_grad.shape)
print("Gradient of CRPS with respect to noise:", crps_grad)

In [None]:
def multidim_ou_process(key, shape, theta=0.15, mu=0.0, sigma=0.2, dt=1.0, x0=None):
    """
    Generate a multidimensional Ornstein-Uhlenbeck (OU) process trajectory.

    Args:
        key: JAX PRNGKey.
        shape: Tuple (timesteps, dim) for the output trajectory.
        theta: Mean reversion rate.
        mu: Long-term mean (can be scalar or array of shape [dim]).
        sigma: Volatility (can be scalar or array of shape [dim]).
        dt: Time step size.
        x0: Initial value (array of shape [dim]). If None, starts at mu.

    Returns:
        Array of shape [timesteps, dim] with OU process trajectory.
    """
    # For shape = (timesteps, E, P), treat each (E, P) as an independent OU process
    timesteps, E, P = shape
    mu = jnp.broadcast_to(mu, (E, P))
    sigma = jnp.broadcast_to(sigma, (E, P))
    if x0 is None:
        x0 = mu

    noise_key, _ = jax.random.split(key)
    noise = jax.random.normal(noise_key, shape=(timesteps, E, P))

    def ou_step(x_prev, noise_t):
        x_next = x_prev + theta * (mu - x_prev) * dt + sigma * jnp.sqrt(dt) * noise_t
        return x_next, x_next

    _, traj = jax.lax.scan(ou_step, x0, noise)
    ou_increments = traj[1:] - traj[:-1]

    return traj, ou_increments 

# Example usage:
key = jax.random.PRNGKey(42)
traj,ou_increments = multidim_ou_process(key, shape=(ensemble_params.nt, ensemble_params.E, ensemble_params.P), theta=0.2, mu=0.0, sigma=0.3, dt=ensemble_model.params.dt)

print("Shape of OU process trajectory:", traj.shape)
plt.figure(figsize=(10, 5))
plt.plot(traj[:, 0], label='Dimension 1')
plt.plot(traj[:, 1], label='Dimension 2')
plt.plot(traj[:, 2], label='Dimension 3')
plt.title('Multidimensional Ornstein-Uhlenbeck Process')
plt.xlabel('Time Steps')
plt.ylabel('Value')
plt.legend()
plt.show()
# Compute the increments generated by the OU process

print("Shape of OU increments:", ou_increments.shape)
plt.figure(figsize=(10, 5))
plt.plot(ou_increments[:, 0], label='Increment Dimension 1')
plt.plot(ou_increments[:, 1], label='Increment Dimension 2')
plt.plot(ou_increments[:, 2], label='Increment Dimension 3')
plt.title('Increments of Multidimensional OU Process')
plt.xlabel('Time Steps')
plt.ylabel('Increment Value')
plt.legend()
plt.show()

In [None]:

# traj,ou_increments = multidim_ou_process(key, shape=(ensemble_params.nt, ensemble_params.E, ensemble_params.P), theta=0.2, mu=0.0, sigma=0.3, dt=ensemble_model.params.dt)


In [None]:
from jax import jit, value_and_grad

# Simple gradient descent to update the noise array to minimize mean CRPS
# JIT compile the CRPS and gradient computation for speed
crps_wrt_noise_jit = jit(value_and_grad(crps_wrt_noise))

learning_rate = 1e2
num_steps = 32
#noise = jnp.ones_like(noise) * 0.01  # Initialize noise with a small value
noise_opt = noise.copy()
#noise_opt = jnp.ones_like(noise_opt)  # Start with zero noise for optimization
crps_history = []

for step in range(num_steps):
    crps_value, crps_grad = crps_wrt_noise_jit(noise_opt)
    crps_history.append(float(crps_value))
    noise_opt = noise_opt - learning_rate * crps_grad
    print(f"Step {step+1}, CRPS: {crps_value:.6f}")

# Plot CRPS history
plt.figure()
plt.plot(crps_history, marker='o')
plt.xlabel('Gradient Descent Step')
plt.ylabel('Mean CRPS')
plt.title('CRPS during Noise Optimization')
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(noise.flatten(), bins=100, alpha=0.25, label='noise')
plt.hist(noise_opt.flatten(), bins=100, alpha=0.25, label='noise_opt')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Distribution of noise and noise_opt')
plt.legend()
plt.show()
# Compute and plot the difference between noise and noise_opt
noise_diff = noise_opt - noise

plt.figure(figsize=(10, 5))
plt.hist(noise_diff.flatten(), bins=100, alpha=0.5, label='noise_opt - noise')
plt.xlabel('Difference Value')
plt.ylabel('Frequency')
plt.title('Distribution of Difference: noise_opt - noise')
plt.legend()
plt.show()

print("Mean difference:", noise_diff.mean())
print("Std of difference:", noise_diff.std())
print("Max difference:", noise_diff.max())
print("Min difference:", noise_diff.min())

In [None]:
import scipy.stats as stats
import matplotlib.pyplot as plt

# Flatten the arrays for fitting
noise_flat = noise.flatten()
noise_opt_flat = noise_opt.flatten()

# List of candidate distributions to test
distributions = [
    'norm', 'laplace', 't', 'cauchy', 'logistic', 'expon', 'uniform', 'gamma', 'beta'
]

def best_fit_distribution(data, distributions):
    best_dist = None
    best_params = None
    best_ks = float('inf')
    results = []
    for dist_name in distributions:
        dist = getattr(stats, dist_name)
        try:
            params = dist.fit(data)
            ks_stat, ks_p = stats.kstest(data, dist_name, args=params)
            results.append((dist_name, ks_stat, ks_p, params))
            if ks_stat < best_ks:
                best_ks = ks_stat
                best_dist = dist_name
                best_params = params
        except Exception as e:
            continue
    return best_dist, best_params, results

best_dist_noise, params_noise, results_noise = best_fit_distribution(noise_flat, distributions)
best_dist_noise_opt, params_noise_opt, results_noise_opt = best_fit_distribution(noise_opt_flat, distributions)

print(f"Best fit for noise: {best_dist_noise} with params {params_noise}")
print(f"Best fit for noise_opt: {best_dist_noise_opt} with params {params_noise_opt}")

# Plot histogram and best fit
x = np.linspace(min(noise_flat.min(), noise_opt_flat.min()), max(noise_flat.max(), noise_opt_flat.max()), 1000)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(noise_flat, bins=100, density=True, alpha=0.5, label='noise')
plt.plot(x, getattr(stats, best_dist_noise).pdf(x, *params_noise), 'r-', label=f'{best_dist_noise} fit')
plt.title('noise')
plt.legend()

plt.subplot(1, 2, 2)
plt.hist(noise_opt_flat, bins=100, density=True, alpha=0.5, label='noise_opt')
plt.plot(x, getattr(stats, best_dist_noise_opt).pdf(x, *params_noise_opt), 'r-', label=f'{best_dist_noise_opt} fit')
plt.title('noise_opt')
plt.legend()
plt.show()

In [None]:
pf_new = ParticleFilter_Sequential(# never do the 
        n_particles=ensemble_params.E,
        n_steps=observation_temporal_frequency,
        n_dim=initial_signal.shape[-1],
        forward_model=ensemble_model,
        signal_model=signal_model,
        sigma=observation_noise,
        ess_threshold=1.0,
        resampling="none",
        observation_locations=observation_locations,
        Driving_Noise=noise
    )
final, all = pf_new.run(initial_ensemble, initial_weights, initial_signal, da_steps, key)
particles = jnp.concatenate([initial_ensemble[None, ...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None, ...], all[2]], axis=0)
crps = crps_internal(signal, particles)

pf_learned = ParticleFilter_Sequential(# never do the 
        n_particles=ensemble_params.E,
        n_steps=observation_temporal_frequency,
        n_dim=initial_signal.shape[-1],
        forward_model=ensemble_model,
        signal_model=signal_model,
        sigma=observation_noise,
        ess_threshold=1.0,
        resampling="none",
        observation_locations=observation_locations,
        Driving_Noise=noise_opt
    )
final_learned, all_learned = pf_learned.run(initial_ensemble, initial_weights, initial_signal, da_steps, key)
particles_learned = jnp.concatenate([initial_ensemble[None, ...], all_learned[0]], axis=0)
signal_learned = jnp.concatenate([initial_signal[None, ...], all_learned[2]], axis=0)
crps_learned = crps_internal(signal_learned, particles_learned)

plt.figure(figsize=(10, 5), dpi=300)
plt.plot(signal_model.x, signal[da_step, 0, :], color='k', label='Signal', linewidth=2, linestyle='--')
plt.plot(signal_model.x, particles[da_step].T, color='b', alpha=0.1, linewidth=0.5)
plt.plot(signal_model.x[observation_locations], signal[da_step, 0, observation_locations], 'ro', label='Observations')
plt.legend()
plt.title('Final Time Step: Signal and PF Particles (Original Noise)')
plt.xlabel('x')
plt.ylabel('Value')
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 5), dpi=300)
plt.plot(signal_model.x, signal_learned[da_step, 0, :], color='k', label='Signal', linewidth=2, linestyle='--')
plt.plot(signal_model.x, particles_learned[da_step].T, color='g', alpha=0.5, linewidth=0.5)
plt.plot(signal_model.x[observation_locations], signal_learned[da_step, 0, observation_locations], 'ro', label='Observations')
plt.legend()
plt.title('Final Time Step: Signal and PF Particles (Learned Noise)')
plt.xlabel('x')
plt.ylabel('Value')
plt.tight_layout()
plt.show()




In [None]:

plt.figure(figsize=(8, 4), dpi=150)
plt.plot(crps, label='Original Noise CRPS', color='b', marker='o')
plt.plot(crps_learned, label='Learned Noise CRPS', color='g', marker='s')
plt.xlabel('DA Step')
plt.ylabel('CRPS')
plt.title('CRPS: Original vs Learned Noise')
plt.legend()
plt.tight_layout()
plt.show()