In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path
from os import PathLike
from kalman_reconstruction.custom_plot import (
    set_custom_rcParams,
    plot_colors,
    plot_state_with_probability,
    adjust_lightness,
)
from kalman_reconstruction import pipeline
from kalman_reconstruction.statistics import normalize
from reconstruct_climate_indices.statistics import xarray_dataset_welch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.figure import Figure

# from sklearn.linear_model import LinearRegression
from typing import Dict

In [None]:
# ## LIGHT THEME
# plt.style.use("seaborn-v0_8-whitegrid")
# dark_color = [0.3, 0.3, 0.3]
# light_color = [0.8, 0.8, 0.8]
# lightness_0 = 0.75
# lightness_1 = 0.5
# cmap = "rocket"
# cmap_r = "rocket_r"

### DARK THEME
plt.style.use("dark_background")
dark_color = [0.7, 0.7, 0.7]
light_color = [0.2, 0.2, 0.2]
lightness_0 = 1.15
lightness_1 = 1.5
cmap = "rocket_r"
cmap_r = "rocket"


colors = set_custom_rcParams()
plt.rcParams["grid.alpha"] = 0.5
plt.rcParams["axes.grid"] = True

plot_colors(colors)

variables_color = dict()
variables_color["NAO_ST"] = colors[0]
variables_color["AMO"] = colors[2]
variables_color["latent"] = colors[1]
variables_color["latent_2"] = colors[4]
variables_color["log_likelihod"] = colors[-1]

In [None]:
REPO_PATH = Path(".").resolve().parent.parent
results_path = REPO_PATH / Path("results") / "earth_system_models" / "CMIP6" / "MIROC6"
results_path.mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True


def save_fig(fig: Figure, relative_path: PathLike, kwargs: Dict = dict()):
    store_path = results_path / relative_path
    store_path.parent.mkdir(parents=True, exist_ok=True)
    if SAVE_FIGURES:
        fig.savefig(store_path, **kwargs)
    else:
        pass

In [None]:
DATA_PATH = REPO_PATH / "data" / "earth_system_models" / "cmip6" / "miroc6"
AMO = xr.open_dataset(DATA_PATH / "AMO_anomalies_MIROC.nc")
NAO = xr.open_dataset(DATA_PATH / "NAO_ensemble_norm_MIROC.nc")
AMO

In [None]:
AMO = AMO.rename(dict(tos="AMO"))
NAO = NAO.rename(dict(psl="NAO_ST"))

miroc6 = xr.merge([AMO, NAO])
miroc6 = miroc6.assign_coords(member=miroc6.member)
# rng = np.random.default_rng(seed = 12347654)
# pipeline.add_random_variable(
#     ds = miroc6,
#     var_name = "latent",
#     random_generator= rng,
#     variance=1,
#     dim = "time"
# )

state_variables = list(miroc6.data_vars)[::-1]  # use reverse order

number_latent = 3
seed = 12347653
for i in range(1, number_latent + 1):
    rng2 = np.random.default_rng(seed=seed + i)
    pipeline.add_random_variable(
        ds=miroc6, random_generator=rng2, var_name=f"latent{i}", variance=1, dim="time"
    )

state_variables += [f"latent{i}" for i in range(1, number_latent + 1)]
state_variables

### Run Kalman-SEM OR load dataset

In [None]:
# subdataset_selections = [
#     dict(member = idx) for idx in miroc6.member.values
# ]
# kalman_full_miroc6 = pipeline.run_function_on_multiple_subdatasets(
#     processing_function=pipeline.xarray_Kalman_SEM,
#     parent_dataset=miroc6,
#     subdataset_selections=subdataset_selections,
#     func_kwargs= dict(
#         observation_variables=["AMO", "NAO_ST"],
#         state_variables=state_variables,
#         nb_iter_SEM=50
#     )
# )

kalman_result = xr.open_dataset(
    REPO_PATH
    / "data/earth_system_models/processed/thoughtful-finch-610/thoughtful-finch-610_kalman.nc"
)
kalman_states = pipeline.from_standard_dataset(kalman_result, var_name="states")

#### Create test sinus signal for frequency analysis

In [None]:
# create a test sinus to check the frequency is calculated correct
period_in_years = 6
period = period_in_years * pd.Timedelta(value=365.25, unit="days") / (2 * np.pi)
td = miroc6.time - miroc6.time[0].values
sinus = np.sin(td / period)

miroc6["sinus"] = sinus
miroc6["sinus"].plot()
np.isnan(miroc6["sinus"]).sum()
# plt.xlim(
#     miroc6.time[0],
#     miroc6.time[12*10]
#     )

### Analysis

#### Prepare Fourier analysis

**Note: The signal is not evenly spaced!**
**Imporovement needs to be done**

In [None]:
# compute timesteps in days
dts = (miroc6.time.shift(time=-1) - miroc6.time.shift(time=0)) / np.timedelta64(1, "D")

fig, ax = plt.subplots(1, 1)
dts.isel(time=slice(0, 5 * 12)).plot(
    ax=ax,
    linestyle="",
    marker="o",
    markersize="10",
)
ax.set_ylabel("Time-Step in days")
ax.set_xlabel("Date")
save_fig(fig, relative_path="Time-Step.png")
save_fig(fig, relative_path="Time-Step.svg")
# compute the mean timestep
dt_mean = dts.mean()
fs = 365.25 / dt_mean
fs = fs.values

In [None]:
welch_kwargs = dict(
    fs=fs,
    window="boxcar",
    nperseg=len(miroc6.time),
)

miroc6_freq = xarray_dataset_welch(
    ds=miroc6,
    dim="time",
    welch_kwargs=welch_kwargs,
)
kalman_states_freq = xarray_dataset_welch(
    ds=kalman_states,
    dim="time",
    welch_kwargs=welch_kwargs,
)

In [None]:
plt.loglog(
    miroc6_freq.frequency, miroc6_freq.sinus, linestyle="-", label="result of FFT"
)
plt.axvline(
    1 / period_in_years, color=dark_color, label="theoretical value of sine period"
)
plt.legend(loc="lower left")
plt.ylim(10 ** (-8), 10 ** (2.5))

(1e-08, 316.22776601683796)

#### Plot loglikelihood

In [None]:
var = "log_likelihod"
data = kalman_result[var]
data_diff = data - data.isel(kalman_iteration=0)
fig, axs = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=(12, 7))

plot_state_with_probability(
    x_value=data["kalman_iteration"],
    state=data.mean(dim="member"),
    prob=data.std(dim="member"),
    stds=2,
    ax=axs[0],
    line_kwargs=dict(linewidth=3, zorder=10),
)
for mem in data.member:
    axs[0].plot(
        data.kalman_iteration,
        data.sel(member=mem),
        linestyle="-",
        color=adjust_lightness(dark_color, 0.6),
        alpha=0.4,
        zorder=1,
    )

for mem in [4, 42]:
    axs[0].plot(
        data.kalman_iteration,
        data.sel(member=mem),
        linestyle="-",
        color="r",
        alpha=1,
        zorder=1,
    )
    axs[0].annotate(
        f"{mem}",
        xy=(data.kalman_iteration[-1], data.sel(member=mem).isel(kalman_iteration=-1)),
        color="r",
    )


plot_state_with_probability(
    x_value=data_diff["kalman_iteration"],
    state=data_diff.mean(dim="member"),
    prob=data_diff.std(dim="member"),
    stds=2,
    ax=axs[1],
    line_kwargs=dict(linewidth=3, zorder=10),
)
for mem in data_diff.member:
    axs[1].plot(
        data_diff.kalman_iteration,
        data_diff.sel(member=mem),
        linestyle="-",
        color=adjust_lightness(dark_color, 0.6),
        alpha=0.4,
        zorder=1,
    )
    axs[1].annotate(
        f"{mem.values}",
        xy=(
            data_diff.kalman_iteration[-1],
            data_diff.sel(member=mem).isel(kalman_iteration=-1),
        ),
    )

for mem in [4, 42]:
    axs[1].plot(
        data_diff.kalman_iteration,
        data_diff.sel(member=mem),
        linestyle="-",
        color="r",
        alpha=1,
        zorder=1,
    )
    axs[1].annotate(
        f"{mem}",
        xy=(
            data_diff.kalman_iteration[-1],
            data_diff.sel(member=mem).isel(kalman_iteration=-1),
        ),
        color="r",
    )

axs[0].set_title("Log Likelihood")
axs[0].set_ylabel("Log Likelihood")
axs[1].set_title("Log Likelihood difference to initialization")
axs[1].set_ylabel("Log Likelihood difference to initialization")

for ax in axs.flatten():
    ax.set_xlabel("Kalman iteration")

save_fig(fig, relative_path="Loglikelihood.svg")
save_fig(fig, relative_path="Loglikelihood.png", kwargs=dict(dpi=256))

#### Plot frequency analysis

In [None]:
fig, axs = plt.subplots(
    nrows=2, ncols=1, sharex=True, sharey=True, figsize=(12, 9), layout="constrained"
)

for var in ["AMO", "NAO_ST", "latent"]:
    data = kalman_states_freq[var]
    # for mem in data.member :
    #     plt.loglog(data.frequency, data.sel(member = mem),
    #                linestyle = "-",
    #                color = [0.3,0.3,0.3],
    #                alpha = 0.1,
    #                zorder = 3)
    axs[0].loglog(
        data.frequency,
        data.median("member"),
        linestyle="-",
        color=variables_color[var],
        zorder=5,
        label=var,
        linewidth=3,
    )
    axs[0].fill_between(
        data.frequency,
        data.quantile(q=0.1, dim="member"),
        data.quantile(q=0.9, dim="member"),
        color=variables_color[var],
        alpha=0.3,
        zorder=3,
    )

axs[0].legend(loc="lower left")

for var in ["latent"]:
    data = kalman_states_freq[var]
    for mem in data.member:
        (indi_lines,) = axs[1].loglog(
            data.frequency,
            data.sel(member=mem),
            linestyle="-",
            color=adjust_lightness(dark_color, 0.6),
            alpha=0.2,
            zorder=1,
        )
    axs[1].loglog(
        data.frequency,
        data.median("member"),
        linestyle="-",
        color=variables_color[var],
        zorder=5,
        label=var,
        linewidth=3,
    )
    axs[1].fill_between(
        data.frequency,
        data.quantile(q=0.1, dim="member"),
        data.quantile(q=0.9, dim="member"),
        color=variables_color[var],
        alpha=0.3,
        zorder=3,
    )
var = "latent"
axs[1].loglog(
    miroc6_freq.frequency,
    miroc6_freq[var],
    linestyle="-",
    color=variables_color["latent_2"],
    zorder=2,
    alpha=0.75,
    linewidth=3,
    label=f"{var} init.",
)

axs[1].legend(loc="lower left")


axs[0].set_title("PSD spectra | Median as thick line | 0.1 to 0.9 quantile shaded")
axs[1].set_title("As above for latent | Individual members as thin lines")


for ax in axs.flatten():
    ax.set_ylim(ymin=10 ** (-5.5), ymax=10 ** (0.5))
    ax.set_ylabel(r"PSD in $V^2 y^{-1}$")
    ax.set_xlabel(r"Frequency in $y^{-1}$")

save_fig(fig, relative_path="PSD_spectra.svg")
save_fig(fig, relative_path="PSD_spectra.png", kwargs=dict(dpi=256))

#### Evolution of latent variable

In [None]:
fig, axs = plt.subplots(
    nrows=3, ncols=1, sharex=True, sharey=False, figsize=(12, 6), layout="constrained"
)

member_idx = 42
current_ds = kalman_states.sel(member=member_idx)
axs[0].set_title(f"Evolution of states | member: {member_idx}")
for idx, var in enumerate(["NAO_ST", "AMO", "latent"]):
    axs[idx].plot(
        current_ds.time, current_ds[var], label=var, color=variables_color[var]
    )
    axs[idx].set_ylabel(var, color=variables_color[var])

for ax in axs.flatten():
    ax.set_xlim(
        xmin=current_ds.time.min(),
        xmax=current_ds.time.max(),
    )

save_fig(fig, relative_path=f"Evolution_{member_idx}.svg")
save_fig(fig, relative_path=f"Evolution_{member_idx}.png", kwargs=dict(dpi=256))

### Perform smoothing of timeseries

In [None]:
import scipy as sp

smooth_AMO = sp.fft.irfft(
    np.median(sp.fft.rfft(miroc6.AMO.values, axis=1), axis=0),
)

smooth_NAO_ST = sp.fft.irfft(
    np.median(sp.fft.rfft(miroc6.NAO_ST.values, axis=1), axis=0),
)


smooth_miroc6 = miroc6.mean("member").copy()
smooth_miroc6["AMO"] = smooth_miroc6["AMO"] * 0 + smooth_AMO
smooth_miroc6["NAO_ST"] = smooth_miroc6["NAO_ST"] * 0 + smooth_NAO_ST

In [None]:
fig, axs = plt.subplots(
    nrows=3, ncols=1, sharex=True, sharey=False, figsize=(12, 6), layout="constrained"
)

current_ds = smooth_miroc6
axs[0].set_title(f"Evolution of states | inverse DFFT of median over members of DFFT")
for idx, var in enumerate(["NAO_ST", "AMO", "latent"]):
    axs[idx].plot(
        current_ds.time, current_ds[var], label=var, color=variables_color[var]
    )
    axs[idx].set_ylabel(var, color=variables_color[var])

for ax in axs.flatten():
    ax.set_xlim(
        xmin=current_ds.time.min(),
        xmax=current_ds.time.max(),
    )

# save_fig(fig, relative_path=f"Evolution_{member_idx}.svg")
# save_fig(fig, relative_path=f"Evolution_{member_idx}.png", kwargs=dict(dpi = 256))

NameError: name 'plt' is not defined

In [None]:
# smooth_kalman_result = pipeline.xarray_Kalman_SEM(
#     ds = smooth_miroc6,
#     observation_variables=["AMO", "NAO_ST"],
#     state_variables=["AMO", "NAO_ST", "latent"],
#     nb_iter_SEM=50
# )

In [None]:
(
    smooth_kalman_result["log_likelihod"]
    - smooth_kalman_result["log_likelihod"].isel(kalman_iteration=0)
).plot()

NameError: name 'smooth_kalman_result' is not defined

In [None]:
smooth_kalman_states = pipeline.from_standard_dataset(
    smooth_kalman_result, var_name="states"
)

In [None]:
fig, axs = plt.subplots(
    nrows=3, ncols=1, sharex=True, sharey=False, figsize=(12, 6), layout="constrained"
)

current_ds = smooth_miroc6
axs[0].set_title(f"Evolution of states | Final state using frequency smoothed")
for idx, var in enumerate(["NAO_ST", "AMO", "latent"]):
    axs[idx].plot(
        current_ds.time, current_ds[var], label=var, color=variables_color[var]
    )
    axs[idx].set_ylabel(var, color=variables_color[var])

for ax in axs.flatten():
    ax.set_xlim(
        xmin=current_ds.time.min(),
        xmax=current_ds.time.max(),
    )

# save_fig(fig, relative_path=f"Evolution_{member_idx}.svg")
# save_fig(fig, relative_path=f"Evolution_{member_idx}.png", kwargs=dict(dpi = 256))

In [None]:
welch_kwargs = dict(
    fs=fs,
    window="boxcar",
    nperseg=len(miroc6.time),
)

smooth_miroc6_freq = xarray_dataset_welch(
    ds=smooth_miroc6,
    dim="time",
    welch_kwargs=welch_kwargs,
)
smooth_kalman_states_freq = xarray_dataset_welch(
    ds=smooth_kalman_states,
    dim="time",
    welch_kwargs=welch_kwargs,
)

#### Plot frequency analysis

In [None]:
fig, ax = plt.subplots(
    nrows=1, ncols=1, sharex=True, sharey=True, figsize=(12, 6), layout="constrained"
)

for var in ["AMO", "NAO_ST", "latent"]:
    data = smooth_kalman_states_freq[var]
    ax.loglog(
        data.frequency,
        data,
        linestyle="-",
        color=variables_color[var],
        zorder=5,
        label=var,
        linewidth=3,
    )
var = "latent"
ax.loglog(
    smooth_miroc6_freq.frequency,
    smooth_miroc6_freq[var],
    linestyle="-",
    color=variables_color["latent_2"],
    zorder=2,
    alpha=0.75,
    linewidth=3,
    label=f"{var} init.",
)

ax.legend(loc="lower left")

ax.set_ylim(ymin=10 ** (-6.5), ymax=10 ** (0.5))
ax.set_ylabel(r"PSD in $V^2 y^{-1}$")
ax.set_xlabel(r"Frequency in $y^{-1}$")

# save_fig(fig, relative_path="PSD_spectra.svg")
# save_fig(fig, relative_path="PSD_spectra.png", kwargs=dict(dpi = 256))

Text(0.5, 0, 'Frequency in $y^{-1}$')

### Member 42 analysis multiple latent variables

In [None]:
plot_colors(colors)

variables_color["latent1"] = colors[1]
variables_color["latent2"] = colors[3]
variables_color["latent3"] = colors[5]

In [None]:
ds_member42 = miroc6.sel(member=42).copy()
number_latent = 3
seed = 903298487326
for i in range(1, number_latent + 1):
    rng2 = np.random.default_rng(seed=seed + i)
    pipeline.add_random_variable(
        ds=ds_member42,
        random_generator=rng2,
        var_name=f"latent{i}",
        variance=1,
        dim="time",
    )
ds_member42

state_variables = ["NAO_ST", "AMO"] + [
    f"latent{i}" for i in range(1, number_latent + 1)
]

In [None]:
member42_kalman_result = pipeline.xarray_Kalman_SEM(
    ds=ds_member42,
    observation_variables=["AMO", "NAO_ST"],
    state_variables=state_variables,
    nb_iter_SEM=50,
    variance_obs_comp=0.001,
)

100%|██████████| 50/50 [00:36<00:00,  1.38it/s]


In [None]:
member42_kalman_states = pipeline.from_standard_dataset(member42_kalman_result)

In [None]:
height = int(6 / 4 * len(state_variables))
fig, axs = plt.subplots(
    nrows=len(state_variables),
    ncols=1,
    sharex=True,
    sharey=False,
    figsize=(12, height),
    layout="constrained",
)

current_ds = member42_kalman_states.isel(time=slice(2, 1000000))
axs[0].set_title(f"Evolution of states | Final state using frequency smoothed")
for idx, var in enumerate(state_variables):
    axs[idx].plot(
        current_ds.time, current_ds[var], label=var, color=variables_color[var]
    )
    axs[idx].set_ylabel(var, color=variables_color[var])

for ax in axs.flatten():
    ax.set_xlim(
        xmin=current_ds.time.min(),
        xmax=current_ds.time.max(),
    )

save_fig(fig, relative_path=f"Evolution_member_42_latent{number_latent}.svg")
save_fig(
    fig,
    relative_path=f"Evolution_member_42_latent{number_latent}.png",
    kwargs=dict(dpi=256),
)

In [None]:
welch_kwargs = dict(
    fs=fs,
    window="boxcar",
    nperseg=len(miroc6.time),
)

member42_kalman_states_freq = xarray_dataset_welch(
    ds=member42_kalman_states,
    dim="time",
    welch_kwargs=welch_kwargs,
)
ds_member42_freq = xarray_dataset_welch(
    ds=ds_member42,
    dim="time",
    welch_kwargs=welch_kwargs,
)

In [None]:
fig, ax = plt.subplots(
    nrows=1, ncols=1, sharex=True, sharey=True, figsize=(12, 6), layout="constrained"
)

current_ds = member42_kalman_states_freq
for var in ["AMO", "NAO_ST", "latent1", "latent2", "latent3"]:
    data = current_ds[var]
    ax.loglog(
        data.frequency,
        data,
        linestyle="-",
        color=variables_color[var],
        zorder=5,
        label=var,
        linewidth=3,
        alpha=0.7,
    )

ax.legend(loc="lower left")

ax.set_ylim(ymin=10 ** (-6.5), ymax=10 ** (0.5))
ax.set_ylabel(r"PSD in $V^2 y^{-1}$")
ax.set_xlabel(r"Frequency in $y^{-1}$")

save_fig(fig, relative_path=f"PSD_spectra_42_latent{number_latent}.svg")
save_fig(
    fig, relative_path=f"PSD_spectra_42_latent{number_latent}.png", kwargs=dict(dpi=256)
)

In [None]:
member42_kalman_states_df = normalize(
    member42_kalman_states.drop("member")
    .drop("kalman_iteration")
    .drop("state_name_copy"),
    method="mean",
).to_dataframe()
member42_kalman_states_df

In [None]:
g = sns.PairGrid(member42_kalman_states_df)
g.map_upper(sns.histplot, kde=True)
g.map_lower(sns.kdeplot, fill=True)
g.map_diag(sns.histplot, kde=True)

<seaborn.axisgrid.PairGrid at 0x19b68f7bb80>

In [None]:
save_fig(g, relative_path=f"3_latent_distribution_42_latent_{number_latent}.pdf")
save_fig(g, relative_path=f"3_latent_distribution_42_latent_{number_latent}.svg")
save_fig(g, relative_path=f"3_latent_distribution_42_latent_{number_latent}.png")