In [None]:
import numpy as np
import xarray as xr
import yaml
from pathlib import Path
from os import PathLike
import seaborn as sns
from kalman_reconstruction.custom_plot import (
    set_custom_rcParams,
    adjust_lightness,
    plot_colors,
    symmetrize_axis,
    handler_map_alpha,
    symmetrize_axis,
    plot_state_with_probability,
)
from reconstruct_climate_indices.idealized_ocean import spunge_ocean, oscillatory_ocean
from reconstruct_climate_indices.statistics import (
    linear_regression_loglog,
    power_density_spectrum,
    xarray_dataset_welch,
    xarray_dataarray_welch,
)
from kalman_reconstruction.pipeline import from_standard_dataset, all_choords_as_dim
from kalman_reconstruction.statistics import normalize, crosscorr
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.figure import Figure

# from sklearn.linear_model import LinearRegression
from scipy import signal
from typing import Dict
import reconstruct_climate_indices.statistics as stati
from sklearn.linear_model import LinearRegression

In [None]:
# ## LIGHT THEME
# plt.style.use("seaborn-v0_8-whitegrid")
# dark_color = [0.3, 0.3, 0.3]
# light_color = [0.8, 0.8, 0.8]
# lightness_0 = 0.75
# lightness_1 = 0.5
# cmap = "rocket"
# cmap_r = "rocket_r"

### DARK THEME
plt.style.use("dark_background")
dark_color = [0.7, 0.7, 0.7]
light_color = [0.2, 0.2, 0.2]
lightness_0 = 1.15
lightness_1 = 1.5
cmap = "rocket_r"
cmap_r = "rocket"


colors = set_custom_rcParams()
plt.rcParams["grid.alpha"] = 0.5
plt.rcParams["axes.grid"] = True

plot_colors(colors)

variables_color = dict()
variables_color["SAT"] = colors[0]
variables_color["SST"] = colors[2]
variables_color["DOT"] = colors[1]
variables_color["latent"] = colors[3]
variables_color["loglikelihood"] = colors[-1]

In [None]:
REPO_PATH = Path(".").resolve().parent
results_path = (
    REPO_PATH / Path("results") / "Report" / "results" / "Evaluation_final_dark"
)
results_path.mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True


def save_fig(fig: Figure, relative_path: PathLike, kwargs: Dict = dict()):
    store_path = results_path / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

# Evaluation 

To evaluate the performance of the Kalman-SEM sucess metrics were found:

1. Increasing $\mathcal{L}$ (sigmoid shape)
2. Weak correlation of the final state of latent variable to Observations, $\rho_{obs}\approx 0$. If knowledge of hidden component exist, high correlation to it $|\left| \rho_{hid} \right||\approx 1$
3. PSD of final state of the latent variable should be turning from white to red \newline$m\approx-2 to -4$

The experiment used has the following properties:
- Experiment_ID = ``750584923317940775``
- ExperimentName = ``Evaluation-Idealized-Ocean``

## Load data 

### Sponge ocean 

In [None]:
run_names = [
    "polite-eel-349",
    "flawless-loon-25",
    "bustling-horse-699",
    "grandiose-hawk-664",
]

kalman_list = []
input_list = []
for name in run_names:
    data_path = REPO_PATH / "data" / "Evaluation-Idealized-Ocean" / name
    kalman_list.append(xr.open_dataset(data_path / (name + "_kalman.nc")))
    input_list.append(xr.open_dataset(data_path / (name + "_input.nc")))
sponge_kalman = xr.merge(kalman_list)
sponge_input = xr.merge(input_list)


# test = normalize(test, dim = "time", method = "mean")
sponge_kalman_states = from_standard_dataset(ds=sponge_kalman)

number_of_runs_sponge = len(sponge_kalman.seed)
ocean_name_sponge = "Sponge Ocean"

print(f"number of runs : {number_of_runs_sponge}")

number of runs : 30


### Oscillatory Ocean

In [None]:
run_names = [
    "dapper-fox-131",
    "respected-fowl-948",
    "aged-bat-385",
    "useful-mink-301",
    "placid-vole-890",
    "fun-ape-341",
]

kalman_list = []
input_list = []
for name in run_names:
    data_path = REPO_PATH / "data" / "Evaluation-Idealized-Ocean" / name
    kalman_list.append(xr.open_dataset(data_path / (name + "_kalman.nc")))
    input_list.append(xr.open_dataset(data_path / (name + "_input.nc")))
oscillatory_kalman = xr.merge(kalman_list)
oscillatory_input = xr.merge(input_list)


# test = normalize(test, dim = "time", method = "mean")
oscillatory_kalman_states = from_standard_dataset(ds=oscillatory_kalman)

number_of_runs_oscillatory = len(oscillatory_kalman.seed)
ocean_name_oscillatory = "Oscillatory Ocean"
print(f"number of runs : {number_of_runs_oscillatory}")

number of runs : 30


In [None]:
# check that seeds are the same:
xr.testing.assert_equal(oscillatory_kalman.seed, sponge_kalman.seed)

## Evaluation of loglikelihood and correlation

### Sponge Ocean

Plot the loglikelihoods of all experiments to estimate if all experiment succeed

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, layout="constrained")
for t in sponge_kalman.tau0:
    # for p in test.per0:
    select_dict = dict(
        tau0=t,
        # per0 = p,
    )
    m = sponge_kalman.sel(select_dict)["log_likelihod"].mean(dim="seed")
    s = sponge_kalman.sel(select_dict)["log_likelihod"].std(dim="seed")
    plot_state_with_probability(
        ax=ax,
        x_value=sponge_kalman["kalman_itteration"],
        state=m,
        prob=s,
        stds=5,
        line_kwargs=dict(color=variables_color["loglikelihood"]),
    )
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.set_title(
    f"Solid as mean and shaded as 2 std based on {number_of_runs_sponge} indep. runs for each parameter set"
)
title = f"Log likelihood evolution over kalman itterations."
fig.suptitle(f"{ocean_name_sponge} | {title}")
save_name = f"Loglikelihood_evolution"
save_fig(fig=fig, relative_path=Path(ocean_name_sponge) / (save_name + ".svg"))

Loglikelihood increase

In [None]:
# Calculate Loglikelihood increase
sponge_kalman["log_likelihod_increase"] = sponge_kalman["log_likelihod"].isel(
    kalman_itteration=-1
) - sponge_kalman["log_likelihod"].isel(kalman_itteration=0)
da_lli = sponge_kalman["log_likelihod_increase"]
da_lli = da_lli.expand_dims(df=[0.115])
# PLOT
heatmap_kwargs = dict(
    xticklabels=da_lli.tau0.values / 365.25,
    yticklabels=da_lli.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=-2.6,
    vmax=0,
    cmap=cmap,
)

m = da_lli.mean(dim="seed") * 10 ** (-3)
s = da_lli.std(dim="seed") * 10 ** (-3)

fig, axs = plt.subplots(nrows=2, ncols=1, layout="constrained", figsize=(5, 3))
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs_sponge} indep. runs ($\times 1000$)")
sns.heatmap(
    -np.abs(s),
    ax=axs[1],
    **heatmap_kwargs,
)
axs[1].set_title(f"Std. of {number_of_runs_sponge} indep. runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$df$ in y")

title = f"Log likelihood increase over all kalman itterations."
fig.suptitle(f"{ocean_name_sponge} | {title}")
save_name = f"Loglikelihood_increase"
save_fig(fig=fig, relative_path=Path(ocean_name_sponge) / (save_name + ".svg"))

calculate correlation to all input variables

In [None]:
# calculate all correaltions
corr_list = []

for var in sponge_input.data_vars:
    temp = crosscorr(
        ds1=sponge_kalman_states["latent"],
        ds2=sponge_input[var],
        dim="time",
    )
    temp.name = var
    corr_list.append(temp)
correlation_sponge = xr.merge(corr_list)
correlation_sponge = correlation_sponge.expand_dims(df=[0.115])

plot correaltion to variable

In [None]:
var = "SST"
da_corr = correlation_sponge[var]

# PLOT

heatmap_kwargs = dict(
    xticklabels=da_corr.tau0.values / 365.25,
    yticklabels=da_corr.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
    cmap=cmap_r,
)

m = np.abs(da_corr).mean(dim="seed")
s = np.abs(da_corr).std(dim="seed")

fig, axs = plt.subplots(nrows=2, ncols=1, layout="constrained", figsize=(5, 3))
# cbar_ax = fig.add_axes([.91, .3, .03, .4])
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs_sponge} indep. runs ($\times 1000$)")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs_sponge} indep. runs")
sns.heatmap(
    s,
    ax=axs[1],
    # cbar_ax=cbar_ax,
    **heatmap_kwargs,
)
axs[1].set_title(f"Std. of {number_of_runs_sponge} indep. runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$T_0 in y")

title = f"Correlation Coefficient {var} to latent."
fig.suptitle(f"{ocean_name_sponge} | {title}")
save_name = f"correlation_coefficient_{var}"
save_fig(fig=fig, relative_path=Path(ocean_name_sponge) / (save_name + ".svg"))

In [None]:
# Check correct order
m

### Oscillatory Ocean

#### Loglikelihoods of all experiments

It is important to check if all experiments are saturated in the loglikelihood

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, layout="constrained")
for t in oscillatory_kalman.tau0:
    for p in oscillatory_kalman.per0:
        select_dict = dict(
            tau0=t,
            per0=p,
        )
        m = oscillatory_kalman.sel(select_dict)["log_likelihod"].mean(dim="seed")
        s = oscillatory_kalman.sel(select_dict)["log_likelihod"].std(dim="seed")
        plot_state_with_probability(
            ax=ax,
            x_value=oscillatory_kalman["kalman_itteration"],
            state=m,
            prob=s,
            stds=2,
            # line_kwargs=dict(color=variables_color["loglikelihood"])
        )
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.set_title(
    f"Solid as mean and shaded as 2 std based on {number_of_runs_sponge} indep. runs for each parameter set"
)

title = f"Log likelihood evolution over kalman itterations."
fig.suptitle(f"{ocean_name_oscillatory} | {title}")
save_name = f"Loglikelihood_evolution"
save_fig(fig=fig, relative_path=Path(ocean_name_oscillatory) / (save_name + ".svg"))

#### Loglikelihood increase

In [None]:
# Calculate Loglikelihood increase
oscillatory_kalman["log_likelihod_increase"] = oscillatory_kalman["log_likelihod"].isel(
    kalman_itteration=-1
) - oscillatory_kalman["log_likelihod"].isel(kalman_itteration=0)
da_lli = oscillatory_kalman["log_likelihod_increase"]

# PLOT
heatmap_kwargs = dict(
    xticklabels=da_lli.per0.values / 365.25,
    yticklabels=da_lli.tau0.values / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=2.6,
    cmap=cmap_r,
)
mean_heatmap_kwargs = heatmap_kwargs.copy()
mean_heatmap_kwargs["cbar_kws"] = dict(label=r"Mean $\mathcal{LI}$ in $10^3$")
std_heatmap_kwargs = heatmap_kwargs.copy()
std_heatmap_kwargs["cbar_kws"] = dict(label=r"Std. $\mathcal{LI}$ in $10^3$")


m = da_lli.mean(dim="seed") * 10 ** (-3)
s = da_lli.std(dim="seed") * 10 ** (-3)

# PLOT
fig, axs = plt.subplots(
    nrows=1, ncols=2, layout="constrained", sharex=True, sharey=True, figsize=(8.5, 4)
)

sns.heatmap(m, ax=axs[0], **mean_heatmap_kwargs)

axs[0].set_title(rf"Mean of {number_of_runs_oscillatory} indep. runs")
sns.heatmap(
    s,
    ax=axs[1],
    **std_heatmap_kwargs,
)
axs[1].set_title(f"Std. of {number_of_runs_oscillatory} indep. runs")

for ax in axs:
    ax.set_xlabel(r"$T_0$ in y")
axs[0].set_ylabel(r"$\tau_0$ in y")

title = r"Loglikelihood increase ($\mathcal{LI}$)"
fig.suptitle(f"{ocean_name_oscillatory} | {title}")
save_name = f"Loglikelihood_increase"
save_fig(fig=fig, relative_path=Path(ocean_name_oscillatory) / (save_name + ".svg"))

In [None]:
fig, axs = plt.subplots(ncols=5, nrows=5, figsize=(10, 10), sharey=True, sharex=True)
for i, t in enumerate(da_lli.tau0):
    for j, p in enumerate(da_lli.per0):
        sel_dict = dict(
            tau0=t,
            per0=p,
        )
        axs[i, j].hist(da_lli.sel(sel_dict), density=True, color=dark_color, alpha=0.2)
        axs[i, j].axvline(
            da_lli.mean(dim="seed").sel(sel_dict), label="mean", color="r", linewidth=2
        )
        axs[i, j].axvline(
            da_lli.median(dim="seed").sel(sel_dict),
            label="median",
            color="b",
            linewidth=2,
        )
        axs[i, j].set_title(
            rf"$\tau_0$ = {da_lli.tau0[i].values / 365.25}"
            + "\n"
            + f"$\omega_0$ = {da_lli.per0[j].values / 365.25}"
        )
        axs[i, j].legend()

#### Correlation to other variables

In [None]:
# calculate all correaltions
corr_list = []

for var in oscillatory_input.data_vars:
    temp = crosscorr(
        ds1=oscillatory_kalman_states["latent"],
        ds2=oscillatory_input[var],
        dim="time",
    )
    temp.name = var
    corr_list.append(temp)
correlation_oscillatory = xr.merge(corr_list)

In [None]:
var = "DOT"
da_corr = correlation_oscillatory[var]

# PLOT

heatmap_kwargs = dict(
    xticklabels=da_corr.per0.values / 365.25,
    yticklabels=da_corr.tau0.values / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
    cmap=cmap_r,
)

mean_heatmap_kwargs = heatmap_kwargs.copy()
mean_heatmap_kwargs["cbar_kws"] = dict(label=r"Mean $\left| \rho_{hid} \right|$")
std_heatmap_kwargs = heatmap_kwargs.copy()
std_heatmap_kwargs["cbar_kws"] = dict(label=r"Std. $\left| \rho_{hid} \right|$")


m = np.abs(da_corr).mean(dim="seed")
s = np.abs(da_corr).std(dim="seed")

# PLOT
fig, axs = plt.subplots(
    nrows=1, ncols=2, layout="constrained", sharex=True, sharey=True, figsize=(8.5, 4)
)

sns.heatmap(m, ax=axs[0], **mean_heatmap_kwargs)

axs[0].set_title(rf"Mean of {number_of_runs_oscillatory} indep. runs")
sns.heatmap(
    s,
    ax=axs[1],
    **std_heatmap_kwargs,
)
axs[1].set_title(f"Std. of {number_of_runs_oscillatory} indep. runs")

for ax in axs:
    ax.set_xlabel(r"$T_0$ in y")
axs[0].set_ylabel(r"$\tau_0$ in y")


title = f"Correlation Coefficient {var} to latent " + r"($\left| \rho_{hid} \right|$)"
fig.suptitle(f"{ocean_name_oscillatory} | {title}")
save_name = f"correlation_coefficient_{var}"
save_fig(fig=fig, relative_path=Path(ocean_name_oscillatory) / (save_name + ".svg"))

In [None]:
sponge_input.time[1]

In [None]:
# Validate that the axis are correct:
print("Validate that the axis are correct!")
for t, p in zip([5, 20, 15], [48, 48, 36]):
    val = (
        np.abs(correlation_oscillatory[var].sel(tau0=t * 365.25, per0=p * 365.25))
        .mean(dim="seed")
        .values
    )
    print(f"tau0 = {t}, per0 = {p}\t {val:.2f}")

Validate that the axis are correct!
tau0 = 5, per0 = 48	 0.04
tau0 = 20, per0 = 48	 0.21
tau0 = 15, per0 = 36	 0.44


## Power Spectral Density

#### Compute the power spectral density

NOTE: 
In order to compute a Power Spectral Density **WIHTOUT** using welch method, it is necessary to set the ``window = "boxcar"`` and ``npersep = "lengt of the data"``.
```` python
welch_kwargs = dict(
    fs = 12,                    # period is 1/12 y -> fs = 12 y^{-1} 
    nperseg = len(data.time),   # length in timesteps
    scaling = "density",
    window = "boxcar"
)
````

In [None]:
# Compute PSD with frequency in year**{-1}
# Set up welch_kwargs to use NO welch method.
welch_kwargs = dict(
    fs=12,  # period is 1/12 y -> fs = 12 y^{-1}
    nperseg=len(sponge_input.time),  # length in timesteps
    scaling="density",
    window="boxcar",
)

psd_sponge = xarray_dataset_welch(sponge_input, dim="time", welch_kwargs=welch_kwargs)
psd_oscillatory = xarray_dataset_welch(
    oscillatory_input, dim="time", welch_kwargs=welch_kwargs
)

psd_sponge_kalman = xarray_dataset_welch(
    sponge_kalman_states, dim="time", welch_kwargs=welch_kwargs
)
psd_oscillatory_kalman = xarray_dataset_welch(
    oscillatory_kalman_states, dim="time", welch_kwargs=welch_kwargs
)

#### Perform ``linear_regression_loglog`` 
$f_{low} = 1/100$ in $y^{-1}$ for all!

In [None]:
F_LOW = 1 / 100
F_HIGH = np.inf

In [None]:
frequencies_0, sponge_linear, sponge_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    psd_sponge["frequency"],  # Input frequencies
    psd_sponge,  # Input spectrum
    # psd_sponge["tau0"] / 365.25,
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs=dict(
        f_low=F_LOW,
        f_high=F_HIGH,
        weights="f_inverse",
    ),
)

frequencies_1, oscillatory_linear, oscillatory_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    psd_oscillatory["frequency"],  # Input frequencies
    psd_oscillatory,  # Input spectrum
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs=dict(
        f_low=F_LOW,
        f_high=F_HIGH,
        weights="f_inverse",
    ),
)

frequencies_2, sponge_kalman_linear, sponge_kalman_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    psd_sponge_kalman["frequency"],  # Input frequencies
    psd_sponge_kalman,  # Input spectrum
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs=dict(
        f_low=F_LOW,
        f_high=F_HIGH,
        weights="f_inverse",
    ),
)

(
    frequencies_3,
    oscillatory_kalman_linear,
    oscillatory_kalman_regression,
) = xr.apply_ufunc(
    stati.linear_regression_loglog,
    psd_oscillatory_kalman["frequency"],  # Input frequencies
    psd_oscillatory_kalman,  # Input spectrum
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs=dict(
        f_low=F_LOW,
        f_high=F_HIGH,
        weights="f_inverse",
    ),
)

In [None]:
# Assert that all frequencies are the same
l = [
    frequencies_0["frequency"],
    frequencies_1["frequency"],
    frequencies_2["frequency"],
    frequencies_3["frequency"],
]
for idx in np.arange(len(l) - 1):
    xr.testing.assert_equal(l[idx], l[idx - 1])

#### extract the slopes using xr.apply_ufunc

In [None]:
def get_slope(x: LinearRegression) -> float:
    """resturns"""
    return x.coef_[0][0]


sponge_slopes = xr.apply_ufunc(
    get_slope,
    sponge_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)

oscillatory_slopes = xr.apply_ufunc(
    get_slope,
    oscillatory_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)

sponge_kalman_slopes = xr.apply_ufunc(
    get_slope,
    sponge_kalman_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)
oscillatory_kalman_slopes = xr.apply_ufunc(
    get_slope,
    oscillatory_kalman_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)

#### plot the slope

In [None]:
var = "latent"
da_slopes = sponge_kalman_slopes[var]

fig, axs = plt.subplots(ncols=5, figsize=(8, 2), sharey=True)
for idx in range(5):
    axs[idx].hist(da_slopes.isel(tau0=idx), density=True, color=dark_color, alpha=0.2)
    axs[idx].axvline(
        da_slopes.mean(dim="seed").isel(tau0=idx), label="mean", color="r", linewidth=2
    )
    axs[idx].axvline(
        da_slopes.median(dim="seed").isel(tau0=idx),
        label="median",
        color="b",
        linewidth=2,
    )
    axs[idx].set_title(rf"$\tau_0$ = {da_slopes.tau0[idx].values / 365.25}")
    axs[idx].legend()

In [None]:
var = "latent"
da_slopes = sponge_kalman_slopes[var]
da_slopes = da_slopes.expand_dims(df=[0.115])
heatmap_kwargs = dict(
    xticklabels=da_slopes.tau0.values / 365.25,
    yticklabels=da_slopes.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=-4,
    vmax=0,
    cmap=cmap,
)

# Calculate mean and std
m = da_slopes.mean(dim="seed")
s = da_slopes.std(dim="seed")

fig, axs = plt.subplots(nrows=2, ncols=1, layout="constrained", figsize=(5, 3))
# cbar_ax = fig.add_axes([.91, .3, .03, .4])
# plot mean
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs_sponge} indep. runs ($\times 1000$)")
# plot standard deviation
sns.heatmap(
    -s,
    ax=axs[1],
    # cbar_ax=cbar_ax,
    **heatmap_kwargs,
)
axs[1].set_title(f"Std. of {number_of_runs_oscillatory} indep. runs")

for ax in axs:
    ax.set_ylabel(r"$\tau_0$ in y")
    ax.set_xlabel(r"$T_0$ in y")

title = f"Slope of linear regression in loglog space for {var}."
fig.suptitle(f"{ocean_name_sponge} | {title}")
save_name = f"slope_linear_regression_{var}"
save_fig(fig=fig, relative_path=Path(ocean_name_sponge) / (save_name + ".svg"))

In [None]:
var = "latent"
da_slopes = oscillatory_kalman_slopes[var]

fig, axs = plt.subplots(
    ncols=5, nrows=5, figsize=(10, 10), sharey=True, sharex=True, layout="constrained"
)
for i, t in enumerate(da_slopes.tau0):
    for j, p in enumerate(da_slopes.per0):
        sel_dict = dict(
            tau0=t,
            per0=p,
        )
        axs[i, j].hist(
            da_slopes.sel(sel_dict), density=True, color=dark_color, alpha=0.2
        )
        axs[i, j].axvline(
            da_slopes.mean(dim="seed").sel(sel_dict),
            label="mean",
            color="r",
            linewidth=2,
        )
        axs[i, j].axvline(
            da_slopes.median(dim="seed").sel(sel_dict),
            label="median",
            color="b",
            linewidth=2,
        )
        axs[i, j].set_title(
            rf"$\tau_0$ = {da_slopes.tau0[i].values / 365.25}"
            + "\n"
            + f"$\omega_0$ = {da_slopes.per0[j].values / 365.25}"
        )
        axs[i, j].legend()

In [None]:
var = "latent"
da_slopes = oscillatory_kalman_slopes[var]
heatmap_kwargs = dict(
    xticklabels=da_slopes.per0.values / 365.25,
    yticklabels=da_slopes.tau0.values / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=-2.5,
    vmax=0,
    cmap=cmap,
)

mean_heatmap_kwargs = heatmap_kwargs.copy()
mean_heatmap_kwargs["cbar_kws"] = dict(label=r"Mean $m_{lat}$")
std_heatmap_kwargs = heatmap_kwargs.copy()
std_heatmap_kwargs["cbar_kws"] = dict(label=r"Std. $m_{lat}$")

# Calculate mean and std
m = da_slopes.mean(dim="seed")
s = da_slopes.std(dim="seed")

# PLOT
fig, axs = plt.subplots(
    nrows=1, ncols=2, layout="constrained", sharex=True, sharey=True, figsize=(8.5, 4)
)
# Plot mean
sns.heatmap(
    m,
    ax=axs[0],
    **mean_heatmap_kwargs,
)
axs[0].set_title(f"Mean of {number_of_runs_oscillatory} indep. runs")
# Plot standard deviation
hm = sns.heatmap(
    -s,
    ax=axs[1],
    **std_heatmap_kwargs,
)
axs[1].set_title(f"Std. of {number_of_runs_oscillatory} indep. runs")

for ax in axs:
    ax.set_xlabel(r"$T_0$ in y")
axs[0].set_ylabel(r"$\tau_0$ in y")

title = f"m for final state of {var} variable " + r"($m_{lat}$)"
fig.suptitle(f"{ocean_name_oscillatory} | {title}")
save_name = f"slope_linear_regression_{var}"
save_fig(
    fig=fig,
    relative_path=Path(ocean_name_oscillatory)
    / (save_name + f"_{F_LOW:.2E}_{F_HIGH:.2E}.svg"),
)

## Detailed insights

#### Plot PSD Spectra SST, DOT and latent

In [None]:
m = psd_oscillatory_kalman.latent.median(dim="seed")
m_DOT = psd_oscillatory.DOT.mean(dim="seed")
m_SST = psd_oscillatory.SST.mean(dim="seed")
mini = psd_oscillatory_kalman.latent.quantile(0.1, dim="seed")
maxi = psd_oscillatory_kalman.latent.quantile(0.9, dim="seed")
fig, axs = plt.subplots(
    nrows=5, ncols=5, layout="constrained", sharex=True, sharey=True, figsize=(10, 10)
)
for i, t in enumerate(oscillatory_kalman.tau0):
    for j, p in enumerate(oscillatory_kalman.per0):
        select_dict = dict(
            tau0=t,
            per0=p,
        )
        for s in psd_oscillatory_kalman.seed:
            axs[i, j].plot(
                psd_oscillatory.frequency,
                psd_oscillatory_kalman.latent.sel(select_dict).sel(seed=s),
                linewidth=1,
                alpha=0.01,
                color=dark_color,
            )
        h = axs[i, j].plot(
            psd_oscillatory.frequency,
            m.sel(select_dict),
            alpha=1,
            linewidth=2,
            color=variables_color["latent"],
        )
        axs[i, j].fill_between(
            x=psd_oscillatory.frequency,
            y1=mini.sel(select_dict),
            y2=maxi.sel(select_dict),
            color=h[0].get_color(),
            alpha=0.4,
        )
        h = axs[i, j].plot(
            psd_oscillatory.frequency,
            m_DOT.sel(select_dict),
            alpha=0.75,
            linewidth=2,
            color=variables_color["DOT"],
        )
        h = axs[i, j].plot(
            psd_oscillatory.frequency,
            m_SST.sel(select_dict),
            alpha=0.75,
            linewidth=2,
            color=variables_color["SST"],
        )

        # plot_state_with_probability(
        #     ax= ax,
        #     x_value= psd_oscillatory.frequency,
        #     state= m.sel(select_dict),
        #     prob = s.sel(select_dict),
        # )

for i, t in enumerate(oscillatory_kalman.tau0):
    axs[i, 0].set_ylabel(rf"$\tau_0$ = {t.values / 365.25} y")
for j, p in enumerate(oscillatory_kalman.per0):
    axs[-1, j].set_xlabel(rf"$T_0$ = {p.values / 365.25} y")


for ax in axs.flatten():
    ax.set_yscale("log")
    ax.set_xscale("log")
    ax.set_ylim(ymin=10 ** (-8))


save_fig(fig, relative_path="ALL_FREQUENCIES.pdf")
save_fig(fig, relative_path="ALL_FREQUENCIES.png")
save_fig(fig, relative_path="ALL_FREQUENCIES.svg")

#### Plot lagged correaltion SST, DOT

In [None]:
lagged_correlation_DOT_SST = []
for lag in np.arange(0, 48 * 12, 6):
    lagged_correlation_DOT_SST.append(
        crosscorr(
            ds1=oscillatory_input.DOT, ds2=oscillatory_input.SST, lag=lag, dim="time"
        )
        .expand_dims(lag=[lag])
        .rename("lagged_correlation_DOT_SST")
    )

In [None]:
ds_lagged_correlation_DOT_SST = xr.merge(lagged_correlation_DOT_SST)
da_lagged_correlation_DOT_SST = ds_lagged_correlation_DOT_SST[
    "lagged_correlation_DOT_SST"
]

In [None]:
m_crosscorr = da_lagged_correlation_DOT_SST.mean(dim="seed")
std_crosscorr = da_lagged_correlation_DOT_SST.std(dim="seed")

In [None]:
fig, axs = plt.subplots(
    nrows=5, ncols=5, layout="constrained", sharex=True, sharey=True, figsize=(10, 10)
)
for i, t in enumerate(oscillatory_kalman.tau0):
    for j, p in enumerate(oscillatory_kalman.per0):
        select_dict = dict(
            tau0=t,
            per0=p,
        )
        ax = axs[i, j]
        # ax.errorbar(m_crosscorr.lag / 12, m_crosscorr.sel(select_dict), yerr=std_crosscorr.sel(select_dict), fmt='.')
        for s in da_lagged_correlation_DOT_SST.seed:
            ax.plot(
                da_lagged_correlation_DOT_SST.lag / 12,
                da_lagged_correlation_DOT_SST.sel(select_dict).sel(seed=s),
                color=dark_color,
                alpha=0.25,
            )
        plot_state_with_probability(
            ax=ax,
            x_value=m_crosscorr.lag / 12,
            state=m_crosscorr.sel(select_dict),
            prob=std_crosscorr.sel(select_dict),
            line_kwargs=dict(marker="+"),
            stds=1,
        )

for i, t in enumerate(oscillatory_kalman.tau0):
    axs[i, 0].set_ylabel(rf"$\tau_0$ = {t.values / 365.25} y")
for j, p in enumerate(oscillatory_kalman.per0):
    axs[-1, j].set_xlabel(rf"$T_0$ = {p.values / 365.25} y")


save_fig(fig, relative_path="LAGGED_CORRELATION_ALL_FREQUENCIES.pdf")
save_fig(fig, relative_path="LAGGED_CORRELATION_ALL_FREQUENCIES.png")
save_fig(fig, relative_path="LAGGED_CORRELATION_ALL_FREQUENCIES.svg")

#### Plot normalized Distributions SAT, SST, DOT and latent

In [None]:
current_ds = normalize(oscillatory_input.copy(), method="norm", dim="time")
dx = 0.5
extent = 5
bin_values = np.arange(-extent, extent, dx) + dx / 2
x_values = (bin_values[:-1] + bin_values[1:]) / 2
c, v = xr.apply_ufunc(
    np.histogram,
    current_ds,  # Input frequencies
    input_core_dims=[
        ["time"],
    ],
    output_core_dims=[
        ["count"],
        ["edge"],
    ],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float],
    kwargs=dict(bins=bin_values, density=True),
)

# current_ds_mean = current_ds.mean(dim = "seed")
# current_ds_std = current_ds.std(dim = "seed")

In [None]:
var = "DOT"
count_median = c[var].median(dim="seed")
# count_mean = c[var].mean(dim = "seed")
# count_std = c[var].std(dim = "seed")


fig, axs = plt.subplots(
    nrows=5, ncols=5, layout="constrained", sharex=True, sharey=True, figsize=(10, 10)
)
for i, t in enumerate(current_ds.tau0):
    for j, p in enumerate(current_ds.per0):
        select_dict = dict(
            tau0=t,
            per0=p,
        )
        ax = axs[i, j]
        # ax.errorbar(m_crosscorr.lag / 12, m_crosscorr.sel(select_dict), yerr=std_crosscorr.sel(select_dict), fmt='.')
        for s in current_ds.seed:
            ax.plot(
                x_values,
                c[var].sel(select_dict).sel(seed=s),
                color=dark_color,
                alpha=0.2,
            )
        ax.plot(
            x_values,
            count_median.sel(select_dict),
            marker="+",
            color=variables_color[var],
            alpha=1,
            linewidth=2,
        )
        # plot_state_with_probability(
        #     ax = ax,
        #     x_value=x_values,
        #     state=count_mean.sel(select_dict),
        #     prob=count_std.sel(select_dict),
        #     line_kwargs=dict(marker = "+", color = variables_color[var]),
        #     stds=0.96,
        # )

for i, t in enumerate(oscillatory_kalman.tau0):
    axs[i, 0].set_ylabel(rf"$\tau_0$ = {t.values / 365.25} y")
for j, p in enumerate(oscillatory_kalman.per0):
    axs[-1, j].set_xlabel(rf"$T_0$ = {p.values / 365.25} y")


# save_fig(fig, relative_path=f"DISTRIBUTION_{var}_ALL_FREQUENCIES.pdf")
save_fig(fig, relative_path=f"DISTRIBUTION_{var}_ALL_FREQUENCIES.png")
save_fig(fig, relative_path=f"DISTRIBUTION_{var}_ALL_FREQUENCIES.svg")