In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from kalman_reconstruction import pipeline
from kalman_reconstruction.custom_plot import (
    ncols_nrows_from_N,
    set_custom_rcParams,
    plot_state_with_probability,
    adjust_lightness,
)
from kalman_reconstruction.statistics import normalize
import matplotlib.pyplot as plt
import yaml

set_custom_rcParams()
plt.rcParams["axes.grid"] = True

In [None]:
experiment_name = "smoothed_10y_not_NAO"
REPO_PATH = Path("..")
PATH_FIGURES = REPO_PATH / Path("results") / "CiCMOD" / experiment_name
SAVE_FIGURES = True


def save_fig(fig, relative_path, **kwargs):
    store_path = PATH_FIGURES / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

In [None]:
time_slice = slice(500, 1000)
data = data_original.sel(time=time_slice).copy()
rolling_window = 0 * 12
rng_seed = 83653
random_variance = 1
nb_iter_SEM = 50
observation_variables = ["AMO", "NAO_ST", "SAT_N_OCEAN"]
state_variables = ["AMO", "NAO_ST", "SAT_N_OCEAN", "latent"]

settings = dict(
    rolling_window=rolling_window,
    rng_seed=rng_seed,
    random_variance=random_variance,
    observation_variables=observation_variables,
    state_variables=state_variables,
    nb_iter_SEM=nb_iter_SEM,
    time_slice=dict(
        start=time_slice.start,
        stop=time_slice.stop,
    ),
    data_path=str(data_path),
)
with open(PATH_FIGURES / "settings.yaml", "w") as stream:
    stream.write(
        "#Settings used in the application of the Kalman_SEM on the CiCOD dataset.\n"
    )
    yaml.dump(data=settings, stream=stream, default_flow_style=False)

In [None]:
data_path = (
    REPO_PATH
    / Path("data")
    / "earth_system_models"
    / "CiCMOD"
    / "climate_indices_CESM.nc"
)
data_original = xr.load_dataset(data_path)

In [None]:
try:
    data = (
        data.rolling(time=rolling_window)
        .mean()
        .isel(time=slice(rolling_window, -rolling_window))
    )
    data["NAO_ST"] = data_original["NAO_ST"]
    data["SAT_N_OCEAN"] = data_original["SAT_N_OCEAN"]
except:
    pass


random_vars = ["latent"]
for random_var in random_vars:
    rng = np.random.default_rng(seed=rng_seed)
    pipeline.add_random_variable(
        ds=data,
        var_name=random_var,
        random_generator=rng,
        variance=random_variance,
        dim="time",
    )

In [None]:
kalman_results = pipeline.xarray_Kalman_SEM(
    ds=data,
    observation_variables=observation_variables,
    state_variables=state_variables,
    nb_iter_SEM=nb_iter_SEM,
)

100%|██████████| 50/50 [03:47<00:00,  4.56s/it]


In [None]:
kalman_states = pipeline.from_standard_dataset(kalman_results)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(kalman_results.kalman_itteration, kalman_results.log_likelihod)
ax.set_xlabel("kalman itteration")
ax.set_ylabel("log likelihood")
fig.suptitle("CiCMOD | Loglikelihood Kalman SEM")
save_fig(fig, "CiCMOD_latent_indices_scatter.png", dpi=400)

In [None]:
fig, axs = plt.subplots(
    nrows=len(state_variables),
    ncols=1,
    layout="constrained",
    figsize=(12, 7),
    sharex=True,
)
axs = axs.flatten()
for idx, var in enumerate(state_variables):
    handle1, handle2 = plot_state_with_probability(
        ax=axs[idx],
        x_value=kalman_results.time,
        state=kalman_results.states.sel(state_name=var),
        prob=kalman_results.covariance.sel(state_name=var, state_name_copy=var),
        line_kwargs=dict(label=f"{var} recon."),
        output=True,
    )

    # if "latent" not in var:
    color = adjust_lightness(
        handle1[0].get_color(),
    )
    axs[idx].plot(data.time, data[var], label=f"{var} truth", alpha=0.7, linestyle=":")
    axs[idx].set_title(var)
    axs[idx].set_ylabel("value")
    axs[idx].legend()
axs[idx].set_xlabel("time in years")
fig.suptitle("CiCMOD | Reconstruction against truth")
save_fig(fig, "CiCMOD_recons_truth.png", dpi=400)

In [None]:
n_cols = len(state_variables)
fig, axs = plt.subplots(layout="constrained", figsize=(12, 4), ncols=n_cols)
axs = axs.flatten()

for idx, var in enumerate(state_variables):
    axs[idx].scatter(
        kalman_states[var],
        data[var],
    )
    axs[idx].set_xlabel("reconstruction")
    axs[idx].set_ylabel("truth")
    axs[idx].set_title(var)

fig.suptitle("CiCMOD | Reconstruction against truth")
save_fig(fig, "CiCMOD_recons_truth_scatter.png", dpi=400)

In [None]:
data_vars = data.data_vars
row_col = ncols_nrows_from_N(len(data_vars))

fig, axs = plt.subplots(layout="constrained", figsize=(20, 20), **row_col)
axs = axs.flatten()

for idx, var in enumerate(data_vars):
    axs[idx].scatter(
        kalman_states.latent,
        data[var],
        marker=".",
        alpha=0.75,
    )
    axs[idx].set_xlabel(var)
    axs[idx].set_ylabel("variable")
    axs[idx].set_title(var)

fig.suptitle("CiCMOD | Latent variable against climate indeces")
save_fig(fig, "CiCMOD_latent_indices_scatter.png", dpi=400)

### Compute lagged cross correlation and covariance  

In [None]:
lag_years = np.arange(-30, 30, 1)
data_vars = data.data_vars
da_ccov_list = []
da_ccor_list = []
for idx, var in enumerate(data_vars):
    for lag in lag_years:
        # because data is stored in monthly form, we need to multiply the shift by 12 to have teh lag in years
        lag_months = lag * 12
        # calculate the covariance
        ccov = xr.cov(
            data[var], kalman_states.latent.shift(time=lag * 12), dim="time"
        ).values
        da_ccov = xr.DataArray(
            data=ccov[np.newaxis],
            dims=["lag_years"],
            coords=dict(
                lag_years=(["lag_years"], [lag]),
            ),
        )
        da_ccov = da_ccov.rename(var)
        da_ccov_list.append(da_ccov)

        # calculate the correlation
        ccor = xr.corr(
            data[var], kalman_states.latent.shift(time=lag * 12), dim="time"
        ).values
        da_ccor = xr.DataArray(
            data=ccor[np.newaxis],
            dims=["lag_years"],
            coords=dict(
                lag_years=(["lag_years"], [lag]),
            ),
        )
        da_ccor = da_ccor.rename(var)
        da_ccor_list.append(da_ccor)


data_ccov = xr.merge(da_ccov_list)
data_ccor = xr.merge(da_ccor_list)

In [None]:
data_vars = state_variables
row_col = ncols_nrows_from_N(len(data_vars))

fig, axs = plt.subplots(
    layout="constrained",
    figsize=(12, 4),
    sharex=True,
    sharey=True,
    ncols=len(data_vars),
)
axs = axs.flatten()

for idx, var in enumerate(data_vars):
    axs[idx].step(data_ccor.lag_years, data_ccor[var], label="cor")
    axs[idx].set_xlabel("lag in years")
    axs[idx].set_ylabel("correlation")
    axs[idx].set_title(var)

extend = np.max(np.abs(axs[idx].get_ylim()))
axs[idx].set_ylim((-extend, extend))

fig.suptitle("CiCMOD | Lagged correlation of latent variable against state varibales")
save_fig(fig, "CiCMOD_latent_states_lagged_corr.png", dpi=400)

In [None]:
data_vars = data.data_vars
row_col = ncols_nrows_from_N(len(data_vars))

fig, axs = plt.subplots(
    layout="constrained", figsize=(20, 20), sharex=True, sharey=True, **row_col
)
axs = axs.flatten()

for idx, var in enumerate(data_vars):
    axs[idx].step(data_ccor.lag_years, data_ccor[var], label="cor")
    axs[idx].set_xlabel("lag in years")
    axs[idx].set_ylabel("correlation")
    axs[idx].set_title(var)

extend = np.max(np.abs(axs[idx].get_ylim()))
axs[idx].set_ylim((-extend, extend))

fig.suptitle("CiCMOD | Lagged correlation of latent variable against climate indices")
save_fig(fig, "CiCMOD_latent_indices_lagged_corr.png", dpi=400)

### Perform frequency analyis on input and ouptut data

In [None]:
from scipy import fftpack
from scipy.ndimage import uniform_filter1d

# Number of samplepoints
# sample spacing


def do_fft(x, y):
    dt = x[1] - x[0]
    y = y[~np.isnan(y)]

    N = len(y)
    x = np.arange(0, N, dt)

    yf = fftpack.fft(y)
    xf = np.linspace(0.0, 1.0 / (2.0 * dt), N // 2)
    yf_plot = 2.0 / N * np.abs(yf[: N // 2])
    f_min = 1 / N
    f_max = 1 / (2 * dt)
    return xf, yf, yf_plot, f_min, f_max

In [None]:
data_vars = state_variables
fig, axs = plt.subplots(
    layout="constrained",
    figsize=(12, 4),
    sharex=True,
    sharey=True,
    ncols=len(data_vars),
)
axs = axs.flatten()

for idx, var in enumerate(data_vars):
    ax = axs[idx]
    xf, yf, yf_plot, f_min, f_max = do_fft(x=data.time, y=data[var].values)
    ax.loglog(xf, yf_plot, label="truth", alpha=0.7)
    xf, yf, yf_plot, f_min, f_max = do_fft(
        x=kalman_states.time, y=kalman_states[var].values
    )
    ax.loglog(xf, yf_plot, label="reconst.", alpha=0.7)

    xticks = ax.get_xticks().copy()
    ax.set_xticks(ticks=xticks, labels=1 / xticks)
    x_ticks = ax.get_xticks()[1:]
    ax.set_xticks(ticks=x_ticks, labels=1 / x_ticks)

    f_min = 1 / 100  #  years^{-1}
    ax.set_xlim((f_min, f_max))

    ax.set_xlabel("Period in years")
    ax.set_ylabel("Power in ????")
    ax.set_title(var)
    ax.legend()

fig.suptitle("CiCMOD | Frequency spectrum state varibales")
save_fig(fig, "CiCMOD_fft_states.png", dpi=400)

  ax.set_xticks(ticks=xticks, labels= 1/xticks)
  ax.set_xticks(ticks=xticks, labels= 1/xticks)
  ax.set_xticks(ticks=x_ticks, labels=1/x_ticks)
