# Example 3 
In this experiment we run the KdV ensemble by transport noise, using the standard particle filter. We observe degeneracy in the filter, in the twin experiment. 

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
import numpy as np
jax.config.update("jax_enable_x64", True)

Initialisation

In [2]:
signal_params = ConfigDict(KDV_params_2_SALT)
ensemble_params = ConfigDict(KDV_params_2_SALT)
print(ensemble_params)


Advection_basis_name: constant
E: 10
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 0.0
c_3: 2.0e-05
c_4: 0.0
dt: 0.001
equation_name: KdV
initial_condition: gaussian
method: Dealiased_SETDRK4
noise_magnitude: 0.01
nx: 256
tmax: 1
xmax: 1
xmin: 0



We now specify the number of ensemble members and the number of basis functions required for the salt noise ensemble. 

In [3]:
signal_params.update(tmax=4)
ensemble_params.update(tmax=4,E=128,noise_magnitude=0.001,P=32,stochastic_advection_basis='constant')

Now we specify the models, by calling the class.

In [4]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

I0000 00:00:1740671752.628184       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


multinomial, systematic, no_resampling, default


In [None]:
observation_spatial_frequency = 16
observation_locations = np.arange(0,signal_model.x.shape[0],observation_spatial_frequency)
observation_noise = 0.1
number_of_observations_time = 32
observation_temporal_frequency = int(ensemble_model.nt/number_of_observations_time)
print(observation_locations)
print(observation_temporal_frequency)

pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    seed = 11,
    resampling = "systematic",#'default',
    observation_locations = observation_locations,
)

[  0  16  32  48  64  80  96 112 128 144 160 176 192 208 224 240]
125


In [6]:
da_steps = number_of_observations_time
final, all = pf_systematic.run(initial_ensemble, initial_signal, da_steps) 

TypeError: sub got incompatible shapes for broadcasting: (256,), (2560,).

Put in the initial condition.

In [None]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
print(particles.shape)

In [None]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles[da_step,:,:].T, color='b',label='particles',linewidth=0.1)
    plt.show()

interact(plot, da_step=(0, signal_model.nmax))