In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
import itertools
import matplotlib.pyplot as plt

# from kalman_reconstruction.kalman import (
#     Kalman_SEM,
# )
from kalman_reconstruction import pipeline
from kalman_reconstruction.custom_plot import (
    plot_state_with_probability,
    set_custom_rcParams,
)
from kalman_reconstruction.statistics import (
    normalize,
    autocorr,
)

set_custom_rcParams()
plt.rcParams["figure.figsize"] = (8, 5)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

from reconstruct_climate_indices.idealized_ocean import AMO_oscillatory_ocean
from reconstruct_climate_indices.track_data import (
    track_model,
    run_function_on_multiple_datasets,
)
from tqdm import tqdm

In [None]:
SUBDATA_PATH = "AMO_oscillator_V2"
PATH_FIGURES = Path("../results/AMO_oscillator_V2")
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    if SAVE_FIGURES:
        fig.savefig(PATH_FIGURES / relative_path, **kwargs)
    else:
        pass

In [None]:
def product_dict(**kwargs):
    keys = kwargs.keys()
    for instance in itertools.product(*kwargs.values()):
        yield dict(zip(keys, instance))

In [None]:
default_settings = dict(
    nt=1000,  # timesteps
    dt=30,  # days
    per0=24 * 365.25,  # days
    tau0=10 * 365.25,  # days
    dNAO=0.1,
    dEAP=0.1,
    cNAOvsEAP=0,
    save_path=None,
    return_settings=True,
)
setting = default_settings.copy()
select_dict = {key: default_settings[key] for key in ["dNAO", "dEAP", "cNAOvsEAP"]}

factor = np.array([[0.1], [0.5], [1], [5]])
experiment_settings = dict()
experiment_settings_flat = dict()
for key in ["dNAO", "dEAP"]:
    experiment_settings[key] = default_settings[key] * factor
    experiment_settings_flat[key] = default_settings[key] * factor.flatten()


data_list = []
for s in tqdm(list(product_dict(**experiment_settings_flat))):
    setting.update(**s)
    data = track_model(
        func=AMO_oscillatory_ocean,
        mlflow_args=dict(experiment_id="286934939241168502"),
        func_kwargs=setting,
        subdata_path=SUBDATA_PATH,
    )
    data_list.append(data)

100%|██████████| 16/16 [00:04<00:00,  3.65it/s]


In [None]:
experiments = xr.merge(data_list)

In [None]:
fig, axs = plt.subplots(
    nrows=len(experiments.dNAO), ncols=len(experiments.dEAP), figsize=(15, 15)
)
for i, dNAO in tqdm(enumerate(experiments.dNAO)):
    for j, dEAP in enumerate(experiments.dEAP):
        axs[i, j].plot(
            experiments.time_years,
            experiments["AMO"].sel(dNAO=dNAO, dEAP=dEAP),
            label="AMO",
        )
        axs[i, j].plot(
            experiments.time_years,
            experiments["ZOT"].sel(dNAO=dNAO, dEAP=dEAP),
            label="ZOT",
        )
        axs[i, j].set_title(f"dNAO: {dNAO:.2f}, dEAP: {dEAP:.2f}")
        axs[i, j].set_ylabel("value")
        axs[i, j].set_xlabel("years")
        axs[i, j].legend()

fig.suptitle("Variation of dNAO and dEAP")
fig.tight_layout()

4it [00:00,  9.62it/s]


In [None]:
data_standard = pipeline.to_standard_dataset(data)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2)
axs_flat = axs.flatten()
for idx, var in enumerate(["NAO", "EAP", "ZOT", "AMO"]):
    data[var].plot(ax=axs_flat[idx], x="time_years")
    axs_flat[idx].set_title(var)
    axs_flat[idx].set_xlabel("Time in years")
    axs_flat[idx].set_ylabel("Value")

fig.tight_layout()
save_fig(fig, "Evolution.png", dpi=400)

Code to run the Kalman Itteration for all experiments

In [None]:
seed = 39264
variance = 5
nb_iter_SEM = 50
forcast_duration = 0.5

rng1 = np.random.default_rng(seed=seed)
rng2 = np.random.default_rng(seed=seed + 1)
rng3 = np.random.default_rng(seed=seed + 2)
rng4 = np.random.default_rng(seed=seed + 3)

Create Dataset for kalman_results

In [None]:
observation_variables = ["ZOT", "NAO", "EAP"]
state_variables = ["ZOT", "NAO", "EAP", "latent"]
data_1_latent = data.copy()
pipeline.add_random_variable(
    ds=data_1_latent, var_name="latent", random_generator=rng3, variance=1
)
kalman_result = pipeline.xarray_Kalman_SEM(
    ds=data_1_latent,
    observation_variables=observation_variables,
    state_variables=state_variables,
    nb_iter_SEM=nb_iter_SEM,
)

100%|██████████| 50/50 [00:22<00:00,  2.25it/s]


In [None]:
kalman_states = pipeline.from_standard_dataset(
    kalman_result, var_name="states", prefix="kalman_"
)

In [None]:
kalman_result.log_likelihod.plot()

[<matplotlib.lines.Line2D at 0x203099a37c0>]

In [None]:
for state in observation_variables:
    plt.scatter(
        normalize(kalman_result.states.sel(state_name=state)),
        normalize(data[state].values.flatten()),
        alpha=0.5,
        label=state,
    )

plt.plot([0, 1], [0, 1], color="k")
plt.legend()

<matplotlib.legend.Legend at 0x203091a32e0>

In [None]:
for state in data_standard.state_name.values:
    corr = xr.corr(
        normalize(kalman_result.states.sel(state_name="latent")), normalize(data[state])
    )
    plt.scatter(
        normalize(kalman_result.states.sel(state_name="latent")),
        normalize(data[state].values.flatten()),
        alpha=0.5,
        label=f"{state} : {corr:.2f}",
    )


plt.plot([-1, 1], [-1, 1], color="k")
plt.legend()

<matplotlib.legend.Legend at 0x2030a7539d0>

In [None]:
fig, ax = plt.subplots(1, 1)
for state in kalman_result.state_name:
    plot_state_with_probability(
        ax=ax,
        x_value=kalman_result.time_years,
        state=kalman_result.states.sel(state_name=state),
        prob=kalman_result.covariance.sel(state_name=state, state_name_copy=state),
        line_kwargs=dict(label=state.values),
        output=False,
    )
ax.legend()

<matplotlib.legend.Legend at 0x2030a402470>

In [None]:
def adjust_lightness(color, amount=0.5):
    import matplotlib.colors as mc
    import colorsys

    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2])

In [None]:
fig, ax = plt.subplots(1, 1)
for idx, state in enumerate(["AMO", "ZOT"]):
    color = colors[idx]
    dark_color = adjust_lightness(color)
    if state in kalman_result.state_name:
        plot_state_with_probability(
            ax=ax,
            x_value=kalman_result.time_years,
            state=kalman_result.states.sel(state_name=state),
            prob=kalman_result.covariance.sel(state_name=state, state_name_copy=state),
            line_kwargs=dict(color=color, label=f"{state}"),
        )
    else:
        pass

    data[state].plot(ax=ax, x="time_years", label=f"{state} truth", color=dark_color)

color = colors[idx + 1]
plot_state_with_probability(
    ax=ax,
    x_value=kalman_result.time_years,
    state=kalman_result.states.sel(state_name="latent"),
    prob=kalman_result.covariance.sel(state_name="latent", state_name_copy="latent"),
    line_kwargs=dict(label=f"latent", color=color),
)

ax.legend()
ax.set_ylabel("Value")
ax.set_xlabel("Time in years")
ax.set_title("Deterministic variables shading as 95% CI")
save_fig(fig, "deterministic_variables.png", dpi=400)
save_fig(fig, "deterministic_variables.svg")

In [None]:
def norm(self):
    # minmax
    return (self - self.min()) / (self.max() - self.min())
    # mean std
    # return (self-self.mean())/self.std(ddof=ddof)

In [None]:
time_slice = slice(0, 3000)
fig, ax = plt.subplots(1, 1)
for state in ["NAO", "EAP"]:
    if state in kalman_result.state_name:
        plot_state_with_probability(
            ax=ax,
            x_value=kalman_result.time_years.isel(time=time_slice),
            state=kalman_result.states.sel(state_name=state).isel(time=time_slice),
            prob=kalman_result.covariance.sel(
                state_name=state, state_name_copy=state
            ).isel(time=time_slice),
            line_kwargs=dict(label=f"{state}"),
        )
        data[state].isel(time=time_slice).plot(
            ax=ax, x="time_years", label=f"{state} truth", linestyle=":", color="grey"
        )
    else:
        print(f"{state} not it results")

ax.legend()
ax.set_ylabel("Value")
ax.set_xlabel("Time in years")
ax.set_title("Stochastic variables shading as 95% CI")
save_fig(fig, "stochastic_variables.png", dpi=400)

In [None]:
def vars_to_dataframe(ds):
    states = [var for var in ds.data_vars.keys()]
    drop_vars = [var for var in ds.coords.keys() if var not in states + ["time"]]
    return ds.drop(drop_vars).to_dataframe()


data_all = xr.merge([data_restored.sel(select_dict), kalman_states])
data_all = normalize(data_all)
df_all = vars_to_dataframe(data_all.isel(time=slice(0, 100)))

g = sns.PairGrid(df_all)
g.map_diag(sns.histplot, kde=True, bins=20)
g.map_upper(sns.histplot)
g.map_lower(sns.kdeplot, fill=False)
save_fig(g, "CorrelationMap.png", dpi=400)