In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
import sys
sys.path.append('..')
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
from filters.filter import ParticleFilterAll
from jax import config
jax.config.update("jax_enable_x64", True)
import numpy as np

True
not currently supported.


Initialisation of a twin experiment. 

In [2]:
signal_params = ConfigDict(KS_params_SALT)
ensemble_params = ConfigDict(KS_params_SALT)
ensemble_params.update(E=1,P=32,stochastic_advection_basis='constant')
ensemble_params.update(E=128,P=32,stochastic_advection_basis='constant')

Now we specify the models, by calling the class.

In [3]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

print(ensemble_model.params,ensemble_model.params.nt)
available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

Advection_basis_name: sin
E: 128
Forcing_basis_name: none
P: 32
S: 0
c_0: 0
c_1: 1
c_2: 1
c_3: 0.0
c_4: 1
dt: 0.25
equation_name: Kuramoto-Sivashinsky
initial_condition: Kassam_Trefethen_KS_IC
method: Dealiased_SETDRK4
noise_magnitude: 0.001
nt: 600
nx: 256
stochastic_advection_basis: constant
tmax: 150
xmax: 100.53096491487338
xmin: 0
 600
multinomial, systematic, no_resampling, none, default


In [4]:
observation_spatial_frequency = 32
observation_locations = np.arange(0,signal_model.x.shape[0],observation_spatial_frequency)
observation_noise = 0.5
number_of_observations_time = 32 #ensemble_model.params.nt
observation_temporal_frequency = int(ensemble_model.params.nt/number_of_observations_time)
print(f"Observation temporal frequency: {observation_temporal_frequency}")
print(f"Observation spatial locations:{observation_locations}")
pf_systematic = ParticleFilterAll(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    resampling="systematic",#'default',
    observation_locations = observation_locations,
)

Observation temporal frequency: 1
Observation spatial locations:[  0  32  64  96 128 160 192 224]


To run the filter with all outputs, the input arrays need to have 3 dimensions.
They are (time, n_particles, space).

In [5]:
da_steps = number_of_observations_time
print(initial_ensemble.shape, initial_signal.shape, da_steps)
initial_ensemble = initial_ensemble[None, ...] # adding the time dimension
initial_signal = initial_signal[None, ...] # adding the time dimension
key = jax.random.PRNGKey(0) # random key for JAX
final, all = pf_systematic.run(initial_ensemble, initial_signal, da_steps,key) #the final input is scan length? 

(128, 256) (1, 256) 600


In [6]:
print(len(final), len(all))

3 3


The two outputs are tuples. The first tuple has length 2 and contains the final ensemble and the final signal arrays.
The second tuple has length three and contains all outputs through time of the ensemble, signal and observation.

In [7]:
for i in range(3):
    print(all[i].shape)
arr = all[0]
all_0_reshaped = jnp.reshape(arr, (-1, arr.shape[2], arr.shape[3]))
print(all_0_reshaped.shape)
def plot_member(e):
    plt.imshow(all_0_reshaped[:, e, :], aspect='auto', origin='lower')
    plt.title(f'Ensemble member {e}')
    plt.xlabel('Space')
    plt.ylabel('Time')
    plt.show()

interact(plot_member, e=(0, all_0_reshaped.shape[1] - 1))
# plt.imshow(all_0_reshaped[:,2,:], aspect='auto', origin='lower')
# plt.show()


(600, 1, 128, 256)
(600, 1, 1, 256)
(600, 1, 256)
(600, 128, 256)


interactive(children=(IntSlider(value=63, description='e', max=127), Output()), _dom_classes=('widget-interact…

<function __main__.plot_member(e)>

In [8]:
for i in range(2):
    print(final[i].shape)

(1, 128, 256)
(1, 1, 256)


Put in the initial condition.

In [9]:
print(initial_ensemble.shape)
print(all[0].shape)
all_0_reshaped = jnp.reshape(all[0], (-1, all[0].shape[2], all[0].shape[3]))
print(all_0_reshaped.shape)

particles =jnp.concatenate([initial_ensemble[...], all_0_reshaped], axis=0)
print(particles.shape)

print(all[1].shape)
all_1_reshaped = jnp.reshape(all[1], (-1, all[1].shape[2], all[1].shape[3]))
print(all_1_reshaped.shape)
print(initial_signal.shape)
signal = jnp.concatenate([initial_signal[...], all_1_reshaped], axis=0)
print(signal.shape)
print(all[2].shape)
# observations = jnp.concatenate([initial_signal[...], all[2]], axis=0)
observations = all[2][:,:, observation_locations]
print(observations.shape)


(1, 128, 256)
(600, 1, 128, 256)
(600, 128, 256)
(601, 128, 256)
(600, 1, 1, 256)
(600, 1, 256)
(1, 1, 256)
(601, 1, 256)
(600, 1, 256)
(600, 1, 8)


In [10]:
def plot(time):
    plt.plot(signal_model.x, signal[time,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles[time,:,:].T, color='b',label='particles',linewidth=0.1)
    if (time ) % observation_temporal_frequency == 0 and time !=0:
        plt.plot(signal_model.x[observation_locations], observations[int(time/observation_temporal_frequency),0,:], 'ro',label='observations')
    plt.show()

interact(plot, time=(0, ensemble_model.params.nt - 1, 1))

interactive(children=(IntSlider(value=299, description='time', max=599), Output()), _dom_classes=('widget-inte…

<function __main__.plot(time)>