In [None]:
import numpy as np
import xarray as xr
import yaml
from pathlib import Path
from kalman_reconstruction.custom_plot import (
    set_custom_rcParams,
    adjust_lightness,
    plot_colors,
    symmetrize_axis,
    handler_map_alpha,
    symmetrize_axis,
    plot_state_with_probability,
)
from reconstruct_climate_indices.idealized_ocean import spunge_ocean, oscillatory_ocean
from reconstruct_climate_indices.statistics import (
    linear_regression_loglog,
    power_density_spectrum,
    xarray_dataset_welch,
    xarray_dataarray_welch,
)
from kalman_reconstruction.pipeline import from_standard_dataset
from kalman_reconstruction.statistics import normalize, crosscorr
import matplotlib.pyplot as plt

# from sklearn.linear_model import LinearRegression
from scipy import signal

In [None]:
# dark_color = "k"
# lightness_0 = 0.75
# lightness_1 = 0.5
plt.style.use("dark_background")
dark_color = [0.7, 0.7, 0.7]
light_color = [0.2, 0.2, 0.2]
lightness_0 = 1.15
lightness_1 = 1.5
set_custom_rcParams()
plt.rcParams["axes.grid"] = False
colors = [
    "#CC6677",
    "#6E9CB3",
    "#CA8727",
    "#44AA99",
    "#AA4499",
    "#D6BE49",
    "#A494F5",
]
plot_colors(colors)
variables_color = dict()
variables_color["SAT"] = colors[0]
variables_color["SST"] = colors[2]
variables_color["DOT"] = colors[1]
variables_color["latent"] = colors[3]
variables_color["loglikelihood"] = (colors[-1],)

In [None]:
REPO_PATH = Path(".").resolve().parent
results_path = REPO_PATH / Path("results") / "Report" / "results" / "Example"
results_path.mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    store_path = results_path / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

# Evaluation 

In [None]:
def all_choords_as_dims(ds):
    for dim in ds.coords:
        try:
            ds = ds.expand_dims(dim)
        except:
            pass
    return ds


import seaborn as sns

## Sponge Ocean

In [None]:
run_names = [
    "polite-eel-349",
    "flawless-loon-25",
]

kalman_list = []
input_list = []
for name in run_names:
    data_path = REPO_PATH / "data" / "Evaluation-Idealized-Ocean" / name
    kalman_list.append(xr.open_dataset(data_path / (name + "_kalman.nc")))
    input_list.append(xr.open_dataset(data_path / (name + "_input.nc")))
sponge_kalman = xr.merge(kalman_list)
sponge_input = xr.merge(input_list)


# test = normalize(test, dim = "time", method = "mean")
sponge_kalman_states = from_standard_dataset(ds=sponge_kalman, dim="time")

number_of_runs = len(sponge_kalman.seed)

In [None]:
ocean = "Sponge Ocean"
fig, ax = plt.subplots(1, 1)
for t in sponge_kalman.tau0:
    # for p in test.per0:
    select_dict = dict(
        tau0=t,
        # per0 = p,
    )
    m = sponge_kalman.sel(select_dict)["log_likelihod"].mean(dim="seed")
    s = sponge_kalman.sel(select_dict)["log_likelihod"].std(dim="seed")
    plot_state_with_probability(
        ax=ax, x_value=sponge_kalman["kalman_itteration"], state=m, prob=s, stds=5
    )
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.set_title(
    f"Solid as mean and shaded as 2 std based on {number_of_runs} random runs for each parameter set"
)
fig.suptitle(f"{ocean} | Log likelihood evolution over kalman itterations.")

Text(0.5, 0.98, 'Sponge Ocean | Log likelihood evolution over kalman itterations.')

In [None]:
ocean = "Sponge Ocean"
# Calculate Loglikelihood increase
sponge_kalman["log_likelihod_increase"] = sponge_kalman["log_likelihod"].isel(
    kalman_itteration=-1
) - sponge_kalman["log_likelihod"].isel(kalman_itteration=0)
da_lli = sponge_kalman["log_likelihod_increase"]
da_lli = da_lli.expand_dims(df=[0.115])
# PLOT
heatmap_kwargs = dict(
    xticklabels=da_lli.tau0.values / 365.25,
    yticklabels=da_lli.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=2.6,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = da_lli.mean(dim="seed") * 10 ** (-3)
s = da_lli.std(dim="seed") * 10 ** (-3)
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs} random runs ($\times 1000$)")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
        **dict(vmin=0, vmax=1),
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$df$ in y")

fig.suptitle(f"{ocean} | Log likelihood increase over kalman itterations.")

Text(0.5, 0.98, 'Sponge Ocean | Log likelihood increase over kalman itterations.')

In [None]:
# calculate all correaltions
corr_list = []

for var in sponge_input.data_vars:
    temp = crosscorr(
        ds1=sponge_kalman_states["latent"],
        ds2=sponge_input[var],
        dim="time",
    )
    temp.name = var
    corr_list.append(temp)
sponge_corr = xr.merge(corr_list)
sponge_corr = sponge_corr.expand_dims(df=[0.115])

In [None]:
var = "SAT"
da_corr = sponge_corr[var]

# PLOT

heatmap_kwargs = dict(
    xticklabels=da_corr.tau0.values / 365.25,
    yticklabels=da_corr.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = np.abs(da_corr).mean(dim="seed")
s = np.abs(da_corr).std(dim="seed")
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(f"Mean of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_xlabel(r"$\tau_0$ in y")
    ax.set_ylabel(r"$\omega_0$ in y")

fig.suptitle(f"{ocean} | Correlation Coefficient {var} to latent.")

Text(0.5, 0.98, 'Sponge Ocean | Correlation Coefficient SAT to latent.')

In [None]:
# Check correct order
m

## Oscillatory Ocean

In [None]:
run_names = [
    "dapper-fox-131",
    "respected-fowl-948",
]

kalman_list = []
input_list = []
for name in run_names:
    data_path = REPO_PATH / "data" / "Evaluation-Idealized-Ocean" / name
    kalman_list.append(xr.open_dataset(data_path / (name + "_kalman.nc")))
    input_list.append(xr.open_dataset(data_path / (name + "_input.nc")))
oscillatory_kalman = xr.merge(kalman_list)
oscillatory_input = xr.merge(input_list)


# test = normalize(test, dim = "time", method = "mean")
oscillatory_kalman_states = from_standard_dataset(ds=oscillatory_kalman, dim="time")

number_of_runs = len(oscillatory_kalman.seed)

In [None]:
ocean = "Oscillatory Ocean"
fig, ax = plt.subplots(1, 1)
for t in oscillatory_kalman.tau0:
    for p in oscillatory_kalman.per0:
        select_dict = dict(
            tau0=t,
            per0=p,
        )
        m = oscillatory_kalman.sel(select_dict)["log_likelihod"].mean(dim="seed")
        s = oscillatory_kalman.sel(select_dict)["log_likelihod"].std(dim="seed")
        plot_state_with_probability(
            ax=ax,
            x_value=oscillatory_kalman["kalman_itteration"],
            state=m,
            prob=s,
            stds=2,
        )
ax.set_xlabel("Itterations of Kalman-SEM")
ax.set_ylabel("Loglikelihood")
ax.set_title(
    f"Solid as mean and shaded as 2 std based on {number_of_runs} random runs for each parameter set"
)
fig.suptitle(f"{ocean} | Log likelihood evolution over kalman itterations.")

Text(0.5, 0.98, 'Oscillatory Ocean | Log likelihood evolution over kalman itterations.')

In [None]:
ocean = "Oscillatory Ocean"
# Calculate Loglikelihood increase
oscillatory_kalman["log_likelihod_increase"] = oscillatory_kalman["log_likelihod"].isel(
    kalman_itteration=-1
) - oscillatory_kalman["log_likelihod"].isel(kalman_itteration=0)
da_lli = oscillatory_kalman["log_likelihod_increase"]
# PLOT
heatmap_kwargs = dict(
    xticklabels=da_lli.per0.values / 365.25,
    yticklabels=da_lli.tau0.values / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=2.6,
)


fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = da_lli.mean(dim="seed") * 10 ** (-3)
s = da_lli.std(dim="seed") * 10 ** (-3)
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(rf"Mean of {number_of_runs} random runs ($\times 1000$)")
# sns.heatmap(s,
#             ax = axs[1],
#             **heatmap_kwargs
#             )
# axs[1].set_title("Std. of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
        **dict(vmin=0, vmax=1),
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_ylabel(r"$\tau_0$ in y")
    ax.set_xlabel(r"$\omega_0$ in y")

fig.suptitle(f"{ocean} | Log likelihood increase over kalman itterations.")

Text(0.5, 0.98, 'Oscillatory Ocean | Log likelihood increase over kalman itterations.')

In [None]:
# calculate all correaltions
corr_list = []

for var in oscillatory_input.data_vars:
    temp = crosscorr(
        ds1=oscillatory_kalman_states["latent"],
        ds2=oscillatory_input[var],
        dim="time",
    )
    temp.name = var
    corr_list.append(temp)
oscillatory_corr = xr.merge(corr_list)

In [None]:
var = "DOT"
da_corr = oscillatory_corr[var]

# PLOT

heatmap_kwargs = dict(
    xticklabels=da_corr.per0.values / 365.25,
    yticklabels=da_corr.tau0.values / 365.25,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=0,
    vmax=1,
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = np.abs(da_corr).mean(dim="seed")
s = np.abs(da_corr).std(dim="seed")
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(f"Mean of {number_of_runs} random runs")
sns.heatmap(
    s / m,
    ax=axs[1],
    **{
        **heatmap_kwargs,
    },
)
axs[1].set_title(f"Std./Mean of {number_of_runs} random runs")

for ax in axs:
    ax.set_ylabel(r"$\tau_0$ in y")
    ax.set_xlabel(r"$\omega_0$ in y")

fig.suptitle(f"{ocean} | Correlation Coefficient {var} to latent.")

Text(0.5, 0.98, 'Oscillatory Ocean | Correlation Coefficient DOT to latent.')

In [None]:
# On can see that the x and y axis need to be flipped
np.abs(oscillatory_corr[var].sel(tau0=5 * 365.25, per0=48 * 365.25)).mean(dim="seed")
np.abs(oscillatory_corr[var].sel(tau0=20 * 365.25, per0=48 * 365.25)).mean(dim="seed")

## Power Spectral Density

In [None]:
sponge_kalman_psd = xarray_dataset_welch(
    sponge_kalman_states, dim="time", welch_kwargs=dict(fs=12)
)
oscillatory_kalman_psd = xarray_dataset_welch(
    oscillatory_kalman_states, dim="time", welch_kwargs=dict(fs=12)
)

sponge_psd = xarray_dataset_welch(sponge_input, dim="time", welch_kwargs=dict(fs=12))
oscillatory_psd = xarray_dataset_welch(
    oscillatory_input, dim="time", welch_kwargs=dict(fs=12)
)

In [None]:
from importlib import reload
import reconstruct_climate_indices.statistics as stati
from sklearn.linear_model import LinearRegression

In [None]:
reload(stati)

<module 'reconstruct_climate_indices.statistics' from 'C:\\Users\\Niebaum\\Documents\\Repositories\\reconstruct-climate-indices\\reconstruct_climate_indices\\statistics.py'>

In [None]:
frequencies, sponge_linear, sponge_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    sponge_psd["frequency"],  # Input frequencies
    sponge_psd,  # Input spectrum
    # sponge_psd["tau0"] / 365.25,
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs={"f_low": -np.inf, "f_high": np.inf, "weights": "f_inverse"},
)

frequencies, oscillatory_linear, oscillatory_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    oscillatory_psd["frequency"],  # Input frequencies
    oscillatory_psd,  # Input spectrum
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs={"f_low": -np.inf, "f_high": np.inf, "weights": "f_inverse"},
)

frequencies, sponge_kalman_linear, sponge_kalman_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    sponge_kalman_psd["frequency"],  # Input frequencies
    sponge_kalman_psd,  # Input spectrum
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs={"f_low": -np.inf, "f_high": np.inf, "weights": "f_inverse"},
)

frequencies, oscillatory_kalman_linear, oscillatory_kalman_regression = xr.apply_ufunc(
    stati.linear_regression_loglog,
    oscillatory_kalman_psd["frequency"],  # Input frequencies
    oscillatory_kalman_psd,  # Input spectrum
    input_core_dims=[["frequency"], ["frequency"]],
    output_core_dims=[["frequency"], ["frequency"], []],
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float, float, object],
    kwargs={"f_low": -np.inf, "f_high": np.inf, "weights": "f_inverse"},
)

In [None]:
def get_slope(x):
    return x.coef_[0][0]


sponge_slopes = xr.apply_ufunc(
    get_slope,
    sponge_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)

oscillatory_slopes = xr.apply_ufunc(
    get_slope,
    oscillatory_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)

sponge_kalman_slopes = xr.apply_ufunc(
    get_slope,
    sponge_kalman_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)
oscillatory_kalman_slopes = xr.apply_ufunc(
    get_slope,
    oscillatory_kalman_regression,  # Input frequencies
    vectorize=True,
    # dask='parallelized',
    # exclude_dims=set(("frequency",)),
    output_dtypes=[float],
)

#### plot the slope

In [None]:
var = "latent"
# ocean = "Oscillatory Ocean"
# da_slopes = oscillatory_kalman_slopes[var]
# heatmap_kwargs = dict(
#     xticklabels=da_slopes.per0.values / 365.25,
#     yticklabels=da_slopes.tau0.values / 365.25,
#     square=True,
#     annot=True,
#     fmt=".2f",
#     vmin = -3,
#     vmax = 0,
#     cmap = "rocket_r"
# )
ocean = "Sponge Ocean"
da_slopes = sponge_kalman_slopes[var]
da_slopes = da_slopes.expand_dims(df=[0.115])
heatmap_kwargs = dict(
    xticklabels=da_slopes.tau0.values / 365.25,
    yticklabels=da_slopes.df.values,
    square=True,
    annot=True,
    fmt=".2f",
    vmin=-3,
    vmax=0,
    cmap="rocket_r",
)

fig, axs = plt.subplots(1, 2, figsize=(14, 6))
m = da_slopes.mean(dim="seed")
s = -da_slopes.std(dim="seed")
sns.heatmap(m, ax=axs[0], **heatmap_kwargs)
axs[0].set_title(f"Mean of {number_of_runs} random runs")
sns.heatmap(
    s,
    ax=axs[1],
    **{
        **heatmap_kwargs,
    },
)
axs[1].set_title(f"Std. of {number_of_runs} random runs (negative)")

for ax in axs:
    ax.set_ylabel(r"$\tau_0$ in y")
    ax.set_xlabel(r"$\omega_0$ in y")

fig.suptitle(f"{ocean} | Slope of linear regression in loglog space for {var}.")

Text(0.5, 0.98, 'Sponge Ocean | Slope of linear regression in loglog space for latent.')