In [None]:
import numpy as np
import xarray as xr
import yaml
from pathlib import Path
from kalman_reconstruction.custom_plot import (
    set_custom_rcParams,
    adjust_lightness,
    plot_colors,
    symmetrize_axis,
    handler_map_alpha,
    symmetrize_axis,
    plot_state_with_probability,
)
from reconstruct_climate_indices.idealized_ocean import spunge_ocean, oscillatory_ocean
from reconstruct_climate_indices.statistics import linear_regression_loglog
from kalman_reconstruction.pipeline import from_standard_dataset
from kalman_reconstruction.statistics import normalize, crosscorr
import matplotlib.pyplot as plt

# from sklearn.linear_model import LinearRegression
from scipy import signal

In [None]:
# dark_color = "k"
# lightness_0 = 0.75
# lightness_1 = 0.5
plt.style.use("dark_background")
dark_color = [0.7, 0.7, 0.7]
light_color = [0.2, 0.2, 0.2]
lightness_0 = 1.15
lightness_1 = 1.5
set_custom_rcParams()
plt.rcParams["axes.grid"] = False
colors = [
    "#CC6677",
    "#6E9CB3",
    "#CA8727",
    "#44AA99",
    "#AA4499",
    "#D6BE49",
    "#A494F5",
]
plot_colors(colors)
variables_color = dict()
variables_color["SAT"] = colors[0]
variables_color["SST"] = colors[2]
variables_color["DOT"] = colors[1]
variables_color["latent"] = colors[3]
variables_color["loglikelihood"] = (colors[-1],)

In [None]:
REPO_PATH = Path(".").resolve().parent
results_path = REPO_PATH / Path("results") / "Report" / "results" / "Example"
results_path.mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    store_path = results_path / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

In [None]:
df = 0.115
tau0 = 10  # years
lambda0 = np.pi * tau0 / 2  # years
omega0 = 24  # years

print(rf"$\lambda_0$ = {lambda0}")
print(rf"$\omega_0$ = {omega0}")

$\lambda_0$ = 15.707963267948966
$\omega_0$ = 24


In [None]:
# Sponge ocean
run_name = "dapper-doe-978"
SubdataPath = "simplified_ocean_experiments"
RunPath = Path(".")
RunPath = RunPath.resolve()
RunPath = RunPath.parent / "data" / SubdataPath / run_name
InputPath = RunPath / (run_name + "_input.nc")
KalmanPath = RunPath / (run_name + "_kalman.nc")
SettingsPath = RunPath / (run_name + "_setup.yaml")

experiments = xr.open_dataset(InputPath)
experiments_kalman = xr.open_dataset(KalmanPath)
experiments_kalman_states = from_standard_dataset(experiments_kalman)

# Read YAML file
with open(SettingsPath, "r") as stream:
    settings = yaml.safe_load(stream)
select_dict = dict(
    tau0=tau0 * 365.25,
    df=df,
)

sponge_data = experiments.sel(select_dict)
sponge_kalman_states = experiments_kalman_states.sel(select_dict)
sponge_kalman_states["loglikelihood"] = experiments_kalman["log_likelihod"].sel(
    select_dict
)

In [None]:
# Oscillatory ocean
run_name = "angry-swan-795"
SubdataPath = "simplified_ocean_experiments"
RunPath = Path(".")
RunPath = RunPath.resolve()
RunPath = RunPath.parent / "data" / SubdataPath / run_name
InputPath = RunPath / (run_name + "_input.nc")
KalmanPath = RunPath / (run_name + "_kalman.nc")
SettingsPath = RunPath / (run_name + "_setup.yaml")

experiments = xr.open_dataset(InputPath)
experiments_kalman = xr.open_dataset(KalmanPath)
experiments_kalman_states = from_standard_dataset(experiments_kalman)

# Read YAML file
with open(SettingsPath, "r") as stream:
    settings = yaml.safe_load(stream)
select_dict = dict(
    per0=omega0 * 365.25,
    tau0=tau0 * 365.25,
)


oscillator_data = experiments.sel(select_dict)
oscillator_kalman_states = experiments_kalman_states.sel(select_dict)
oscillator_kalman_states["loglikelihood"] = experiments_kalman["log_likelihod"].sel(
    select_dict
)

### Get the length, timestep and so on

In [None]:
def get_dt(l):
    idx = 0
    for current_dataset in l:
        if idx == 0:
            dt = current_dataset.time[1] - current_dataset.time[0]
        else:
            assert dt == current_dataset.time[1] - current_dataset.time[0]
    return dt


def get_T(l):
    idx = 0
    for current_dataset in l:
        if idx == 0:
            T = current_dataset.time[-1] - current_dataset.time[0]
        else:
            assert T == current_dataset.time[-1] - current_dataset.time[0]
    return T / 365.25


dt = get_dt(
    [oscillator_data, oscillator_kalman_states, sponge_data, sponge_kalman_states]
).values
print(dt)
years = np.round(
    get_T(
        [oscillator_data, oscillator_kalman_states, sponge_data, sponge_kalman_states]
    ).values,
    decimals=2,
)
print(years)
# years = 1000
time_steps = years * 365.25 / dt
fs = 365.25 / dt  # 1/years

welch_window_width = 250  # years

30
985.54


In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, layout="constrained", figsize=(6, 4))

ax.plot(
    sponge_kalman_states["kalman_itteration"],
    sponge_kalman_states["loglikelihood"],
    label="Sponge Ocean",
)
ax.plot(
    oscillator_kalman_states["kalman_itteration"],
    oscillator_kalman_states["loglikelihood"],
    label="Oscillatory Ocean",
)
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.legend(loc="lower right")
ax.set_title("Loglikelihood against Itteartions of Kalman-SEM")

fig_name = f"Loglikelihood_{years}y"
save_fig(fig=fig, relative_path=fig_name + ".pdf")
save_fig(fig=fig, relative_path=fig_name + ".png")

In [None]:
fig, axs = plt.subplots(
    nrows=1, ncols=2, layout="constrained", sharex=True, sharey=True
)

stepfill_kwargs = dict(
    bins=31,
    histtype="stepfilled",
    density=True,
    alpha=0.1,
)
step_kwargs = dict(
    bins=31,
    histtype="step",
    density=True,
    alpha=1,
    linewidth=3,
)

axs_spunge = axs[0]
axs_oscill = axs[1]
axs_spunge.set_title("Sponge Ocean")
axs_oscill.set_title("Oscillatory Ocean")

# Plot Spunge distribution
for var in ["SAT", "SST"]:
    his = axs_spunge.hist(
        normalize(sponge_data[var], method="mean"),
        label=var,
        color=variables_color[var],
        **step_kwargs,
    )
    axs_spunge.hist(
        normalize(sponge_data[var], method="mean"),
        color=variables_color[var],
        **stepfill_kwargs,
    )

# Plot Oscilltor distribution
for var in ["SAT", "SST", "DOT"]:
    his = axs_oscill.hist(
        normalize(oscillator_data[var], method="mean"),
        label=var,
        color=variables_color[var],
        **step_kwargs,
    )
    axs_oscill.hist(
        normalize(oscillator_data[var], method="mean"),
        color=variables_color[var],
        **stepfill_kwargs,
    )

# Plot normal distribution
import scipy.stats as stats

mu = 0
variance = 1
width = 4
sigma = np.sqrt(variance)
x = np.linspace(mu - width * sigma, mu + width * sigma, 100)
for ax in axs.flatten():
    ax.plot(
        x,
        stats.norm.pdf(x, mu, sigma),
        color=dark_color,
        label=r"$\mathcal{N}(0,1)$",
        linewidth=4,
    )
    ax.legend()
    ax.set_xlabel("Normalized Values")
    symmetrize_axis(axes=ax, axis="x")
axs[0].set_ylabel("Probability Density")

fig.suptitle("Normalized PDD Observations and hidden Component")

fig_name = f"Distribution_{years}y"
save_fig(fig=fig, relative_path=fig_name + ".pdf")
save_fig(fig=fig, relative_path=fig_name + ".png")

In [None]:
# dt = 30                     # days
# fs = 365.25/dt              # 1/years
# welch_window_width = 250    # years

# fig, axs = plt.subplots(nrows = 2, ncols = 2, layout="constrained", sharex=True, sharey=True)
# kwargs = dict(
#     linestyle = "-",
#     linewidth = "1",
#     marker = ".",
# )

# idx = 0 # x index to plot the text at

# axs_spunge = axs[:,0]
# axs_oscill = axs[:,1]
# axs_spunge[0].set_title('Sponge Ocean')
# axs_oscill[0].set_title('Oscillatory Ocean')

# # ------------------------------
# # plot SAT
# # ------------------------------
# for current_dataset, ax in zip(
#     [sponge_data, oscillator_data],
#     [axs_spunge[0],axs_oscill[0]]
#     ):
#     x = current_dataset["SST"].values
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = len(x)
#     )
#     h = ax.loglog(
#         frequencies,
#         spectrum,
#         label = "SAT",
#         alpha = 0.7,
#         color = SAT_color,
#     )
#     color = adjust_lightness(h[0].get_color(), lightness_0)
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = int(welch_window_width*fs)
#     )
#     ax.loglog(
#         frequencies,
#         spectrum,
#         label = "SAT welch",
#         color = color
#     )
#     frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
#         frequencies=frequencies,
#         spectrum=spectrum,
#         weights="f_inverse"
#     )
#     slope = regression.coef_[0,0]
#     color = adjust_lightness(color, lightness_1)
#     ax.loglog(
#         frequencies_linear,
#         spectrum_linear,
#         color = color,
#     )
#     ax.text(frequencies_linear[idx],
#             spectrum_linear[idx],
#             f"m={slope:.2f}",
#             ha='right',
#             va='bottom',
#             bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
#             color = adjust_lightness(color, lightness_1),)

# # ------------------------------
# # plot SST
# # ------------------------------
# for current_dataset, ax, f_low in zip(
#     [sponge_data, oscillator_data],
#     [axs_spunge[1],axs_oscill[1]],
#     [1/lambda0, 1/omega0]
#     ):
#     x = current_dataset["SST"].values
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = len(x)
#     )
#     h = ax.loglog(
#         frequencies,
#         spectrum,
#         label = "SST",
#         alpha = 0.7,
#         color = SST_color
#     )
#     color = adjust_lightness(h[0].get_color(), lightness_0)
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = int(welch_window_width*fs)
#     )
#     ax.loglog(
#         frequencies,
#         spectrum,
#         label = "SST welch",
#         color = color
#     )
#     frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
#         frequencies=frequencies,
#         spectrum=spectrum,
#         weights="f_inverse",
#         f_low = f_low,
#     )
#     slope = regression.coef_[0,0]
#     color = adjust_lightness(color, lightness_1)
#     ax.loglog(
#         frequencies_linear,
#         spectrum_linear,
#         color = color,
#     )
#     ax.text(frequencies_linear[idx],
#             spectrum_linear[idx],
#             f"m={slope:.2f}",
#             ha='right',
#             va='bottom',
#             bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
#             color = adjust_lightness(color, lightness_1),)

# # ------------------------------
# # plot DOT
# # ------------------------------

# for current_dataset, ax, f_low in zip(
#     [oscillator_data],
#     [axs_oscill[1]],
#     [1/omega0]
#     ):
#     x = current_dataset["DOT"].values
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = len(x)
#     )
#     h = ax.loglog(
#         frequencies,
#         spectrum,
#         label = "DOT",
#         alpha = 0.7,
#         color = DOT_color
#     )
#     color = adjust_lightness(h[0].get_color(), lightness_0)
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = int(welch_window_width*fs)
#     )
#     ax.loglog(
#         frequencies,
#         spectrum,
#         label = "DOT welch",
#         color = color
#     )
#     frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
#         frequencies=frequencies,
#         spectrum=spectrum,
#         weights="f_inverse",
#         f_low = f_low,
#     )
#     slope = regression.coef_[0,0]
#     color = adjust_lightness(color, lightness_1)
#     ax.loglog(
#         frequencies_linear,
#         spectrum_linear,
#         color = color,
#     )
#     ax.text(frequencies_linear[idx],
#             spectrum_linear[idx],
#             f"m={slope:.2f}",
#             ha='right',
#             va='bottom',
#             bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
#             color = adjust_lightness(color, lightness_1),)

# # ------------------------------
# # plot latent
# # ------------------------------

# for current_dataset, ax, f_low in zip(
#     [sponge_kalman_states, oscillator_kalman_states],
#     [axs_spunge[1],axs_oscill[1]],
#     [1/omega0, 1/omega0]
#     ):
#     x = current_dataset["latent"].values
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = len(x)
#     )
#     h = ax.loglog(
#         frequencies,
#         spectrum,
#         label = "latent",
#         alpha = 0.7,
#         color = latent_color
#     )
#     color = adjust_lightness(h[0].get_color(), lightness_0)
#     frequencies, spectrum = signal.welch(
#         x = x,
#         fs=fs,
#         window="hann",
#         nperseg = int(welch_window_width*fs)
#     )
#     ax.loglog(
#         frequencies,
#         spectrum,
#         label = "latent welch",
#         color = color
#     )
#     frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
#         frequencies=frequencies,
#         spectrum=spectrum,
#         weights="f_inverse",
#         f_low = f_low,
#     )
#     slope = regression.coef_[0,0]
#     color = adjust_lightness(color, lightness_1)
#     ax.loglog(
#         frequencies_linear,
#         spectrum_linear,
#         color = color,
#     )
#     ax.text(frequencies_linear[idx],
#             spectrum_linear[idx],
#             f"m={slope:.2f}",
#             ha='right',
#             va='bottom',
#             bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
#             color = adjust_lightness(color, lightness_1),)


# for ax in axs.flatten() :
#     ax.axvline(1/lambda0, color = dark_color, linestyle = ":") #, label = fr"$\tau_0$")
#     ax.text(1/lambda0,
#             10**(-6.5),
#             r"$\pi/(2\tau_0)$",
#             ha='left',
#             va='bottom',
#             c=dark_color,
#             rotation = 90
#     )
#     ax.axvline(1/omega0, color = dark_color, linestyle = "--") #, label = fr"$\omega_0$")
#     ax.text(1/omega0,
#             10**(-6.5),
#             r"$1/\omega_0$",
#             ha='right',
#             va='bottom',
#             c=dark_color,
#             rotation = 90
#     )
#     ax.axvline(1/welch_window_width, color = dark_color, linestyle = "-") #, label = f"welch window.:\n{welch_window_width}")
#     ax.text(1/welch_window_width,
#             10**(-6.5),
#             r"$f_{welch}$",
#             ha='right',
#             va='bottom',
#             c=dark_color,
#             rotation = 90
#     )
#     ax.legend(ncol=2,labelspacing=0.01, loc="lower left")

#     ax.set_xlim(left=10**(-3), right=10**(1))

# save_fig(fig, relative_path="1000y_run.pdf")

In [None]:
fig, axs = plt.subplots(
    nrows=2, ncols=2, layout="constrained", sharex=True, sharey=True
)
kwargs = dict(
    linestyle="-",
    linewidth="1",
    marker=".",
)

idx = 0  # x index to plot the text at

axs_spunge = axs[:, 0]
axs_oscill = axs[:, 1]
axs_spunge[0].set_title("Sponge Ocean")
axs_oscill[0].set_title("Oscillatory Ocean")

# ------------------
# plot SAT
# ------------------
var = "SAT"
for current_dataset, ax in zip(
    [sponge_data, oscillator_data], [axs_spunge[0], axs_oscill[0]]
):
    color = variables_color[var]
    x = current_dataset[var].values
    frequencies, spectrum = signal.welch(x=x, fs=fs, window="hann", nperseg=len(x))
    ax.loglog(frequencies, spectrum, label=f"{var}", color=color, alpha=0.7)
    frequencies, spectrum = signal.welch(
        x=x, fs=fs, window="hann", nperseg=int(welch_window_width * fs)
    )
    ax.loglog(
        frequencies,
        spectrum,
        label=f"{var} welch",
        color=adjust_lightness(color, lightness_0),
    )
    frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
        frequencies=frequencies, spectrum=spectrum, weights="f_inverse"
    )
    slope = regression.coef_[0, 0]
    ax.loglog(
        frequencies_linear,
        spectrum_linear,
        color=adjust_lightness(color, lightness_1),
    )
    ax.text(
        frequencies_linear[idx],
        spectrum_linear[idx],
        f"m={slope:.2f}",
        ha="right",
        va="bottom",
        bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
        color=adjust_lightness(color, lightness_1),
    )

# ------------------
# plot SST
# ------------------
var = "SST"
for current_dataset, ax, f_low in zip(
    [sponge_data, oscillator_data],  # datasets
    [axs_spunge[1], axs_oscill[1]],  # axes object to plot into
    [1 / lambda0, 1 / omega0],  # lowest frequency to use for linear regression
):
    x = current_dataset[var].values
    color = variables_color[var]
    frequencies, spectrum = signal.welch(x=x, fs=fs, window="hann", nperseg=len(x))
    h = ax.loglog(
        frequencies,
        spectrum,
        label=f"{var}",
        color=color,
        alpha=0.7,
    )
    frequencies, spectrum = signal.welch(
        x=x, fs=fs, window="hann", nperseg=int(welch_window_width * fs)
    )
    ax.loglog(
        frequencies,
        spectrum,
        label=f"{var} welch",
        color=adjust_lightness(color, lightness_0),
    )
    frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
        frequencies=frequencies,
        spectrum=spectrum,
        f_low=1 / lambda0,
        weights="f_inverse",
    )
    slope = regression.coef_[0, 0]
    ax.loglog(
        frequencies_linear,
        spectrum_linear,
        color=adjust_lightness(color, lightness_1),
    )
    ax.text(
        frequencies_linear[idx],
        spectrum_linear[idx],
        f"m={slope:.2f}",
        ha="right",
        va="bottom",
        bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
        color=adjust_lightness(color, lightness_1),
    )

# ------------------
# plot DOT
# ------------------
var = "DOT"
for current_dataset, ax, f_low in zip(
    [oscillator_data],  # datasets
    [axs_oscill[1]],  # axes object to plot into
    [1 / omega0],  # lowest frequency to use for linear regression
):
    color = variables_color[var]
    x = current_dataset[var].values
    frequencies, spectrum = signal.welch(x=x, fs=fs, window="hann", nperseg=len(x))
    h = ax.loglog(
        frequencies,
        spectrum,
        label=f"{var}",
        alpha=0.7,
        color=color,
    )
    frequencies, spectrum = signal.welch(
        x=x, fs=fs, window="hann", nperseg=int(welch_window_width * fs)
    )
    ax.loglog(
        frequencies,
        spectrum,
        label=f"{var} welch",
        color=adjust_lightness(color, lightness_0),
    )
    frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
        frequencies=frequencies, spectrum=spectrum, f_low=f_low, weights="f_inverse"
    )
    slope = regression.coef_[0, 0]
    ax.loglog(
        frequencies_linear,
        spectrum_linear,
        color=adjust_lightness(color, lightness_1),
    )
    ax.text(
        frequencies_linear[idx],
        spectrum_linear[idx],
        f"m={slope:.2f}",
        ha="right",
        va="bottom",
        bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
        color=adjust_lightness(color, lightness_1),
    )


# ------------------------------
# plot latent
# ------------------------------
latent_f_low = 0
var = "latent"
for current_dataset, ax, f_low in zip(
    [sponge_kalman_states, oscillator_kalman_states],
    [axs_spunge[1], axs_oscill[1]],
    [latent_f_low, latent_f_low],
):
    color = variables_color[var]
    x = current_dataset["latent"].values
    frequencies, spectrum = signal.welch(x=x, fs=fs, window="hann", nperseg=len(x))
    ax.loglog(frequencies, spectrum, label="latent", alpha=0.7, color=color)
    frequencies, spectrum = signal.welch(
        x=x, fs=fs, window="hann", nperseg=int(welch_window_width * fs)
    )
    ax.loglog(
        frequencies,
        spectrum,
        label="latent welch",
        color=adjust_lightness(color, lightness_0),
    )
    frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
        frequencies=frequencies,
        spectrum=spectrum,
        weights="f_inverse",
        f_low=f_low,
    )
    slope = regression.coef_[0, 0]
    ax.loglog(
        frequencies_linear,
        spectrum_linear,
        color=adjust_lightness(color, lightness_1),
    )
    ax.text(
        frequencies_linear[idx],
        spectrum_linear[idx],
        f"m={slope:.2f}",
        ha="right",
        va="top",
        bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
        color=adjust_lightness(color, lightness_1),
    )


for ax in axs.flatten():
    ax.axvline(1 / lambda0, color=dark_color, linestyle=":")  # , label = fr"$\tau_0$")
    ax.text(
        1 / lambda0,
        10 ** (-6.5),
        r"$\pi/(2\tau_0)$",
        ha="left",
        va="bottom",
        c=dark_color,
        rotation=90,
    )
    ax.axvline(
        1 / omega0, color=dark_color, linestyle="--"
    )  # , label = fr"$\omega_0$")
    ax.text(
        1 / omega0,
        10 ** (-6.5),
        r"1/$\omega_0$",
        ha="right",
        va="bottom",
        c=dark_color,
        rotation=90,
    )
    ax.axvline(
        1 / welch_window_width, color=dark_color, linestyle="-"
    )  # , label = f"welch window.:\n{welch_window_width}")
    ax.text(
        1 / welch_window_width,
        10 ** (-6.5),
        r"$f_{welch}$",
        ha="right",
        va="bottom",
        c=dark_color,
        rotation=90,
    )
    ax.legend(ncol=2, labelspacing=0.01, loc="lower left")

fig.suptitle(
    f"Power Density Spectrum of Latent variable, Observations and hidden Component"
)

fig_name = f"Frequency_{years}y"
save_fig(fig=fig, relative_path=fig_name + ".pdf")
save_fig(fig=fig, relative_path=fig_name + ".png")

In [None]:
fig, axs = plt.subplots(
    nrows=1,
    ncols=2,
    layout="constrained",
)
for ax, current_kalman_states, current_data, title in zip(
    axs,
    [sponge_kalman_states, oscillator_kalman_states],
    [sponge_data, oscillator_data],
    ["Sponge Ocean", "Oscillatory Ocean"],
):
    reconst = normalize(current_dataset, method="11")
    truth = normalize(current_data, method="11")
    for var in truth.data_vars:
        try:
            corr = xr.corr(truth[var], reconst["latent"])
            ax.scatter(
                truth[var],
                reconst["latent"],
                marker=".",
                color=variables_color[var],
                alpha=np.abs(corr.values) / 1.2,
                label=f"{var} : {corr:.2f}",
            )
        except:
            pass
    ax.set_title(title)
    ax.set_xlabel("true state")
    ax.set_ylabel("latent variable")
    ax.legend(
        markerscale=3,
        handler_map=handler_map_alpha(),
    )
    symmetrize_axis(axes=ax, axis=0)
    symmetrize_axis(axes=ax, axis=1)


fig.suptitle(f"Latent Variable against true States (Observations and Hidden)")
fig_name = f"Truth_against_Latent_{years}y"
save_fig(fig=fig, relative_path=fig_name + ".pdf")
save_fig(fig=fig, relative_path=fig_name + ".png")

In [None]:
fig, axs = plt.subplots(
    nrows=1, ncols=2, layout="constrained", sharex=True, sharey=True
)

stepfill_kwargs = dict(
    bins=31,
    histtype="stepfilled",
    density=True,
    alpha=0.1,
)
step_kwargs = dict(
    bins=31,
    histtype="step",
    density=True,
    alpha=1,
    linewidth=3,
)

axs_spunge = axs[0]
axs_oscill = axs[1]
axs_spunge.set_title("Sponge Ocean")
axs_oscill.set_title("Oscillatory Ocean")

# Plot Spunge distribution
for var in ["latent"]:
    his = axs_spunge.hist(
        normalize(sponge_kalman_states[var], method="mean"),
        label=var,
        color=variables_color[var],
        **step_kwargs,
    )
    axs_spunge.hist(
        normalize(sponge_kalman_states[var], method="mean"),
        color=variables_color[var],
        **stepfill_kwargs,
    )

# Plot Oscilltor distribution
for var in ["latent"]:
    his = axs_oscill.hist(
        normalize(oscillator_kalman_states[var], method="mean"),
        label=var,
        color=variables_color[var],
        **step_kwargs,
    )
    axs_oscill.hist(
        normalize(oscillator_kalman_states[var], method="mean"),
        color=variables_color[var],
        **stepfill_kwargs,
    )

# Plot normal distribution
import scipy.stats as stats

mu = 0
variance = 1
width = 4
sigma = np.sqrt(variance)
x = np.linspace(mu - width * sigma, mu + width * sigma, 100)
for ax in axs.flatten():
    ax.plot(
        x,
        stats.norm.pdf(x, mu, sigma),
        color=dark_color,
        label=r"$\mathcal{N}(0,1)$",
        linewidth=4,
    )
    ax.legend()
    ax.set_xlabel("Normalized Values")
    symmetrize_axis(axes=ax, axis="x")
axs[0].set_ylabel("Probability Density")
fig.suptitle("Normalized PDD latent variable")

fig_name = f"Distribution_{years}y_latent"
save_fig(fig=fig, relative_path=fig_name + ".pdf")
save_fig(fig=fig, relative_path=fig_name + ".png")

# Evaluation 

In [None]:
def all_choords_as_dims(ds):
    for dim in ds.coords:
        try:
            ds = ds.expand_dims(dim)
        except:
            pass
    return ds


import seaborn as sns

## Sponge Ocean

In [None]:
ocean = "Sponge Ocean"

select_dict = dict(tau0=10 * 365.25, per0=24 * 365.25)

"polite-eel-349"


data_path_1 = r"C:\Users\Niebaum\Documents\Repositories\reconstruct-climate-indices\data\Evaluation-Idealized-Ocean\flawless-loon-25\flawless-loon-25"
data_path_2 = r"C:\Users\Niebaum\Documents\Repositories\reconstruct-climate-indices\data\Evaluation-Idealized-Ocean\polite-eel-349\polite-eel-349"

test_1 = xr.open_dataset(data_path_1 + "_kalman.nc")
test_input_1 = xr.open_dataset(data_path_1 + "_input.nc")
test_2 = xr.open_dataset(data_path_2 + "_kalman.nc")
test_input_2 = xr.open_dataset(data_path_2 + "_input.nc")

test = xr.merge([test_1, test_2])
test_input = xr.merge([test_input_1, test_input_2])


# test = normalize(test, dim = "time", method = "mean")
test_kalman_states = from_standard_dataset(ds=test, dim="time")

number_of_runs = len(test.seed)

In [None]:
fig, ax = plt.subplots(1, 1)
for t in test.tau0:
    # for p in test.per0:
    select_dict = dict(
        tau0=t,
        # per0 = p,
    )
    m = test.sel(select_dict)["log_likelihod"].mean(dim="seed")
    s = test.sel(select_dict)["log_likelihod"].std(dim="seed")
    plot_state_with_probability(
        ax=ax, x_value=test["kalman_itteration"], state=m, prob=s, stds=5
    )
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.set_title(
    f"Solid as mean and shaded as 2 std based on {number_of_runs} random runs for each parameter set"
)
fig.suptitle(f"{ocean} | Log likelihood evolution over kalman itterations.")

Text(0.5, 0.98, 'Sponge Ocean | Log likelihood evolution over kalman itterations.')

In [None]:
# Calculate Loglikelihood increase
ds_lli = test["log_likelihod"].isel(kalman_itteration=-1) - test["log_likelihod"].isel(
    kalman_itteration=0
)
ds_lli = ds_lli.expand_dims(df=[0.115])
# PLOT
heatmap_kwargs = dict(
    xticklabels=ds_lli.tau0.values / 365.25,
    yticklabels=ds_lli.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=2.6,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = np.abs(ds_lli).mean(dim="seed") * 10 ** (-3)
s = np.abs(ds_lli).std(dim="seed") * 10 ** (-3)
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs} random runs ($\times 1000$)")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
        **dict(vmin=0, vmax=1),
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$df$ in y")

fig.suptitle(f"{ocean} | Log likelihood increase over kalman itterations.")

Text(0.5, 0.98, 'Sponge Ocean | Log likelihood increase over kalman itterations.')

In [None]:
var = "SAT"
ds_corr = crosscorr(
    ds1=test_kalman_states["latent"],
    ds2=test_input[var],
    dim="time",
)
# PLOT

ds_corr = ds_corr.expand_dims(df=[0.115])
# PLOT
heatmap_kwargs = dict(
    xticklabels=ds_corr.tau0.values / 365.25,
    yticklabels=ds_corr.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = np.abs(ds_corr).mean(dim="seed")
s = np.abs(ds_corr).std(dim="seed")
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(f"Mean of {number_of_runs} random runs")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
        **dict(vmin=0, vmax=1),
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$\omega_0$ in y")

fig.suptitle(f"{ocean} | Correlation Coefficient {var} to latent.")

Text(0.5, 0.98, 'Sponge Ocean | Correlation Coefficient SAT to latent.')

## Oscillatory Ocean

In [None]:
ocean = "Oscillatory Ocean"
data_path_1 = r"C:\Users\Niebaum\Documents\Repositories\reconstruct-climate-indices\data\Evaluation-Idealized-Ocean\dapper-fox-131\dapper-fox-131"
data_path_2 = r"C:\Users\Niebaum\Documents\Repositories\reconstruct-climate-indices\data\Evaluation-Idealized-Ocean\respected-fowl-948\respected-fowl-948"

test_1 = xr.open_dataset(data_path_1 + "_kalman.nc")
test_input_1 = xr.open_dataset(data_path_1 + "_input.nc")
test_2 = xr.open_dataset(data_path_2 + "_kalman.nc")
test_input_2 = xr.open_dataset(data_path_2 + "_input.nc")

test = xr.merge([test_1, test_2])
test_input = xr.merge([test_input_1, test_input_2])


# test = normalize(test, dim = "time", method = "mean")
test_kalman_states = from_standard_dataset(ds=test, dim="time")

number_of_runs = len(test.seed)

In [None]:
fig, ax = plt.subplots(1, 1)
for t in test.tau0:
    for p in test.per0:
        select_dict = dict(
            tau0=t,
            per0=p,
        )
        m = test.sel(select_dict)["log_likelihod"].mean(dim="seed")
        s = test.sel(select_dict)["log_likelihod"].std(dim="seed")
        plot_state_with_probability(
            ax=ax, x_value=test["kalman_itteration"], state=m, prob=s, stds=2
        )
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.set_title(
    f"Solid as mean and shaded as 2 std based on {number_of_runs} random runs for each parameter set"
)
fig.suptitle(f"{ocean} | Log likelihood evolution over kalman itterations.")

Text(0.5, 0.98, 'Oscillatory Ocean | Log likelihood evolution over kalman itterations.')

In [None]:
# Calculate Loglikelihood increase

ds_lli = test["log_likelihod"].isel(kalman_itteration=-1) - test["log_likelihod"].isel(
    kalman_itteration=0
)


# PLOT
heatmap_kwargs = dict(
    xticklabels=ds_lli.tau0.values / 365.25,
    yticklabels=ds_lli.per0.values / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=2.6,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = np.abs(ds_lli).mean(dim="seed") * 10 ** (-3)
s = np.abs(ds_lli).std(dim="seed") * 10 ** (-3)
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs} random runs ($\times 1000$)")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
        **dict(vmin=0, vmax=1),
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$df$ in y")

fig.suptitle(f"{ocean} | Log likelihood increase over kalman itterations.")

Text(0.5, 0.98, 'Oscillatory Ocean | Log likelihood increase over kalman itterations.')

In [None]:
var = "DOT"
ds_corr = crosscorr(
    ds1=test_kalman_states["latent"],
    ds2=test_input[var],
    dim="time",
)
# PLOT

heatmap_kwargs = dict(
    xticklabels=np.flip(ds_corr.tau0.values) / 365.25,
    yticklabels=np.flip(ds_corr.per0.values) / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = np.abs(ds_corr).mean(dim="seed").values
s = np.abs(ds_corr).std(dim="seed").values
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(f"Mean of {number_of_runs} random runs")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
        **dict(vmin=0, vmax=1),
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$\omega_0$ in y")

fig.suptitle(f"{ocean} | Correlation Coefficient {var} to latent.")

Text(0.5, 0.98, 'Oscillatory Ocean | Correlation Coefficient DOT to latent.')

In [None]:
select_dict = dict(tau0=15 * 365.25, per0=24 * 365.25)
oscillator_single = test_kalman_states.sel(select_dict)

In [None]:
fs = 12
welch_window_width = 100
fig, axs = plt.subplots(
    nrows=1, ncols=1, layout="constrained", sharex=True, sharey=True
)
kwargs = dict(
    linestyle="-",
    linewidth="1",
    marker=".",
)
ax = axs
idx = 0  # x index to plot the text at

# ------------------
# plot SAT
# ------------------
var = "latent"
for s in oscillator_single.seed:
    color = variables_color[var]
    x = oscillator_single.sel(seed=s)[var].values
    # frequencies, spectrum = signal.welch(
    #     x = x,
    #     fs=fs,
    #     window="hann",
    #     nperseg = len(x)
    # )
    # ax.loglog(
    #     frequencies,
    #     spectrum,
    #     label = f"{var}",
    #     color = color,
    #     alpha = 0.7
    # )
    frequencies, spectrum = signal.welch(
        x=x, fs=fs, window="hann", nperseg=int(welch_window_width * fs)
    )
    ax.loglog(
        frequencies,
        spectrum,
        label=f"{var} welch",
        color=adjust_lightness(color, lightness_0),
        alpha=0.5,
    )
    frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
        frequencies=frequencies, spectrum=spectrum, weights="f_inverse"
    )
    slope = regression.coef_[0, 0]
    ax.loglog(
        frequencies_linear,
        spectrum_linear,
        color=adjust_lightness(color, lightness_1),
        alpha=0.5,
    )
    ax.text(
        frequencies_linear[idx],
        spectrum_linear[idx],
        f"m={slope:.2f}",
        ha="right",
        va="bottom",
        bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
        color=adjust_lightness(color, lightness_1),
    )


ax.set_xlim(xmin=10 ** (-4))

(0.0001, 8.261498206262935)

In [None]:
oscillator_single