# test harmonic fitting for seasonal cycle

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import pathlib
import cmocean
import os
import copy

# Import custom modules
import src.utils
import src.XRO

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load

In [None]:
## load spatial data
_, anom = src.utils.load_consolidated()

## load T,h data and add to anom
Th = src.utils.load_cesm_indices()
anom = xr.merge([anom[["sst", "sst_comp", "T", "T_comp"]], Th[["T_3", "h_w"]]])

## Differentiate

In [None]:
## get sst tendency (and convert from 1/yr to 1/mo)
for n in ["sst", "T_3"]:
    anom[f"ddt_{n}"] = src.utils.get_ddt(anom[[n]])[f"ddt_{n}"]

## Early/late split

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early).compute()
anom_late = anom.sel(t_late).compute()

## Testing

### Scalar

#### Fit

In [None]:
## specify order
ac_order = 4

## specify XRO model
model = src.XRO.XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## empty list to hold results
R_recons = []
params = []
R_recons_bymonth = []

## kwargs for fitting
custom_kwargs = dict(y_vars=["ddt_T_3"], x_vars=["T_3", "h_w"], max_order=ac_order)
custom_kwargs_bymonth = dict(y_var="ddt_T_3", x_vars=["T_3", "h_w"])
xro_kwargs = dict(ac_mask_idx=None, maskNT=[])

## Reconstruct scalar indices
for x in [anom_early, anom_late]:

    ## custom
    R_recon = src.utils.regress_harm_wrapper(x, **custom_kwargs)
    R_recons.append(R_recon["ddt_T_3"].isel(ell=0))

    ## custom (old)
    R_recon_bymonth = src.utils.regress_bymonth(x, **custom_kwargs_bymonth)
    R_recons_bymonth.append(R_recon_bymonth["T_3"])

    ## XRO
    fit = model.fit_matrix(x[["T_3", "h_w"]], **xro_kwargs)
    params.append(model.get_RO_parameters(fit))

#### Plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

axs[0].plot(np.arange(1, 13), R_recons[0])
axs[0].plot(np.arange(1, 13), R_recons[1])

axs[1].plot(np.arange(1, 13), R_recons_bymonth[0])
axs[1].plot(np.arange(1, 13), R_recons_bymonth[1])

for ax in axs:
    ax.plot(params[0].cycle - 0.5, params[0]["R"], ls="--", c="k", alpha=0.8)
    ax.plot(params[1].cycle - 0.5, params[1]["R"], ls="--", c="gray")

    ax.axvline(5, lw=0.6)
    ax.axvline(6, lw=0.6)

axs[0].set_ylim(axs[1].get_ylim())
plt.show()

### spatial

#### Fit

In [None]:
## reconstrut spatially
coefs_spatial = src.utils.regress_xr(
    data=anom_late,
    y_vars=["ddt_sst"],
    x_vars=["T_3", "h_w"],
    helper_fn=src.utils.regress_harm_wrapper,
    max_order=ac_order,
)
R_recon_spatial = 1 / 12 * coefs_spatial["ddt_sst"].isel(ell=0)

#### Plot

In [None]:
## shared args
kwargs = dict(amp=0.5, lat_bound=5)

## set up plot
fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=R_recon_spatial, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=R_recon_spatial, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=R_recon_spatial, **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

## Test gradient

In [None]:
def ddt(x, order=2):

    ## empty array to hold results
    x_ = x.transpose("time", ...)

    if order == 2:

        ddt_x = xr.zeros_like(x_).isel(time=slice(1, -1))
        ddt_x.values = 0.5 * (x_.values[2:] - x_.values[:-2])

    else:
        ddt_x = xr.zeros_like(x_).isel(time=slice(None, -1))
        ddt_x.values = x_.values[1:] - x_.values[:-1]

    return ddt_x * 12

In [None]:
ddt0 = src.utils.get_ddt(Th[["T_3"]])["ddt_T_3"]
ddt0_ = src.utils.get_ddt(Th[["T_3"]], is_forward=False)["ddt_T_3"]
ddt1 = ddt(Th["T_3"], order=1)
ddt2 = ddt(Th["T_3"], order=2)

## Test time stepping

In [None]:
fun_2d = lambda x: x.sel(latitude=slice(-1.5, 1.5), longitude=slice(210, 270)).mean(
    ["latitude", "longitude"]
)
fun_1d = (
    lambda x: x.sel(longitude=slice(210, 270))
    .isel(z_t=0)
    .mean("longitude")
    .squeeze(drop=True)
)

sst_recon = src.utils.reconstruct_fn(
    scores=anom_early["sst"],
    components=anom_early["sst_comp"],
    fn=fun_2d,
)

T_recon = src.utils.reconstruct_fn(
    scores=anom_early["T"],
    components=anom_early["T_comp"],
    fn=fun_1d,
)

In [None]:
plot_idx = dict(time=slice("1851", "1854"), member=9)
shift = lambda x: x.isel(dict(time=slice(1, None)))

fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(sst_recon.sel(plot_idx))
# ax.plot(shift(T_recon.sel(plot_idx)), ls="--")
ax.plot(T_recon.sel(plot_idx), ls="--")
# ax.plot(anom_early["T_3"].sel(plot_idx))
plt.show()

## Test $\frac{d}{dt}$ function

In [None]:
_, budg_anom = src.utils.load_budget_data()
budg_early = budg_anom.sel(t_early).compute()

In [None]:
adv0 = fun_1d(budg_early["TEND_TEMP"])
adv1 = 1 / 12 * src.utils.get_ddt(T_recon.rename("T").to_dataset())["ddt_T"]

In [None]:
adv3 = xr.zeros_like(T_recon.isel(time=slice(1, -1)))
adv3.values = 0.5 * (T_recon.values[:, 2:] - T_recon.values[:, :-2])

In [None]:
plot_idx = dict(time=slice("1851", "1854"), member=0)
shift = lambda x: x.isel(dict(time=slice(1, None)))

fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(adv0.sel(plot_idx))
ax.plot(adv1.sel(plot_idx), ls="-")
ax.plot(adv3.sel(plot_idx), ls="--")
plt.show()