In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path
from kalman_reconstruction.custom_plot import (
    set_custom_rcParams,
    adjust_lightness,
    handler_map_alpha,
    plot_colors,
    symmetrize_axis,
)
from kalman_reconstruction.pipeline import (
    add_random_variable,
    xarray_Kalman_SEM,
    xarray_Kalman_SEM_full_output,
    from_standard_dataset,
)
from kalman_reconstruction.statistics import normalize
from reconstruct_climate_indices.idealized_ocean import spunge_ocean, oscillatory_ocean
from reconstruct_climate_indices.statistics import (
    linear_regression_loglog,
    xarray_dataset_welch,
)
import matplotlib.pyplot as plt
from matplotlib import animation
from kalman_reconstruction.statistics import normalize

# from sklearn.linear_model import LinearRegression
from scipy import signal

In [None]:
# plt.style.use('default')
# dark_color = [0.2, 0.2, 0.2]
# light_color = [0.9, 0.9, 0.9]
# lightness_0 = 0.75
# lightness_1 = 0.5
plt.style.use("dark_background")
dark_color = [0.7, 0.7, 0.7]
light_color = [0.1, 0.1, 0.1]
lightness_0 = 1.15
lightness_1 = 1.5
colors = set_custom_rcParams()
# Set axis spines visibility
plt.rc(
    "axes.spines",
    **{
        "left": False,
        "right": False,
        "bottom": True,
        "top": False,
    },
)
plt.rcParams["axes.grid"] = False

plot_colors(colors)
variables_color = dict()
variables_color["NAO"] = colors[0]
variables_color["AMO"] = colors[2]
variables_color["sin"] = colors[1]
variables_color["latent"] = colors[3]

In [None]:
REPO_PATH = Path(".").resolve().parent
results_path = REPO_PATH / Path("results") / "Presentation"
results_path.mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    store_path = results_path / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

In [None]:
NAO_df = pd.read_csv(
    r"C:\Users\Niebaum\Documents\Repositories\reconstruct-climate-indices\data\observations\NAO\nao.csv",
    delim_whitespace=True,
    skiprows=1,
    header=None,
    index_col=0,
    skipfooter=3,
    # encoding='utf-16'
)
NAO_df
NAO_df.index = pd.to_datetime(NAO_df.index)
NAO_df.index.name = "Year"


NAO_np = NAO_df.to_numpy().flatten()
NAO_date = pd.date_range(start="1948-01-01", periods=len(NAO_np), freq="MS")

  NAO_df = pd.read_csv(


In [None]:
AMO_df = pd.read_csv(
    r"C:\Users\Niebaum\Documents\Repositories\reconstruct-climate-indices\data\observations\AMO\amo.csv",
    delim_whitespace=True,
    skiprows=1,
    header=None,
    index_col=0,
    skipfooter=4,
    # encoding='utf-16'
)
AMO_df
AMO_df.index = pd.to_datetime(AMO_df.index)
AMO_df.index.name = "Year"

AMO_np = AMO_df.to_numpy().flatten()
# AMO_np = np.tile(AMO_np, 3)
# NAO_np = np.tile(NAO_np, 3)

AMO_date = pd.date_range(start="1948-01-01", periods=len(AMO_np), freq="MS")
AMO_df

  AMO_df = pd.read_csv(


In [None]:
data = xr.Dataset(
    data_vars=dict(NAO=(["time"], NAO_np), AMO=(["time"], AMO_np)),
    coords=dict(time=(["time"], AMO_date)),
)
for var in data.data_vars:
    data[var][data[var] == -99.99] = np.nan
    data[var][data[var] == -99.9] = np.nan

# smooth
data["AMO"] = data["AMO"].rolling(time=5 * 12, center=True).mean()

mask = np.isfinite(data)
finite_mask = np.logical_and(mask.AMO.values, mask.NAO.values)
data = data.sel(time=data.time[finite_mask])


# plot data
fig, ax_NAO = plt.subplots(1, 1)

ax_AMO = ax_NAO.twinx()
# plot NAO
color = variables_color["NAO"]
ax_NAO.plot(data.time, data.NAO, color=color, linewidth=2, alpha=0.75)
ax_NAO.set_xlabel("year")
ax_NAO.set_ylabel("NAO")
ax_NAO.spines["right"].set_color(color)
ax_NAO.yaxis.label.set_color(color)
ax_NAO.tick_params(axis="y", colors=color)
# plot AMO
color = variables_color["AMO"]
ax_AMO.plot(data.time, data.AMO, color=color, linewidth=2)
ax_AMO.set_xlabel("year")
ax_AMO.set_ylabel("AMO")
ax_AMO.spines["right"].set_color(color)
ax_AMO.yaxis.label.set_color(color)
ax_AMO.tick_params(axis="y", colors=color)

In [None]:
data = normalize(data)

add_random_variable(
    ds=data,
    var_name="latent",
    random_generator=np.random.default_rng(seed=10000),
    dim="time",
    variance=1,
)
# data["sin"] = (
#     np.sin(2 * np.pi * np.arange(0, len(data.time)) / (10 * 12)) + data.AMO * 0
# )
# data["sin"].plot()
nb_iter_SEM = 50
data_kalman = xarray_Kalman_SEM_full_output(
    ds=data,
    observation_variables=["AMO", "NAO"],
    state_variables=["AMO", "NAO", "latent"],
    nb_iter_SEM=nb_iter_SEM,
    variance_obs_comp=0.0001,
)
data_kalman_states = from_standard_dataset(data_kalman)

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:25<00:00,  1.97it/s]


In [None]:
data_kalman["log_likelihod"].plot()

[<matplotlib.lines.Line2D at 0x1b750972140>]

In [None]:
for var in ["NAO", "AMO"]:
    plt.scatter(
        normalize(data[var], "oneone"),
        normalize(data_kalman_states["latent"].isel(kalman_iteration=-1), "oneone"),
        label=var,
    )

## Power Spectral Density

In [None]:
fs = 12
welch_window_width = 150  # years
nperseg = len(data.time)
years = nperseg / fs

In [None]:
# Compute PSD with frequency in year**{-1}
# Set up welch_kwargs to use NO welch method.
welch_kwargs = dict(
    fs=fs,  # period is 1/12 y -> fs = 12 y^{-1}
    nperseg=nperseg,  # length in timesteps
    scaling="density",
    window="boxcar",
)

psd_data = xarray_dataset_welch(data, dim="time", welch_kwargs=welch_kwargs)
psd_data_kalman = xarray_dataset_welch(
    data_kalman_states, dim="time", welch_kwargs=welch_kwargs
)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, layout="constrained", sharex=True, sharey=True)
kwargs = dict(
    linestyle="-",
    linewidth="1",
    marker=".",
)

ax.set_title("Observations")

# ------------------
# plot Sponge
# ------------------
var = "AMO"
(psd_AMO,) = ax.loglog(
    psd_data.frequency,
    psd_data[var],
    label=var,
    color=adjust_lightness(variables_color[var], lightness_0),
    alpha=0.75,
)
var = "NAO"
(psd_NAO,) = ax.loglog(
    psd_data.frequency,
    psd_data[var],
    label=var,
    color=adjust_lightness(variables_color[var], lightness_0),
    alpha=0.75,
)
var = "latent"
(psd_latent,) = ax.loglog(
    psd_data_kalman.frequency,
    psd_data_kalman[var].isel(kalman_iteration=0),
    label=var,
    color=adjust_lightness(variables_color[var], lightness_0),
)


ax.legend(
    ncols=2,
    loc="lower left",
    handlelength=1,
    labelspacing=0.01,
    handletextpad=0.15,
    columnspacing=0.2,
)

ax.set_ylim(ymin=10 ** (-7), ymax=10 ** (4))
# fig.suptitle(
#     f"Power Density Spectrum of Latent variable, Observations and hidden Component"
# )
ax.set_ylabel(r"PSD $K^{2}y^{-1}$")

ax.set_xlabel(r"$f$ in $y^{-1}$")
ax.grid()


def init_lines():
    psd_latent.set_ydata(psd_data["latent"])
    return (psd_latent,)


def update_lines(
    idx,
):
    if idx == 0:
        psd_latent.set_ydata(psd_data["latent"])
    else:
        idx -= 1
        psd_latent.set_ydata(psd_data_kalman["latent"].isel(kalman_iteration=idx))
    return (psd_latent,)


init_lines()
update_lines(40)

ani_PSD = animation.FuncAnimation(
    fig,
    update_lines,
    init_func=init_lines,
    save_count=nb_iter_SEM + 1,
    interval=200,
    blit=True,
)
from IPython.display import HTML

HTML(ani_PSD.to_html5_video())
# ani_PSD
# # # To save the animation using Pillow as a gif
# writer = animation.FFMpegWriter(
#     fps=1.5,
#     metadata=dict(artist='Me'),
#     bitrate=-1,
# )
# ani_PSD.save(results_path / 'PSD_evolution.mp4', writer=writer, dpi = 256)