## Here we test the stochastic KS equation with different particle filters and Ensemble Kalman Filters.

In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
import sys
sys.path.append('..')
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import Hybrid_composed_ParticleFilter_of_EnKF, EnsembleKalmanFilter, ParticleFilter
from filters.filter import ParticleFilterAll, ParticleFilter_Sequential
from jax import config
jax.config.update("jax_enable_x64", True)
import numpy as np

float64


Initialisation of a twin experiment. 

In [2]:
signal_params = ConfigDict(KS_params_SALT)
ensemble_params = ConfigDict(KS_params_SALT)

signal_params = ConfigDict(KS_params_Force)
ensemble_params = ConfigDict(KS_params_Force)


signal_params.update(E=1,P=32,stochastic_advection_basis='none')
ensemble_params.update(E=128,P=32,stochastic_advection_basis='none')
print(signal_params)


Advection_basis_name: none
E: 1
Forcing_basis_name: sin
P: 32
S: 1
c_0: 0
c_1: 1
c_2: 1
c_3: 0.0
c_4: 1
dt: 0.25
equation_name: Kuramoto-Sivashinsky
initial_condition: Kassam_Trefethen_KS_IC
method: Dealiased_SETDRK4_forced
noise_magnitude: 0.001
nt: 600
nx: 256
stochastic_advection_basis: none
tmax: 150
xmax: 100.53096491487338
xmin: 0



Now we specify the models, by calling the class.

In [3]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, none, default


In [4]:
observation_spatial_frequency = 8 # 4
observation_locations = np.arange(0,signal_model.x.shape[0],observation_spatial_frequency)
observation_noise = .1
number_of_observations_time = int(ensemble_model.params.nt/8) # ensemble_model.params.nt/1
observation_temporal_frequency = int(ensemble_model.params.nt/number_of_observations_time)

print(observation_locations)
print(observation_temporal_frequency)
pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    resampling="systematic",#"none",#"systematic",#'default',
    observation_locations = observation_locations,
)

pf_tj = Hybrid_composed_ParticleFilter_of_EnKF(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    ess_threshold=100,#25,# above one always resamples, zero never resamples
    resampling="systematic",#'default',
    observation_locations = observation_locations,
)

# pf_tj = ParticleFilter_tempered_jittered(
#     n_particles = ensemble_params.E,
#     n_steps = observation_temporal_frequency,
#     n_dim = initial_signal.shape[-1],
#     forward_model = ensemble_model,
#     signal_model = signal_model,
#     sigma = observation_noise,
#     ess_threshold=0.5,#25,# above one always resamples, zero never resamples
#     resampling="systematic",#'default',
#     observation_locations = observation_locations,
#     jitter_magnitude=0.01,  # Adjust the jitter magnitude as needed
#     tempering_steps=10,  # Number of tempering steps
# )


kal = EnsembleKalmanFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    observation_locations = observation_locations,
)

[  0   8  16  24  32  40  48  56  64  72  80  88  96 104 112 120 128 136
 144 152 160 168 176 184 192 200 208 216 224 232 240 248]
8


In [5]:
key = jax.random.PRNGKey(0)
initial_weights = jnp.ones(ensemble_params.E) / ensemble_params.E
print(initial_weights.shape, initial_signal.shape, initial_ensemble.shape, key.shape)
# pf_tj_out = pf_tj.run_step(initial_ensemble, initial_signal, key)
pf_tj_out = pf_tj.run_step(initial_ensemble, initial_weights, initial_signal, key)

out = pf_systematic.run_step(initial_ensemble, initial_signal, key)
out_kal = kal.run_step(initial_ensemble, initial_signal, key)

(128,) (1, 256) (128, 256) (2,)


In [6]:
print(len(out))
ps, ss, obser = out
print(ps.shape, ss.shape, obser.shape)
jnp.count_nonzero(obser)

3
(128, 256) (1, 256) (1, 256)


Array(32, dtype=int64)

In [7]:
for i, arr in enumerate(pf_tj_out):
    print(f"pf_all_out[{i}] shape:", arr.shape)# 8 timesteps, 128 particles, 32 spatial locations, 256 total space locations

pf_all_out[0] shape: (128, 256)
pf_all_out[1] shape: (128,)
pf_all_out[2] shape: (1, 256)
pf_all_out[3] shape: (1, 256)


In [8]:
da_steps = number_of_observations_time
key = jax.random.PRNGKey(0)
final, all = pf_systematic.run(initial_ensemble, initial_signal, da_steps,key) #the final input is scan length? 

In [9]:
final_kal, all_kal = kal.run(initial_ensemble, initial_signal, da_steps,key) #the final input is scan length?

In [10]:
# final_tj, all_tj = pf_tj.run(initial_ensemble, initial_signal, da_steps, key) 
final_tj, all_tj = pf_tj.run(initial_ensemble, initial_weights,initial_signal, da_steps, key) 

Put in the initial condition.

In [11]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
observations = jnp.concatenate([initial_signal[None,...], all[2]], axis=0)
observations = all[2][:,:, observation_locations]
print(observations.shape)
print(particles.shape)

particles_kal = jnp.concatenate([initial_ensemble[None, ...], all_kal[0]], axis=0)
signal_kal = jnp.concatenate([initial_signal[None, ...], all_kal[1]], axis=0)
observations_kal = jnp.concatenate([initial_signal[None, ...], all_kal[2]], axis=0)
observations_kal = all_kal[2][:, :, observation_locations]
print(observations_kal.shape)
print(particles_kal.shape)

particles_tj = jnp.concatenate([initial_ensemble[None, ...], all_tj[0]], axis=0)
# signal_tj = jnp.concatenate([initial_signal[None, ...], all_tj[1]], axis=0)
# observations_tj = jnp.concatenate([initial_signal[None, ...], all_tj[2]], axis=0)
signal_tj = jnp.concatenate([initial_signal[None, ...], all_tj[2]], axis=0)
observations_tj = jnp.concatenate([initial_signal[None, ...], all_tj[3]], axis=0)
observations_tj = all_tj[2][:, :, observation_locations]
print(observations_kal.shape)
print(particles_kal.shape)

(75, 1, 32)
(76, 128, 256)
(75, 1, 32)
(76, 128, 256)
(75, 1, 32)
(76, 128, 256)


In [15]:
def plot_all(da_step):
    # Plot particles_tj (tempered-jittered)
    
    plt.figure(figsize=(12, 6))
    plt.plot(signal_model.x, particles_tj[da_step, 0, :].T, color='orange', label='Particles TJ', linewidth=0.5)
    plt.plot(signal_model.x, particles_tj[da_step, :, :].T, color='orange', linewidth=0.5, alpha=0.5)
    # Plot signal
    plt.plot(signal_model.x, signal[da_step, 0, :], color='k', label='Signal', linewidth=3)
    # Plot all particles
    plt.plot(signal_model.x, (particles[da_step, 0, :]).T, color='b', linewidth=0.05, alpha=0.5, label='Particles')

    plt.plot(signal_model.x, (particles[da_step, :, :]).T, color='b', linewidth=0.05, alpha=0.5)
    plt.plot(signal_model.x, particles_kal[da_step, 0, :].T, color='g', label='ENKF', linewidth=0.5)
    plt.plot(signal_model.x, particles_kal[da_step, :, :].T, color='g', linewidth=0.5)

    # Plot observations
    if da_step > 0:
        plt.scatter(signal_model.x[observation_locations], observations[da_step - 1, 0, :], color='r', label='Observations', zorder=5)
    plt.xlabel('Spatial Domain', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.title(f'Data Assimilation Step {da_step}', fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.show()

interact(plot_all, da_step=(0, da_steps))

interactive(children=(IntSlider(value=37, description='da_step', max=75), Output()), _dom_classes=('widget-int…

<function __main__.plot_all(da_step)>