In [7]:
import importlib
import jax_models
from jax.scipy.linalg import inv, det, svd
import jax.numpy as np
from jax import random, jit
from sklearn.datasets import make_spd_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import jax
from jax_models import  generate_true_states, generate_gc_localization_matrix #visualize_observations, KuramotoSivashinsky,
from jax_filters import ensrf_steps
importlib.reload(jax_models)

<module 'jax_models' from '/central/home/eluk/variational_filtering/jax_models.py'>

In [13]:
from jax.numpy.fft import fft, ifft

@jit
def KS_step(u, k, E, E2, Q, f1, f2, f3, g, h):
    v = fft(u)
    Nv = g * fft(np.real(ifft(v))**2)
    a = E2 * v + Q * Nv
    Na = g * fft(np.real(ifft(a))**2)
    b = E2 * v + Q * Na
    Nb = g * fft(np.real(ifft(b))**2)
    c = E2 * a + Q * (2*Nb - Nv)
    Nc = g * fft(np.real(ifft(c))**2)
    v_next = E * v + Nv * f1 + 2 * (Na + Nb) * f2 + Nc * f3
    u_next = np.real(ifft(v_next))

    return u_next

# Precompute constants needed for the time step function
def precompute_constants(s, l, h, M):
    k = (2 * np.pi / l) * np.concatenate([np.arange(0, s//2), np.array([0]), np.arange(-s//2+1, 0)])
    L = k**2 - k**4
    E = np.exp(h*L)
    E2 = np.exp(h*L/2)
    r = np.exp(1j * np.pi * (np.arange(1, M+1)-.5) / M)
    LR = h * np.tile(L, (M, 1)).T + np.tile(r, (s, 1))
    Q = h * np.real(np.mean((np.exp(LR/2)-1)/LR, axis=1))
    f1 = h * np.real(np.mean((-4-LR+np.exp(LR)*(4-3*LR+LR**2))/LR**3, axis=1))
    f2 = h * np.real(np.mean((2+LR+np.exp(LR)*(-2+LR))/LR**3, axis=1))
    f3 = h * np.real(np.mean((-4-3*LR-LR**2+np.exp(LR)*(4-LR))/LR**3, axis=1))
    g = -0.5j * k

    return k, E, E2, Q, f1, f2, f3, g
s = 256
l = 22
h = 0.25
M = 16
# Precompute constants
k, E, E2, Q_ks, f1, f2, f3, g = precompute_constants(s, l, h, M)


In [None]:
import pickle
import os
from datetime import datetime
from tqdm import tqdm

radii = [2, 10, 100]
inflations = [1.0, 1.05, 1.1]
num_trials = 10
ensemble_sizes = [10, 20]
n= s
num_steps = 10  # Total number of time steps
observation_interval = 5  # Observation is available every 5 time steps

# Initialize random key for reproducibility
key = random.PRNGKey(0)
ks_model_step = lambda state: KS_step(state, k, E, E2, Q_ks, f1, f2, f3, g, h)

# Process noise covariance and Observation noise covariance
Q = 0.01 * np.eye(n)
R = 0.5 * np.eye(n)
# Observation matrix (identity matrix for direct observation of all state variables)
H = np.eye(n)

# Initial state
initial_state = random.normal(key, (n,))

# Initialize data structures for results
std_errors = {(radius, inflation, n_ensemble): [] for radius in radii for inflation in inflations for n_ensemble in ensemble_sizes}
errors = {(radius, inflation, n_ensemble): [] for radius in radii for inflation in inflations for n_ensemble in ensemble_sizes}

for trial in tqdm(range(num_trials), desc="Running Trials"):
    print(f"Trial {trial}")
    observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, ks_model_step, observation_interval)

    for radius in radii:
        local_mat = generate_gc_localization_matrix(n, radius)
        for inflation in inflations:
            for n_ensemble in ensemble_sizes:
                ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
                states = ensrf_steps(ks_model_step, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, local_mat, inflation)
                average_state = np.mean(states, axis=2)  # Calculate the mean along the ensemble dimension
                error = np.sqrt(np.mean((average_state - true_states) ** 2, axis=1))
                errors[(radius, inflation, n_ensemble)].append(error)
                std_dev = np.std(states, axis=(1, 2))  # Standard deviation across all ensemble members and state dimensions
                std_errors[(radius, inflation, n_ensemble)].append(std_dev)

# Preparing data for saving
all_data = {
    'std_errors': std_errors,
    'errors': errors,
    'parameters': {
        'radii': radii,
        'inflations': inflations,
        'ensemble_sizes': ensemble_sizes,
        'num_trials': num_trials
    }
}

# File saving path
directory = 'variational_filtering/experiment_data/'
filename = f'ks_experiment_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
file_path = os.path.join(directory, filename)

with open(file_path, 'wb') as f:
    pickle.dump(all_data, f)

Running Trials:   0%|                                                                            | 0/10 [00:00<?, ?it/s]

Trial 0


Running Trials:  10%|██████▊                                                             | 1/10 [00:46<06:54, 46.09s/it]

Trial 1


Running Trials:  20%|█████████████▌                                                      | 2/10 [01:32<06:10, 46.34s/it]

Trial 2


Running Trials:  30%|████████████████████▍                                               | 3/10 [02:18<05:21, 45.92s/it]

Trial 3


Running Trials:  40%|███████████████████████████▏                                        | 4/10 [03:03<04:33, 45.61s/it]

Trial 4


Running Trials:  50%|██████████████████████████████████                                  | 5/10 [03:48<03:48, 45.67s/it]

Trial 5
