In [8]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
jax.config.update("jax_platform_name", "cpu")
import metrics.ensemble as ens_metrics


Initialisation

In [9]:
signal_params = ConfigDict(KDV_params_2)
print(signal_params)

Advection_basis_name: none
E: 1
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 0.0
c_3: 2.0e-05
c_4: 0.0
dt: 0.001
equation_name: KdV
initial_condition: gaussian
method: Dealiased_SETDRK4
noise_magnitude: 0.0
nx: 256
tmax: 1
xmax: 1
xmin: 0



In [10]:
ensemble_params = ConfigDict(KDV_params_2_SALT)
print(ensemble_params)

Advection_basis_name: constant
E: 10
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 0.0
c_3: 2.0e-05
c_4: 0.0
dt: 0.001
equation_name: KdV
initial_condition: gaussian
method: Dealiased_SETDRK4
noise_magnitude: 0.01
nx: 256
tmax: 1
xmax: 1
xmin: 0



Next, we specify a signal, by choosing a deterministic solver, 

In [11]:
signal_params.update(E=1,method='Dealiased_ETDRK4',P=0,S=0,tmax=1,nmax=256*4)
ensemble_params.update(E=128,noise_magnitude=0.01,P=32,tmax=1,nmax=256*4)

Now we continue to define a stochastic ensemble

Now we specify the models, by calling the class.

In [12]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)
initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)
available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, default


In [13]:
pf_multinomial = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = 1,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = 0.1,# this seems to be different than the sigma for the xi.
    seed = 11,
    resampling='multinomial',
    observation_locations = jnp.arange(0,256,4),
)

pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = 1,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model_2,
    signal_model = signal_model,
    sigma = 0.1,# this seems to be different than the sigma for the xi.
    seed = 11,
    resampling='systematic',
    observation_locations = jnp.arange(0,256,4),
)

pf_no_resampling = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = 1,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model_3,
    signal_model = signal_model,
    sigma = 0.1,# this seems to be different than the sigma for the xi.
    seed = 11,
    resampling='no_resampling',
    observation_locations = jnp.arange(0,256,4),
)

In [14]:
final_systematic, all_systematic = pf_systematic.run(initial_ensemble, initial_signal, signal_model.nmax) 
final_multinomial, all_multinomial = pf_multinomial.run(initial_ensemble, initial_signal, signal_model.nmax) 
final_no_resampling, all_no_resampling = pf_no_resampling.run(initial_ensemble, initial_signal, signal_model.nmax)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scan_fn at /Users/jmw/Documents/GitHub/Particle_Filter/filters/filter.py:51 traced for scan.
------------------------------
The leaked intermediate value was created on line /Users/jmw/Documents/GitHub/Particle_Filter/models/ETD_KT_CM_JAX_Vectorised.py:110 (run). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/Users/jmw/Documents/GitHub/Particle_Filter/filters/filter.py:56 (run)
/Users/jmw/Documents/GitHub/Particle_Filter/filters/filter.py:53 (scan_fn)
/Users/jmw/Documents/GitHub/Particle_Filter/filters/filter.py:42 (run_step)
/Users/jmw/Documents/GitHub/Particle_Filter/filters/filter.py:20 (advance_signal)
/Users/jmw/Documents/GitHub/Particle_Filter/models/ETD_KT_CM_JAX_Vectorised.py:110 (run)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Put in the initial condition.

In [None]:
particles_systematic = jnp.concatenate([initial_ensemble[None,...], all_systematic[0]], axis=0)
particles_multinomial = jnp.concatenate([initial_ensemble[None,...], all_multinomial[0]], axis=0)
particles_no_resampling = jnp.concatenate([initial_ensemble[None,...], all_no_resampling[0]], axis=0)

signal = jnp.concatenate([initial_signal[None,...], all_systematic[1]], axis=0)
print(f"Particles Shape: {particles_systematic.shape} is (N_da_steps+1, N_particles, N_dim)")
print(f"Signal Shape: {signal.shape} is (N_da_steps+1, 1,  N_dim)")
observations = all_systematic[2]
print(f"Observations Shape: {observations.shape} is (N_da_steps, 1,  N_dim)")
print(f"needs fixing, this should be (N_da_steps, N_obs_dim, N_dim)")

In [None]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles_systematic[da_step,:,:].T, color='b'linewidth=0.01)
    plt.plot(signal_model.x, particles_systematic[da_step,0,:].T, color='b',label='particles',linewidth=0.01)

    plt.plot(signal_model.x, particles_no_resampling[da_step,:,:].T, color='r',linewidth=0.01)
    plt.plot(signal_model.x, particles_no_resampling[da_step,0,:].T, color='r',label='particles',linewidth=0.01)
    plt.legend()
    #plt.scatter(signal_model.x[::128], observations[da_step,0,:], color='r',label='observations')
    plt.show()

interact(plot, da_step=(0, signal_model.nmax))

In [None]:
bias = ens_metrics.bias(signal[1:,...], particles[1:,...])
rmse = ens_metrics.rmse(signal[1:,...], particles[1:,...])
crps = ens_metrics.crps(signal[1:,...], particles[1:,...])
print(type(bias), type(rmse), type(crps))

In [None]:
print(bias.shape, rmse.shape, crps.shape)

In [27]:
# rmse_new = ens_metrics.rmse_2(signal[1:,...], particles[1:,...])

In [None]:
plt.plot(bias)
plt.show()

In [None]:
plt.title('RMSE')
plt.plot(rmse)
plt.show()

In [None]:
plt.title('CRPS')
plt.plot(crps)
plt.show()