In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
import seaborn as sns

import itertools
import warnings

import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
from matplotlib.legend_handler import HandlerPathCollection, HandlerLine2D

from kalman_reconstruction.kalman import (
    Kalman_SEM,
)
from reconstruct_climate_indices.idealized_ocean import AMO_oscillatory_ocean
from reconstruct_climate_indices.track_data import track_model
from tqdm import tqdm

In [None]:
SUBDATA_PATH = "AMO_oscillator"
PATH_FIGURES = Path("../results/AMO_oscillator")
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    if SAVE_FIGURES:
        fig.savefig(PATH_FIGURES / relative_path, **kwargs)
    else:
        pass

In [None]:
plt.rcParams["figure.figsize"] = (10, 8)
# Set default matplotlib style
MEDIUM_SIZE = 12
BIGGER_SIZE = 15
HUGHER_SIZE = 18
plt.style.use("seaborn-v0_8-whitegrid")
# plt.style.use('dark_background')

plt.rcParams["figure.figsize"] = (10.0, 6.0)
plt.rc("font", size=MEDIUM_SIZE)  # controls default text sizes
plt.rc("figure", titlesize=HUGHER_SIZE)  # fontsize of the axes title
plt.rc("figure", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels
plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc(
    "axes.spines",
    **dict(
        left=True,
        right=False,
        bottom=True,
        top=False,
    )
)
# fontsize of the x and y labels
plt.rc("xtick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rc("legend", loc="upper right")
# use colorblind save colors https://davidmathlogic.com/colorblind
colors = [
    "#CC6677",
    "#6E9CB3",
    "#CA8727",
    "#44AA99",
    "#AA4499",
    "#D6BE49",
    "#A494F5",
]
plt.rcParams["axes.prop_cycle"] = plt.cycler(color=colors)


def plot_colors(colors):
    fig, axs = plt.subplots(figsize=(5, 1))
    for idx, color in enumerate(colors):
        axs.scatter(idx, 1, c=color, s=300)

    axs.set_yticks([])
    return fig, axs

In [None]:
def product_dict(**kwargs):
    keys = kwargs.keys()
    for instance in itertools.product(*kwargs.values()):
        yield dict(zip(keys, instance))

In [None]:
default_settings = dict(
    nt=3000,  # timesteps
    dt=30,  # days
    per0=24 * 365.25,  # days
    tau0=10 * 365.25,  # days
    dNAO=0.1,
    dEAP=0.1,
    cNAOvsEAP=0,
    save_path=None,
    return_settings=True,
)

data = track_model(
    _func=AMO_oscillatory_ocean, func_kwargs=default_settings, subdata_path=SUBDATA_PATH
)
data

In [None]:
# make sure to get a kind of random seed
seed = np.random.default_rng(seed=2349832653).integers(0, 1e12, 1)
variance_unobs_comp = 1
random_generator = np.random.default_rng(seed=seed)

data["random_variable"] = (
    ["time"],
    random_generator.normal(loc=0, scale=variance_unobs_comp, size=len(data.time)),
)
iselect_dict = dict(
    dEAP=0,
    dNAO=0,
    cNAOvsEAP=0,
)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2)
axs_flat = axs.flatten()
for idx, var in enumerate(["NAO", "EAP", "ZOT", "AMO"]):
    data[var].plot(ax=axs_flat[idx], x="time_years")
    axs_flat[idx].set_title(var)
    axs_flat[idx].set_xlabel("Time in years")
    axs_flat[idx].set_ylabel("Value")

fig.tight_layout()
save_fig(fig, "Evolution.png", dpi=400)

In [None]:
import seaborn as sns

df = (
    data.isel(iselect_dict)
    .drop(["dNAO", "dEAP", "cNAOvsEAP", "time_years"])
    .to_dataframe()
)
g = sns.PairGrid(df)
g.map_upper(sns.histplot)
g.map_lower(sns.kdeplot, fill=False)
g.map_diag(sns.histplot, kde=True)
save_fig(g, "CorrelationMap.png", dpi=400)

Code to run the Kalman Itteration for all experiments

In [None]:
# variance of the random white noise of z
variance_unobs_comp = 1

# variance of the observation error used in Kalman
variance_obs_comp = 0.0001

# number of SEM iterations
nb_iter_SEM = 50

Create Dataset for kalman_results

In [None]:
var_name = lambda key, krn: f"{key}{krn}"

data_kalman = xr.Dataset({})
data_kalman = data_kalman.assign_coords(dict(kalman_iteration=np.arange(nb_iter_SEM)))
data_kalman["kalman_loglike"] = (
    ["kalman_iteration"],
    np.zeros_like(data_kalman.kalman_iteration) * np.nan,
)
kalman_result_names = ["_kalman_state", "_kalman_prob"]

# States!!
state_names = sorted(["ZOT", "NAO", "EAP"])
obser_names = state_names.copy()
obser_names.append("random_variable")


for key in obser_names:
    for krn in kalman_result_names:
        data_kalman[var_name(key, krn)] = xr.DataArray(coords=data.coords)

In [None]:
def run_Kalman_SEM(y_list, z1):
    # state
    y = np.array(y_list).T
    y_list.append(z1)
    x = np.array(y_list).T

    # shapes
    n = np.shape(x)[1]
    p = np.shape(y)[1]

    # kalman parameters
    H = np.append(np.eye(p), np.zeros((p, n)), axis=1)[:, 0:n]
    R = variance_obs_comp * np.eye(p)

    # stochastic EM
    return Kalman_SEM(x, y, H, R, nb_iter_SEM)


iselect_dict = dict(
    dEAP=0,
    dNAO=0,
    cNAOvsEAP=0,
)

y_list = []
for var in state_names:
    if "random" in var:
        pass
    else:
        y_list.append(data[var].isel(iselect_dict).values.flatten())
z1 = data.random_variable.values.flatten()
(
    kalman_state,
    kalman_prob,
    M,
    kalman_loglik,
    x,
    x_f,
    Q,
) = run_Kalman_SEM(y_list, z1)

100%|██████████| 50/50 [00:53<00:00,  1.07s/it]


In [None]:
data_kalman["kalman_loglike"] = xr.DataArray(
    data=kalman_loglik,
    dims=["kalman_iteration"],
    coords=dict(
        kalman_iteration=data_kalman.kalman_iteration,
    ),
)

for idx, key in enumerate(obser_names):
    for krn, temp in zip(kalman_result_names, [kalman_state, kalman_prob]):
        # Store the State results
        if np.ndim(temp) == 2:
            data_kalman[var_name(key, krn)][iselect_dict] = xr.DataArray(
                data=temp[:, idx],
                dims=["time"],
                coords=dict(
                    time=data_kalman.time,
                ),
            )
        # Store the Prob. results
        elif np.ndim(temp) == 3:
            data_kalman[var_name(key, krn)][iselect_dict] = xr.DataArray(
                data=temp[:, idx, idx],
                dims=["time"],
                coords=dict(
                    time=data_kalman.time,
                ),
            )
        else:
            pass

In [None]:
fig, ax = plt.subplots(1, 1)
data_kalman.kalman_loglike.plot(ax=ax, x="kalman_iteration")
ax.set_ylabel("Value")
ax.set_xlabel("Iteration")
ax.set_title("Loglikelihood of Kalman-SEM")
save_fig(fig, "Loglikelihood.png", dpi=400)

In [None]:
def plot_state_prob(
    ax,
    x_value,
    state,
    prob,
    ci=1.96,
    line_kwargs={},
    fill_kwargs=dict(alpha=0.3, label=None),
):
    p = ax.plot(x_value, state, **line_kwargs)
    ax.fill_between(
        x_value,
        state - ci * np.sqrt(prob),
        state + ci * np.sqrt(prob),
        color=p[0].get_color(),
        **fill_kwargs
    )

In [None]:
time_slice = data_kalman.time[data_kalman.time_years < 100]  # in years

In [None]:
fig, ax = plt.subplots(1, 1)
if "AMO" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice),
        state=data_kalman.AMO_kalman_state.sel(time=time_slice).values.flatten(),
        prob=data_kalman.AMO_kalman_prob.sel(time=time_slice).values.flatten(),
        line_kwargs=dict(label="AMO kalman"),
    )
    data.AMO.sel(time=time_slice).plot(
        ax=ax, x="time_years", label="AMO truth", linestyle=":"
    )
else:
    pass
if "ZOT" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice),
        state=data_kalman.ZOT_kalman_state.sel(time=time_slice).values.flatten(),
        prob=data_kalman.ZOT_kalman_prob.sel(time=time_slice).values.flatten(),
        line_kwargs=dict(label="ZOT kalman"),
    )
    data.ZOT.sel(time=time_slice).plot(
        ax=ax, x="time_years", label="ZOT truth", linestyle=":"
    )
else:
    pass

ax.legend()
ax.set_ylabel("Value")
ax.set_xlabel("Time in years")
ax.set_title("Deterministic variables shading as 95% CI")
save_fig(fig, "deterministic_variables.png", dpi=400)

In [None]:
fig, ax = plt.subplots(1, 1)

if "AMO" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice),
        state=data_kalman.AMO_kalman_state.sel(time=time_slice).values.flatten(),
        prob=data_kalman.AMO_kalman_prob.sel(time=time_slice).values.flatten(),
        line_kwargs=dict(label="AMO kalman"),
    )
else:
    pass
if "ZOT" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice),
        state=data_kalman.ZOT_kalman_state.sel(time=time_slice).values.flatten(),
        prob=data_kalman.ZOT_kalman_prob.sel(time=time_slice).values.flatten(),
        line_kwargs=dict(label="ZOT kalman"),
    )
else:
    pass
plot_state_prob(
    ax=ax,
    x_value=data_kalman.time_years.sel(time=time_slice),
    state=data_kalman.random_variable_kalman_state.sel(
        time=time_slice
    ).values.flatten(),
    prob=data_kalman.random_variable_kalman_prob.sel(time=time_slice).values.flatten(),
    line_kwargs=dict(label="z1"),
)
ax.legend()
ax.set_ylabel("Value")
ax.set_xlabel("Time in years")
ax.set_title("Deterministic variables and z1 - shading as 95% CI")
save_fig(fig, "random_variable.png", dpi=400)

In [None]:
fig, ax = plt.subplots(1, 1)
if "NAO" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice),
        state=data_kalman.NAO_kalman_state.sel(time=time_slice).values.flatten(),
        prob=data_kalman.NAO_kalman_prob.sel(time=time_slice).values.flatten(),
        line_kwargs=dict(label="NAO kalman"),
    )
    data.NAO.sel(time=time_slice).plot(
        ax=ax, x="time_years", alpha=0.5, label="NAO truth"
    )
else:
    pass
if "EAP" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice),
        state=data_kalman.EAP_kalman_state.sel(time=time_slice).values.flatten(),
        prob=data_kalman.EAP_kalman_prob.sel(time=time_slice).values.flatten(),
        line_kwargs=dict(label="EAP kalman"),
    )
    data.EAP.sel(time=time_slice).plot(
        ax=ax, x="time_years", alpha=0.5, label="EAP truth"
    )
else:
    pass


ax.set_ylabel("Value")
ax.set_xlabel("Time in years")
ax.legend()

ax.set_title("Stochastic variables shading as 95% CI")
save_fig(fig, "stochastic_variables.png", dpi=400)

In [None]:
time_slice_beginning = data_kalman.time[data_kalman.time_years < 5]  # in years
fig, ax = plt.subplots(1, 1)
if "NAO" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice_beginning),
        state=data_kalman.NAO_kalman_state.sel(
            time=time_slice_beginning
        ).values.flatten(),
        prob=data_kalman.NAO_kalman_prob.sel(
            time=time_slice_beginning
        ).values.flatten(),
        line_kwargs=dict(label="NAO kalman"),
    )
    data.NAO.sel(time=time_slice_beginning).plot(
        ax=ax, x="time_years", alpha=0.5, label="NAO truth"
    )
else:
    pass
if "EAP" in obser_names:
    plot_state_prob(
        ax=ax,
        x_value=data_kalman.time_years.sel(time=time_slice_beginning),
        state=data_kalman.EAP_kalman_state.sel(
            time=time_slice_beginning
        ).values.flatten(),
        prob=data_kalman.EAP_kalman_prob.sel(
            time=time_slice_beginning
        ).values.flatten(),
        line_kwargs=dict(label="EAP kalman"),
    )
    data.EAP.sel(time=time_slice_beginning).plot(
        ax=ax, x="time_years", alpha=0.5, label="EAP truth"
    )
else:
    pass


ax.set_ylabel("Value")
ax.set_xlabel("Time in years")
ax.legend()

ax.set_title("Stochastic variables shading as 95% CI")
save_fig(fig, "stochastic_variables_beginning.png", dpi=400)

Plot Correlation Maps for the result

In [None]:
prob_vars = [var for var in data_kalman.data_vars if "prob" in var]

df = (
    data.isel(iselect_dict)
    .drop(["dNAO", "dEAP", "cNAOvsEAP", "time_years"])
    .to_dataframe()
)
df_kalman = (
    data_kalman.isel(iselect_dict)
    .drop(["kalman_loglike", "kalman_iteration"])
    .drop(["dNAO", "dEAP", "cNAOvsEAP", "time_years"])
    .drop(prob_vars)
    .to_dataframe()
)
df_all = pd.concat([df, df_kalman], ignore_index=False, sort=False, axis=1)

In [None]:
g = sns.PairGrid(df_all)
# g.map_upper(sns.histplot)
g.map_lower(sns.kdeplot, fill=False)
g.map_diag(sns.histplot, kde=True)
save_fig(g, "correlation_kalman_applied.png", dpi=400)

Plot heatmap

In [None]:
# Heatmap of correlation
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
sns.heatmap(
    df_all.corr(),
    ax=ax,
    annot=True,
    fmt=".2f",
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
    square=True,
)
ax.set_title(f"Pearson-Correlation")
fig.tight_layout()
save_fig(fig, "pearson_coefficient.png", dpi=400)