In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
import sys
sys.path.append('..')
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
jax.config.update("jax_platform_name", "cpu")
import metrics.ensemble as ens_metrics
from jax import config
config.update("jax_enable_x64", True)

In this experiment we perform a increase in particle members, to assess how this effects forecast skill scores in an idealised testing environment. 

In [2]:
signal_params = ConfigDict(KDV_params_2)
print(signal_params)

Advection_basis_name: none
E: 1
Forcing_basis_name: none
P: 0
S: 0
c_0: 0
c_1: 1
c_2: 0.0
c_3: 2.0e-05
c_4: 0.0
dt: 0.001
equation_name: KdV
initial_condition: gaussian
method: Dealiased_SETDRK4
noise_magnitude: 0.0
nt: 4000
nx: 256
tmax: 4
xmax: 1
xmin: 0



In [3]:
ensemble_params_1 = ConfigDict(KDV_params_2_SALT)
ensemble_params_2 = ConfigDict(KDV_params_2_SALT)
ensemble_params_3 = ConfigDict(KDV_params_2_SALT)
ensemble_params_4 = ConfigDict(KDV_params_2_SALT)
ensemble_params_5 = ConfigDict(KDV_params_2_SALT)
ensemble_params_6 = ConfigDict(KDV_params_2_SALT)
ensemble_params_7 = ConfigDict(KDV_params_2_SALT)
ensemble_params_8 = ConfigDict(KDV_params_2_SALT)


Next, we specify a signal, by choosing a deterministic solver, 

In [4]:
nx_new = 64#256
signal_params.update(E=1,method='Dealiased_ETDRK4',P=1,S=0,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_1.update(E=16,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_2.update(E=32,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_3.update(E=64,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_4.update(E=128,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_5.update(E=256,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_6.update(E=512,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_7.update(E=1024,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)
ensemble_params_8.update(E=2048,noise_magnitude=0.001,P=32,tmax=4,nmax=256*4,nx=nx_new)

Now we continue to define a stochastic ensemble

Now we specify the models, by calling the class.

In [5]:
signal_model_1 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_2 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_3 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_4 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_5 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_6 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_7 = ETD_KT_CM_JAX_Vectorised(signal_params)
signal_model_8 = ETD_KT_CM_JAX_Vectorised(signal_params)

In [6]:
ensemble_model_1 = ETD_KT_CM_JAX_Vectorised(ensemble_params_1)
ensemble_model_2 = ETD_KT_CM_JAX_Vectorised(ensemble_params_2)
ensemble_model_3 = ETD_KT_CM_JAX_Vectorised(ensemble_params_3)
ensemble_model_4 = ETD_KT_CM_JAX_Vectorised(ensemble_params_4)
ensemble_model_5 = ETD_KT_CM_JAX_Vectorised(ensemble_params_5)
ensemble_model_6 = ETD_KT_CM_JAX_Vectorised(ensemble_params_6)
ensemble_model_7 = ETD_KT_CM_JAX_Vectorised(ensemble_params_7)
ensemble_model_8 = ETD_KT_CM_JAX_Vectorised(ensemble_params_8)

In [7]:

initial_signal = initial_condition(signal_model_1.x, signal_params.E, signal_params.initial_condition)
initial_ensemble_1 = initial_condition(ensemble_model_1.x, ensemble_params_1.E, ensemble_params_1.initial_condition)
initial_ensemble_2 = initial_condition(ensemble_model_2.x, ensemble_params_2.E, ensemble_params_2.initial_condition)
initial_ensemble_3 = initial_condition(ensemble_model_3.x, ensemble_params_3.E, ensemble_params_3.initial_condition)
initial_ensemble_4 = initial_condition(ensemble_model_4.x, ensemble_params_4.E, ensemble_params_4.initial_condition)
initial_ensemble_5 = initial_condition(ensemble_model_5.x, ensemble_params_5.E, ensemble_params_5.initial_condition)
initial_ensemble_6 = initial_condition(ensemble_model_6.x, ensemble_params_6.E, ensemble_params_6.initial_condition)
initial_ensemble_7 = initial_condition(ensemble_model_7.x, ensemble_params_7.E, ensemble_params_7.initial_condition)
initial_ensemble_8 = initial_condition(ensemble_model_8.x, ensemble_params_8.E, ensemble_params_8.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, none, default


In [8]:
# obs_frequency = 32
# observation_noise = 1e-2
# observation_locations = jnp.arange(0,len(signal_model_1.x),obs_frequency)
# #observation_locations = None
observation_spatial_frequency = 32
observation_locations = jnp.arange(0,signal_model_1.x.shape[0],observation_spatial_frequency)
observation_noise = 0.1
number_of_observations_time = 32#ensemble_model.params.nt#32
observation_temporal_frequency = int(ensemble_model_1.params.nt/number_of_observations_time)


pf_1 = ParticleFilter(n_particles = ensemble_params_1.E,
                      n_steps = observation_temporal_frequency,
                      n_dim = initial_signal.shape[-1],
                      forward_model = ensemble_model_1,
                      signal_model = signal_model_1,
                      sigma = observation_noise,
                      resampling='multinomial',
                      observation_locations = observation_locations,)
pf_2 = ParticleFilter(n_particles = ensemble_params_2.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_2,
                        signal_model = signal_model_2,
                        sigma = observation_noise,
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)
pf_3 = ParticleFilter(n_particles = ensemble_params_3.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_3,
                        signal_model = signal_model_3,
                        sigma = observation_noise,# this seems to be different than the sigma for the xi.
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)
pf_4 = ParticleFilter(n_particles = ensemble_params_4.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_4,
                        signal_model = signal_model_4,
                        sigma = observation_noise,# this seems to be different than the sigma for the xi.
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)
pf_5 = ParticleFilter(n_particles = ensemble_params_5.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_5,
                        signal_model = signal_model_5,
                        sigma = observation_noise,# this seems to be different than the sigma for the xi.
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)
pf_6 = ParticleFilter(n_particles = ensemble_params_6.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_6,
                        signal_model = signal_model_6,
                        sigma = observation_noise,# this seems to be different than the sigma for the xi.
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)
pf_7 = ParticleFilter(n_particles = ensemble_params_7.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_7,
                        signal_model = signal_model_7,
                        sigma = observation_noise,# this seems to be different than the sigma for the xi.
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)
pf_8 = ParticleFilter(n_particles = ensemble_params_8.E,
                        n_steps = observation_temporal_frequency,
                        n_dim = initial_signal.shape[-1],
                        forward_model = ensemble_model_8,
                        signal_model = signal_model_8,
                        sigma = observation_noise,# this seems to be different than the sigma for the xi.
  
                        resampling='multinomial',
                        observation_locations = observation_locations,)

In [9]:
key = jax.random.PRNGKey(0)
final_1, all_1 = pf_1.run(initial_ensemble_1, initial_signal, signal_model_1.params.nt,key) 
final_2, all_2 = pf_2.run(initial_ensemble_2, initial_signal, signal_model_2.params.nt,key) 
final_3, all_3 = pf_3.run(initial_ensemble_3, initial_signal, signal_model_3.params.nt,key) 
final_4, all_4 = pf_4.run(initial_ensemble_4, initial_signal, signal_model_4.params.nt,key) 
final_5, all_5 = pf_5.run(initial_ensemble_5, initial_signal, signal_model_5.params.nt,key) 
final_6, all_6 = pf_6.run(initial_ensemble_6, initial_signal, signal_model_6.params.nt,key) 
final_7, all_7 = pf_7.run(initial_ensemble_7, initial_signal, signal_model_7.params.nt,key) 
final_8, all_8 = pf_8.run(initial_ensemble_8, initial_signal, signal_model_8.params.nt,key) 

Put in the initial condition.

In [10]:
particles_1 = jnp.concatenate([initial_ensemble_1[None,...], all_1[0]], axis=0)
particles_2 = jnp.concatenate([initial_ensemble_2[None,...], all_2[0]], axis=0)
particles_3 = jnp.concatenate([initial_ensemble_3[None,...], all_3[0]], axis=0)
particles_4 = jnp.concatenate([initial_ensemble_4[None,...], all_4[0]], axis=0)
particles_5 = jnp.concatenate([initial_ensemble_5[None,...], all_5[0]], axis=0)
particles_6 = jnp.concatenate([initial_ensemble_6[None,...], all_6[0]], axis=0)
particles_7 = jnp.concatenate([initial_ensemble_7[None,...], all_7[0]], axis=0)
particles_8 = jnp.concatenate([initial_ensemble_8[None,...], all_8[0]], axis=0)

In [11]:
signal = jnp.concatenate([initial_signal[None,...],all_1[1]], axis=0)

In [12]:
def plot(da_step):
    plt.plot(signal_model_1.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model_1.x, particles_1[da_step,0,:], color='b',label='E=16')
    plt.plot(signal_model_1.x, particles_8[da_step,0,:], color='g',label='E=2048')
    
    # plt.scatter(signal_model_1.x[::obs_frequency], observations[da_step,0,::obs_frequency], color='r',label='observations')
    plt.legend()
    plt.show()

interact(plot, da_step=(0, signal_model_1.params.nt))

interactive(children=(IntSlider(value=2000, description='da_step', max=4000), Output()), _dom_classes=('widget…

<function __main__.plot(da_step)>

In [13]:
# Subsample signal and particles by taking every nth time step, e.g., every 4th step
subsample_step = 8
subsample_space = 8  # Subsample every 4th spatial point

signal_sub = signal[::subsample_step, :, ::subsample_space]
particles_1_sub = particles_1[::subsample_step, :, ::subsample_space]
particles_2_sub = particles_2[::subsample_step, :, ::subsample_space]
particles_3_sub = particles_3[::subsample_step, :, ::subsample_space]
particles_4_sub = particles_4[::subsample_step, :, ::subsample_space]
particles_5_sub = particles_5[::subsample_step, :, ::subsample_space]
particles_6_sub = particles_6[::subsample_step, :, ::subsample_space]
particles_7_sub = particles_7[::subsample_step, :, ::subsample_space]
particles_8_sub = particles_8[::subsample_step, :, ::subsample_space]
# to subsample and realocate arrays takes 3 min

In [None]:
crps_1 = ens_metrics.crps_internal(signal_sub[1:,...], particles_1_sub[1:,...])
crps_2 = ens_metrics.crps_internal(signal_sub[1:,...], particles_2_sub[1:,...])
crps_3 = ens_metrics.crps_internal(signal_sub[1:,...], particles_3_sub[1:,...])
crps_4 = ens_metrics.crps_internal(signal_sub[1:,...], particles_4_sub[1:,...])
crps_5 = ens_metrics.crps_internal(signal_sub[1:,...], particles_5_sub[1:,...])
crps_6 = ens_metrics.crps_internal(signal_sub[1:,...], particles_6_sub[1:,...])
crps_7 = ens_metrics.crps_internal(signal_sub[1:,...], particles_7_sub[1:,...])
crps_8 = ens_metrics.crps_internal(signal_sub[1:,...], particles_8_sub[1:,...])
# takes about 7 minutes to run at 32 spatial resolution using the crps from xskill score
# takes about ? minutes to run at 256 spatial resolution using the crps from internal jax defined function

: 

In [None]:
rmse_1 = ens_metrics.rmse(signal_sub[1:,...], particles_1_sub[1:,...])
rmse_2 = ens_metrics.rmse(signal_sub[1:,...], particles_2_sub[1:,...])
rmse_3 = ens_metrics.rmse(signal_sub[1:,...], particles_3_sub[1:,...])
rmse_4 = ens_metrics.rmse(signal_sub[1:,...], particles_4_sub[1:,...])
rmse_5 = ens_metrics.rmse(signal_sub[1:,...], particles_5_sub[1:,...])
rmse_6 = ens_metrics.rmse(signal_sub[1:,...], particles_6_sub[1:,...])
rmse_7 = ens_metrics.rmse(signal_sub[1:,...], particles_7_sub[1:,...])
rmse_8 = ens_metrics.rmse(signal_sub[1:,...], particles_8_sub[1:,...])

In [None]:
plt.title('RMSE')
list = [rmse_1,rmse_2,rmse_3,rmse_4,rmse_5,rmse_6,rmse_7,rmse_8]
for i in range(0,8):
    plt.plot(list[i], label=f'E={16*2**i}', color=plt.cm.viridis(i / 7))
plt.plot(observation_noise*jnp.ones_like(list[0]),c='k',label=f'observation noise magnitude')
plt.xlabel('time')
plt.legend()
plt.show()

In [None]:
time = jnp.arange(0, signal_model_1.params.nt+1) * signal_model_1.dt
plt.title('RMSE')
list = [rmse_1,rmse_2,rmse_3,rmse_4,rmse_5,rmse_6,rmse_7,rmse_8]
for i in range(0,8):
    plt.plot(time,list[i], label=f'E={16*2**i}', color=plt.cm.viridis(i / 7))
plt.plot(observation_noise*jnp.ones_like(list[0]),c='k',label=f'observation noise magnitude')
plt.xlabel('time')
plt.legend()
plt.show()

In [None]:
plt.title('CRPS')
list = [crps_1,crps_2,crps_3,crps_4,crps_5,crps_6,crps_7,crps_8]
for i in range(0,8):
    plt.plot(list[i], label=f'E={16*2**i}', color=plt.cm.viridis(i / 7))
plt.plot(observation_noise*jnp.ones_like(list[0]),c='k',label=f'observation noise magnitude')
plt.xlabel('time')
plt.legend()
plt.show()