In [None]:
import numpy as np
import xarray as xr
from kalman_reconstruction import pipeline
from kalman_reconstruction import example_models
from kalman_reconstruction.custom_plot import (
    plot_state_with_probability,
    set_custom_rcParams,
)
import matplotlib.pyplot as plt

set_custom_rcParams()

Now there can be easily multiple random variables add and the Kalman_SEM performs on all of them

In [None]:
data = example_models.Lorenz_63_xarray(dt=0.01, time_length=5, time_steps=None)
seed = 345831200837
variance = 5
rng1 = np.random.default_rng(seed=seed)
rng2 = np.random.default_rng(seed=seed + 1)
rng3 = np.random.default_rng(seed=seed + 2)
rng4 = np.random.default_rng(seed=seed + 3)
pipeline.add_random_variable(
    data, var_name="z1", random_generator=rng1, variance=variance
)
pipeline.add_random_variable(
    data, var_name="z2", random_generator=rng2, variance=variance
)
pipeline.add_random_variable(
    data, var_name="z3", random_generator=rng3, variance=variance
)
pipeline.add_random_variable(
    data, var_name="z4", random_generator=rng4, variance=variance
)

In [None]:
test = pipeline.run_Kalman_SEM_to_xarray(
    ds=data,
    state_variables=["x1", "x2", "x3"],
    random_variables=[],
    nb_iter_SEM=10,
    variance_obs_comp=0.0001,
)


fig, ax = plt.subplots(1, 1)
for var in test.state_names:
    plot_state_with_probability(
        ax=ax,
        x_value=test.time,
        state=test.states.sel(state_names=var),
        prob=test.uncertainties.sel(state_names=var, state_names_copy=var),
        line_kwargs={"label": var.values},
    )

ax.legend()
ax.set_xlim((0, 2))
ax.set_xlabel("time")
ax.set_ylabel("Values")
ax.set_title("Validation using all states");

100%|██████████| 10/10 [00:01<00:00,  5.06it/s]


In [None]:
test = pipeline.run_Kalman_SEM_to_xarray(
    ds=data,
    state_variables=["x2", "x3"],
    random_variables=["z1", "z2"],
    nb_iter_SEM=10,
    variance_obs_comp=0.0001,
)

fig, ax = plt.subplots(1, 1)
for var in test.state_names:
    plot_state_with_probability(
        ax=ax,
        x_value=test.time,
        state=test.states.sel(state_names=var),
        prob=test.uncertainties.sel(state_names=var, state_names_copy=var),
        line_kwargs={"label": var.values},
    )

ax.legend()
ax.set_xlim((0, 2))
ax.set_xlabel("time")
ax.set_ylabel("Values")
ax.set_title("Using multiple random latent variables");

100%|██████████| 10/10 [00:02<00:00,  4.44it/s]


In [None]:
test = pipeline.xarray_Kalman_SEM_time_dependent(
    ds=data,
    state_variables=["x2", "x3"],
    random_variables=["z1", "z2", "z3", "z4"],
    nb_iter_SEM=10,
    variance_obs_comp=0.0001,
)

fig, ax = plt.subplots(1, 1)
for var in test.state_names:
    plot_state_with_probability(
        ax=ax,
        x_value=test.time,
        state=test.states.sel(state_names=var),
        prob=test.uncertainties.sel(state_names=var, state_names_copy=var),
        line_kwargs={"label": var.values},
    )

ax.legend()
ax.set_xlim((0, 2))
ax.set_xlabel("time")
ax.set_ylabel("Values")
ax.set_title("Using multiple random latent variables");

100%|██████████| 10/10 [00:21<00:00,  2.12s/it]
