In [None]:
import numpy as np
from pathlib import Path
from kalman_reconstruction.custom_plot import (
    set_custom_rcParams,
    adjust_lightness,
    handler_map_alpha,
    plot_colors,
    symmetrize_axis,
)
from reconstruct_climate_indices.idealized_ocean import sponge_ocean, oscillatory_ocean
from reconstruct_climate_indices.statistics import linear_regression_loglog
import matplotlib.pyplot as plt
import seaborn as sns
from kalman_reconstruction.statistics import normalize

# from sklearn.linear_model import LinearRegression
from scipy import signal

In [None]:
150 * 365.25 / 30.4375

1800.0

In [None]:
## LIGHT THEME
plt.style.use("seaborn-v0_8-whitegrid")
dark_color = [0.3, 0.3, 0.3]
light_color = [0.8, 0.8, 0.8]
lightness_0 = 0.75
lightness_1 = 0.5
cmap = "rocket"
cmap_r = "rocket_r"

# ### DARK THEME
# plt.style.use("dark_background")
# dark_color = [0.7, 0.7, 0.7]
# light_color = [0.2, 0.2, 0.2]
# lightness_0 = 1.15
# lightness_1 = 1.5
# cmap = "rocket_r"
# cmap_r = "rocket"


colors = set_custom_rcParams()
plt.rcParams["axes.grid"] = True

plot_colors(colors)

variables_color = dict()
variables_color["SAT"] = colors[0]
variables_color["SST"] = colors[2]
variables_color["DOT"] = colors[1]
variables_color["latent"] = colors[3]

In [None]:
REPO_PATH = Path(".").resolve().parent
results_path = REPO_PATH / Path("results") / "Report" / "data" / "idealized_ocean"
results_path.mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    store_path = results_path / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

In [None]:
# 50 years
dt = 30.4375
years = 1000
time_steps = years * 365.25 / dt

tau0 = 10  # years
lambda0 = np.pi * tau0 / 2  # years
omega0 = 24  # years
fs = 365.25 / dt  # 1/years

welch_window_width = 250  # years

sponge_settings = dict(
    nt=time_steps,
    dt=dt,
    df=1.15e-1,
    tau0=tau0 * 365.25,
    save_path=None,
)
sponge_data = sponge_ocean(**sponge_settings)
# display(data)
oscillatory_settings = dict(
    nt=time_steps,
    dt=dt,
    df=1.15e-1,
    tau0=tau0 * 365.25,
    per0=omega0 * 365.25,
    save_path=None,
)
oscillatory_data = oscillatory_ocean(**oscillatory_settings)

print(rf"$\lambda_0$ = {lambda0}")
print(rf"$\omega_0$ = {omega0}")

SAT_units = r"$K d^{-0.5}$"
SST_DOT_units = r"$K$"

# NOTE:
# The PSD units for the SAT are originally given in K^{2} d^{-1} y^{-1} but this is uglly
# So it is better to multuply the SAT by np.sqrt(1/365.25) to get it in K y^{-0.5}
SAT_factor = np.sqrt(
    1 / 365.25
)  # factor by which the input to the fft function should be multiplied to get the units mentioned above
SAT_PSD_units = r"$K^2 y^{-2}$"
SST_DOT_PSD_units = r"$K^2 y^{-1}$"

$\lambda_0$ = 15.707963267948966
$\omega_0$ = 24


They were modified and are now: timesteps = 12000
  warn(


### Check the noramlity of the data

In [None]:
fig, axs = plt.subplots(
    nrows=1, ncols=2, layout="constrained", sharex=True, sharey=True
)

stepfill_kwargs = dict(
    bins=31,
    histtype="stepfilled",
    density=True,
    alpha=0.1,
)
step_kwargs = dict(
    bins=31,
    histtype="step",
    density=True,
    alpha=1,
    linewidth=3,
)

axs_sponge = axs[0]
axs_oscill = axs[1]
axs_sponge.set_title("Sponge Ocean")
axs_oscill.set_title("Oscillatory Ocean")

# Plot sponge distribution
for var in ["SAT", "SST"]:
    his = axs_sponge.hist(
        normalize(sponge_data[var], method="mean"),
        label=var,
        color=variables_color[var],
        **step_kwargs,
    )
    axs_sponge.hist(
        normalize(sponge_data[var], method="mean"),
        color=variables_color[var],
        **stepfill_kwargs,
    )

# Plot Oscilltor distribution
for var in ["SAT", "SST", "DOT"]:
    his = axs_oscill.hist(
        normalize(oscillatory_data[var], method="mean"),
        label=var,
        color=variables_color[var],
        **step_kwargs,
    )
    axs_oscill.hist(
        normalize(oscillatory_data[var], method="mean"),
        color=variables_color[var],
        **stepfill_kwargs,
    )

# Plot normal distribution
import scipy.stats as stats

mu = 0
variance = 1
width = 4
sigma = np.sqrt(variance)
x = np.linspace(mu - width * sigma, mu + width * sigma, 100)
for ax in axs.flatten():
    ax.plot(
        x,
        stats.norm.pdf(x, mu, sigma),
        color=dark_color,
        label=r"$\mathcal{N}(0,1)$",
        linewidth=4,
    )
    ax.legend()
    ax.set_xlabel("Normalized Values")
    symmetrize_axis(axes=ax, axis="x")
axs[0].set_ylabel("Probability Density")

save_fig(fig=fig, relative_path=f"Distribution_{years}y.pdf")

### Plot the evolution for the firsr 50 yeras

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, layout="constrained", sharex=True)
kwargs = dict(linestyle="-", linewidth="2", marker=".", alpha=0.75)

idx = 0  # x index to plot the text at

axs_sponge = axs[:, 0]
axs_oscill = axs[:, 1]
axs_sponge[0].set_title("Sponge Ocean")
axs_oscill[0].set_title("Oscillatory Ocean")

# -------------------
# Plot SPONGE OCEAN
# -------------------

# Plot SAT
var = "SAT"
axs_sponge[0].plot(
    sponge_data["time_years"],
    sponge_data[var],
    label=var,
    color=variables_color[var],
    **kwargs,
)

# Plot SST
var = "SST"
axs_sponge[1].plot(
    sponge_data["time_years"],
    sponge_data[var],
    label=var,
    color=variables_color[var],
    **kwargs,
)


# -------------------
# Plot Oscillatory OCEAN
# -------------------
var = "SAT"
axs_oscill[0].plot(
    oscillatory_data["time_years"],
    oscillatory_data[var],
    label=var,
    color=variables_color[var],
    **kwargs,
)

for var in ["SST", "DOT"]:
    axs_oscill[1].plot(
        oscillatory_data["time_years"],
        oscillatory_data[var],
        label=var,
        color=variables_color[var],
        **kwargs,
    )

# Label and lims for SAT
axs_sponge[0].set_ylabel(f"{SAT_units}")
axs_oscill[0].set_ylabel(f"{SAT_units}")
axs_oscill[0].set_ylim([-0.5, 0.5])
axs_sponge[0].set_ylim([-0.5, 0.5])

# Label and lims for SAT
axs_sponge[1].set_ylabel(f"{SST_DOT_units}")
axs_oscill[1].set_ylabel(f"{SST_DOT_units}")
axs_oscill[1].set_ylim([-12, 12])
axs_sponge[1].set_ylim([-12, 12])

# set xlabel
axs_oscill[1].set_xlabel("Time in y")
axs_sponge[1].set_xlabel("Time in y")

for ax in axs.flatten():
    ax.set_xlim([0, 150])
    ax.legend(
        ncols=2,
        loc="lower left",
        handlelength=1,
        handletextpad=0.15,
        columnspacing=0.2,
        handler_map=handler_map_alpha(),
    )
save_fig(fig=fig, relative_path=f"Evolution_{years}y.pdf")

### Plot the Frequency analysis

In [None]:
def plot_all_frequency_analysis(
    x: np.ndarray, ax: plt.Axes, color: str, f_low: float, idx: int, var: str
) -> None:
    # no welch method applied:
    # window = boxcar,
    # nperseg = len(x)
    frequencies, spectrum = signal.welch(x=x, fs=fs, window="boxcar", nperseg=len(x))
    ax.loglog(
        frequencies,
        spectrum,
        label=f"{var}",
        alpha=0.7,
        color=color,
    )

    # no Welch method applied:
    # window = boxcar,
    # nperseg = len(x)
    frequencies, spectrum = signal.welch(
        x=x,
        fs=fs,
        window="hann",
        nperseg=int(welch_window_width * fs),
    )
    ax.loglog(
        frequencies,
        spectrum,
        label=f"{var} welch",
        color=adjust_lightness(color, lightness_0),
    )

    # create linear regression in loglog space
    frequencies_linear, spectrum_linear, regression = linear_regression_loglog(
        frequencies=frequencies,
        spectrum=spectrum,
        weights="f_inverse",
        f_low=f_low,
    )
    slope = regression.coef_[0, 0]
    ax.loglog(
        frequencies_linear,
        spectrum_linear,
        color=adjust_lightness(color, lightness_1),
    )
    ax.text(
        frequencies_linear[idx],
        spectrum_linear[idx],
        f"m= {slope:.2f}",
        ha="right",
        va="top",
        bbox=dict(facecolor=light_color, edgecolor="None", alpha=0.25),
        color=adjust_lightness(color, lightness_1),
    )

In [None]:
fig, axs = plt.subplots(
    nrows=2, ncols=2, layout="constrained", sharex=True, sharey=True
)
kwargs = dict(
    linestyle="-",
    linewidth="1",
    marker=".",
)

idx = 1  # x index to plot the text at

axs_sponge = axs[:, 0]
axs_oscill = axs[:, 1]
axs_sponge[0].set_title("Sponge Ocean")
axs_oscill[0].set_title("Oscillatory Ocean")

# ------------------
# plot SAT
# ------------------
var = "SAT"
for current_dataset, ax in zip(
    [sponge_data, oscillatory_data], [axs_sponge[0], axs_oscill[0]]
):
    color = variables_color[var]
    x = current_dataset[var].values * SAT_factor
    plot_all_frequency_analysis(
        x=x,
        ax=ax,
        color=color,
        f_low=0,
        idx=idx,
        var=var,
    )


# ------------------
# plot SST
# ------------------
var = "SST"
for current_dataset, ax, f_low in zip(
    [sponge_data, oscillatory_data],  # datasets
    [axs_sponge[1], axs_oscill[1]],  # axes object to plot into
    [1 / lambda0, 1 / omega0],  # lowest frequency to use for linear regression
):
    color = variables_color[var]
    x = current_dataset[var].values
    plot_all_frequency_analysis(
        x=x,
        ax=ax,
        color=color,
        f_low=f_low,
        idx=idx,
        var=var,
    )


# ------------------
# plot DOT
# ------------------
var = "DOT"
for current_dataset, ax, f_low in zip(
    [oscillatory_data],  # datasets
    [axs_oscill[1]],  # axes object to plot into
    [1 / omega0],  # lowest frequency to use for linear regression
):
    color = variables_color[var]
    x = current_dataset[var].values
    plot_all_frequency_analysis(
        x=x,
        ax=ax,
        color=color,
        f_low=f_low,
        idx=idx,
        var=var,
    )


for ax in axs.flatten():
    # plot 1 / lambda9
    ax.axvline(1 / lambda0, color=dark_color, linestyle=":", alpha=0.5)
    # plot 1 / omega0
    ax.axvline(1 / omega0, color=dark_color, linestyle="--", alpha=0.5)
    # plot 1 / welch window width
    ax.axvline(1 / welch_window_width, color=dark_color, linestyle="-", alpha=0.5)
    # plot legend
    ax.legend(
        ncols=2,
        loc="lower left",
        handlelength=1,
        labelspacing=0.01,
        handletextpad=0.15,
        columnspacing=0.2,
    )

# fig.suptitle(
#     f"Power Density Spectrum of Latent variable, Observations and hidden Component"
# )

axs_sponge[0].set_ylabel(f"PSD [{SAT_PSD_units}]")
axs_sponge[1].set_ylabel(f"PSD [{SST_DOT_PSD_units}]")

axs_oscill[1].set_xlabel(r"$f$ in $y^{-1}$")
axs_sponge[1].set_xlabel(r"$f$ in $y^{-1}$")
for ax in axs.flatten():
    ax.set_ylim(ymin=10 ** (-10))
    ax.set_ylim(ymax=10 ** (6))

save_fig(fig=fig, relative_path=f"Frequency_{years}y.pdf")