In [1]:
import importlib
import jax_models
from jax.scipy.linalg import inv, det, svd
import jax.numpy as np
from jax import random, jit
from sklearn.datasets import make_spd_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import jax
from jax_models import visualize_observations, KuramotoSivashinsky, generate_true_states, generate_gc_localization_matrix
from jax_filters import ensrf_steps
importlib.reload(jax_models)

<module 'jax_models' from '/central/home/eluk/variational_filtering/jax_models.py'>

In [2]:
# Initialize parameters
num_steps = 1000  # Number of simulation steps
n = 256 # Dimensionality of the state space for KS model
observation_interval = 5  # Interval at which observations are made
dt = 0.25  # Time step for the KS model

ks_model = KuramotoSivashinsky(dt=dt, s=n, l=22, M=16)

# Initial state
key = random.PRNGKey(0)  # Random key for reproducibility
x0 = random.normal(key, (n,))
initial_state  = x0
# Noise covariances
Q = 0.01 * np.eye(n)  # Process noise covariance
R = 0.5 * np.eye(n)  # Observation noise covariance
# Observation matrix (identity matrix for direct observation of all state variables)
H = np.eye(n)
# Generate observations
observations, true_states = generate_true_states(key, num_steps, n, x0, H, Q, R, ks_model.step, observation_interval)

# Visualize the observations
visualize_observations(true_states)

TypeError: Cannot interpret value of type <class 'method'> as an abstract array; it does not have a dtype attribute

In [4]:
import pickle
import os
from datetime import datetime
from tqdm import tqdm

radii = [2,5, 10, 50, 100]
inflations = [1.0, 1.05, 1.1, 1.2, 1.3, 1.5]
num_trials = 10
ensemble_sizes = [30,40]

n = ks_model.s
num_steps = 1000  # Total number of time steps
observation_interval = 5  # Observation is available every 5 time steps

# Initialize random key for reproducibility
key = random.PRNGKey(0)

# Process noise covariance and Observation noise covariance
Q = 0.01 * np.eye(n)
R = 0.5 * np.eye(n)
# Observation matrix (identity matrix for direct observation of all state variables)
H = np.eye(n)

# Initial state
initial_state = random.normal(key, (n,))

# Initialize data structures for results
std_errors = {(radius, inflation, n_ensemble): [] for radius in radii for inflation in inflations for n_ensemble in ensemble_sizes}
errors = {(radius, inflation, n_ensemble): [] for radius in radii for inflation in inflations for n_ensemble in ensemble_sizes}

for trial in tqdm(range(num_trials), desc="Running Trials"):
    print(f"Trial {trial}")
    observations, true_states = generate_true_states(key, num_steps, n, initial_state, H, Q, R, ks_model.step, observation_interval)

    for radius in radii:
        local_mat = generate_gc_localization_matrix(n, radius)
        for inflation in inflations:
            for n_ensemble in ensemble_sizes:
                ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
                states = ensrf_steps(ks_model.step, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, local_mat, inflation)
                average_state = np.mean(states, axis=2)  # Calculate the mean along the ensemble dimension
                error = np.sqrt(np.mean((average_state - true_states) ** 2, axis=1))
                errors[(radius, inflation, n_ensemble)].append(error)
                std_dev = np.std(states, axis=(1, 2))  # Standard deviation across all ensemble members and state dimensions
                std_errors[(radius, inflation, n_ensemble)].append(std_dev)

# Preparing data for saving
all_data = {
    'std_errors': std_errors,
    'errors': errors,
    'parameters': {
        'radii': radii,
        'inflations': inflations,
        'ensemble_sizes': ensemble_sizes,
        'num_trials': num_trials,
        'filter_params': (H,Q,R, observation_interval, num_steps)
    }
}

# File saving path
directory = '/central/home/eluk/variational_filtering/experiment_data/'
filename = f'ks_experiment_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
file_path = os.path.join(directory, filename)

with open(file_path, 'wb') as f:
    pickle.dump(all_data, f)

Running Trials:   0%|                                                                            | 0/10 [00:00<?, ?it/s]

Trial 0
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  10%|██████▋                                                            | 1/10 [05:02<45:22, 302.55s/it]

Trial 1
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  20%|█████████████▍                                                     | 2/10 [09:54<39:28, 296.10s/it]

Trial 2
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  30%|████████████████████                                               | 3/10 [15:21<36:13, 310.57s/it]

Trial 3
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  40%|██████████████████████████▊                                        | 4/10 [20:41<31:24, 314.06s/it]

Trial 4
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  50%|█████████████████████████████████▌                                 | 5/10 [25:46<25:53, 310.76s/it]

Trial 5
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  60%|████████████████████████████████████████▏                          | 6/10 [30:43<20:24, 306.17s/it]

Trial 6
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  70%|██████████████████████████████████████████████▉                    | 7/10 [35:39<15:08, 302.91s/it]

Trial 7
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  80%|█████████████████████████████████████████████████████▌             | 8/10 [40:48<10:09, 304.77s/it]

Trial 8
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials:  90%|████████████████████████████████████████████████████████████▎      | 9/10 [45:38<05:00, 300.12s/it]

Trial 9
<class 'jaxlib.xla_extension.ArrayImpl'>

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.

Intel MKL ERROR: Parameter 4 was incorrect on entry to SLASCL.


Running Trials: 100%|██████████████████████████████████████████████████████████████████| 10/10 [50:30<00:00, 303.00s/it]


In [None]:
# File saving path
directory = '/central/home/eluk/variational_filtering/experiment_data/'
filename = f'ks_experiment_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
file_path = os.path.join(directory, filename)

with open(file_path, 'wb') as f:
    pickle.dump(all_data, f)

: 