# Predictability

## Imports

In [None]:
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import src.XRO
import copy
import scipy.stats
import warnings
import calendar
import gsw

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## standardize (for convenience)
Th /= Th.std()

### Predictability

1. Fit RO to (i) early and (ii) late data
2. Assess forecast skill in both (compare difference)
3. Perfect model forecasts with RO

#### Fit models

In [None]:
def calc_forecast_skill(fcst_ds, ref_ds, metric="acc", is_mv3=False):
    try:
        fcst_ds = fcst_ds.squeeze().drop("member")
    except:
        pass

    if is_mv3:
        fcst_ds = fcst_ds.rolling(init=3, center=True, min_periods=1).mean("init")
        ref_mv3 = (
            ref_ds.rolling(time=3, center=True, min_periods=1).mean().dropna(dim="time")
        )
    else:
        ref_mv3 = ref_ds

    ## hindcast
    hc_XRO = climpred.HindcastEnsemble(fcst_ds)
    hc_XRO = hc_XRO.add_observations(ref_mv3)

    ## compute skill
    skill_XRO = hc_XRO.verify(
        metric=metric,
        comparison="e2o",
        alignment="maximize",
        dim=["sample"],
        skipna=True,
        groupby="month",
    )

    ## not sure
    try:
        del skill_XRO.attrs["skipna"]
        skill_XRO = skill_XRO.drop("skill")
    except:
        pass

    for var in skill_XRO.data_vars:
        if var != "model":
            skill_XRO[var].encoding["dtype"] = "float32"
            skill_XRO[var].encoding["_FillValue"] = 1e20

    return skill_XRO

In [None]:
## get early/late data
varnames = ["T_34", "h_w"]
Th0 = Th[varnames].isel(time=slice(12, 372))
Th1 = Th[varnames].isel(time=slice(-360, None))

## specify model
model = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)

## fit model
fit_kwargs = dict(ac_mask_idx=None, maskNT=[])
fit0 = model.fit_matrix(Th0, **fit_kwargs)
fit1 = model.fit_matrix(Th1, **fit_kwargs)

#### Make forecasts

Function to get forecast data

In [None]:
def get_forecast_data(
    model, fit, data, max_lead=28, n_members=5, perfect_model=False, save_fp=None
):
    """Get forecasts and verification data for given model, fit, and data"""

    ## try to open data
    if save_fp.is_file():
        time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
        forecast_data = xr.open_dataset(save_fp, decode_times=time_coder)

    else:

        ## stochastic simulation data
        if perfect_model:
            data = model.simulate(
                fit_ds=fit,
                X0_ds=data.isel(member=0, time=0),
                nyear=len(np.unique(data.time.dt.year)),
                ncopy=n_members,
            )

        ## trim and reshape for forecasting
        data_ = data.isel(member=slice(None, n_members))
        data_ = data_.stack(sample=["member", "time"]).transpose(..., "sample")

        ## make forecasts
        # forecast = model.reforecast(
        #     fit_ds=fit,
        #     init_ds=data_,
        #     n_month=max_lead,
        #     ncopy=1,
        #     noise_type="zero",
        # )
        forecast = model.reforecast(
            fit_ds=fit,
            init_ds=data_,
            n_month=max_lead,
            ncopy=100,
        ).mean("member")

        ## massage into xarray
        forecast = forecast.swap_dims({"sample": "init"}).unstack("init")
        forecast = forecast.rename({n: f"{n}_hat" for n in list(forecast)})
        forecast = forecast.isel(time=slice(None, -max_lead))

        #### verification data

        ## get time index
        time_idx = forecast.time
        n_time = len(time_idx)

        ## loop thru lead times
        target = []
        for l in forecast.lead.values:

            target_ = data.sel(member=forecast.member).isel(time=slice(l, n_time + l))
            target_ = target_.assign_coords(dict(time=time_idx))
            target.append(target_)

        ## convert to xr
        target = xr.concat(target, dim=forecast.lead)
        forecast_data = xr.merge([forecast, target])

        ## save to file
        # forecast_data.to_netcdf(save_fp)

    return forecast_data

Compute forecasts

In [None]:
## should we do perfect model forecast?
PERFECT_MODEL = True

## directory for forecast data
if PERFECT_MODEL:
    FNAME0 = "forecast_pm_data0.nc"
    FNAME1 = "forecast_pm_data1.nc"
    # FNAME0 = "test0.nc"
    # FNAME1 = "test1.nc"

else:
    FNAME0 = "forecast_data0.nc"
    FNAME1 = "forecast_data1.nc"

## get filepaths
FORECAST_DIR = pathlib.Path(os.environ["SAVE_FP"], "forecast_data")
FP0 = FORECAST_DIR / FNAME0
FP1 = FORECAST_DIR / FNAME1

## specified shared args
kwargs = dict(model=model, n_members=100, perfect_model=PERFECT_MODEL)

## load data
forecast_data0 = get_forecast_data(fit=fit0, data=Th0, save_fp=FP0, **kwargs)
forecast_data1 = get_forecast_data(fit=fit1, data=Th1, save_fp=FP1, **kwargs)

#### Evaluate

Funcs to compute correlation

In [None]:
def get_corr(ds, varname):
    """Get correlation for given variable in dataset"""

    return xr.corr(ds[varname], ds[f"{varname}_hat"], dim=["member", "time"])


def get_corr_bymonth(ds, varname):
    """Get correlation by month for given variable in dataset"""

    return ds.groupby("time.month").map(get_corr, varname=varname)

Compute

In [None]:
## specify plot variable
plot_var = "T_34"

## compute
corr0 = get_corr_bymonth(forecast_data0, plot_var).sel(lead=slice(1, None))
corr1 = get_corr_bymonth(forecast_data1, plot_var).sel(lead=slice(1, None))

#### Plot difference

Plotting functions

In [None]:
def contourf_skill(ax, skill, **kwargs):
    """filled contour plot of correlation"""

    plot_data = ax.contourf(
        skill.lead,
        skill.month,
        skill,
        **kwargs,
    )

    return plot_data


def contour_cutoff(ax, skill, cutoff=0.5, **kwargs):
    """plot single contour to show cutoff"""

    plot_data = ax.contour(
        skill.lead,
        skill.month,
        skill,
        colors="k",
        levels=[cutoff],
        **kwargs,
    )

    return plot_data


def format_ax(ax):
    """make ax look nicer"""
    for ax in axs:
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
        ax.set_ylabel("Init.")
        ax.set_xlim([None, 24.5])
    axs[-1].set_xticks([1, 8, 16, 24])
    axs[-1].set_xlabel("Lead")

    return

Make the plot

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(4, 4), layout="constrained")

for (
    ax,
    corr,
    ls,
) in zip(axs[:2], [corr0, corr1], ["solid", "dashed"]):

    ## plot
    plot_data0 = contourf_skill(ax, corr, cmap="cmo.amp", levels=np.arange(0, 1.1, 0.1))
    contour_cutoff(ax, corr, linestyles=ls)

## plot difference
plot_data_diff = contourf_skill(
    axs[-1],
    corr1 - corr0,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.5, 0.1),
    extend="both",
    alpha=0.8,
)

## plot cutoffs
contour_cutoff(axs[-1], corr0, linestyles="solid")
contour_cutoff(axs[-1], corr1, linestyles="dashed")

fig.colorbar(plot_data0, ax=axs[:2], ticks=[0, 0.5, 1], label="ACC")
fig.colorbar(plot_data_diff, ax=axs[-1], ticks=[-0.5, 0, 0.5], label="Diff.")

## label
for ax in axs:
    format_ax(ax)

plt.show()

#### Composite evolution of $T$ and $h$

In [None]:
def get_samples(forecast_data, q=0.9, is_warm=True, month=1):
    """function to get quantile of samples from data"""

    ## get Dec. initial conditions and reshape
    data_dec = forecast_data.isel(time=forecast_data.time.dt.month == month)
    data_dec = data_dec.stack(sample=["member", "time"])

    ## get indices to sort data
    T34_init = data_dec["T_34"].isel(lead=0).drop_vars("lead")

    ## select for quantile
    if is_warm:
        cutoff = T34_init.quantile(q=q)
        samples = data_dec.isel(sample=(T34_init >= cutoff))
    else:
        cutoff = T34_init.quantile(q=1 - q)
        samples = data_dec.isel(sample=(T34_init <= cutoff))

    return samples


## get trajectories in early/late period
warm0 = get_samples(forecast_data0, is_warm=True)
warm1 = get_samples(forecast_data1, is_warm=True)

cold0 = get_samples(forecast_data0, is_warm=False)
cold1 = get_samples(forecast_data1, is_warm=False)

In [None]:
def format_axs(axs):
    """format axs in preferred style"""
    for ax in axs[0, :]:
        ax.set_ylim([-2.5, 3])

    for ax in axs[1, :]:
        ax.set_ylim([-2.5, 4])

    for ax in axs[:, 1]:
        ax.set_yticks([])
        ax.yaxis.set_label_position("right")

    for ax in axs[0, :]:
        ax.set_xticks([])

    for ax in axs[1, :]:
        ax.set_xticks([0, 12, 24], labels=[0, 1, 2])

    for ax in axs.flatten():
        ax.axhline(0, c="k", lw=0.7)
        for m in [0, 12, 24]:
            ax.scatter(m, 0, c="k", s=20, zorder=10)

    return

Plot historical vs. future composites

In [None]:
colors = sns.color_palette("colorblind")

fig, axs = plt.subplots(2, 2, figsize=(6, 5))

for lw, ls, q in zip([1.5, 1.5, 4], ["--", "--", "-"], [0.1, 0.9, 0.5]):

    for ci, (warm, label) in enumerate(zip([warm0, warm1], ["hist", "fut"])):
        plot_data = warm.quantile(q=q, dim="sample")
        plot_kwargs = dict(lw=lw, ls=ls, c=colors[ci], alpha=0.8)

        if not (q == 0.5):
            label = None

        axs[0, 0].plot(plot_data.lead, plot_data["T_34"], label=label, **plot_kwargs)
        axs[1, 0].plot(plot_data.lead, -plot_data["h_w"], label=label, **plot_kwargs)

    for ci, warm in enumerate([cold0, cold1]):
        plot_data = warm.quantile(q=q, dim="sample")
        plot_kwargs = dict(lw=lw, ls=ls, c=colors[ci], alpha=0.8)

        axs[0, 1].plot(plot_data.lead, -plot_data["T_34"], **plot_kwargs)
        axs[1, 1].plot(plot_data.lead, plot_data["h_w"], **plot_kwargs)

axs[0, 1].set_ylabel(r"$T$")
axs[1, 1].set_ylabel(r"$h_w$")
axs[0, 0].set_title("El Niño")
axs[0, 1].set_title("La Niña")
axs[0, 0].legend(prop=dict(size=8))
format_axs(axs)

plt.show()

Plot actual vs forecast in hist and future

In [None]:
colors = sns.color_palette("colorblind")

fig, axs = plt.subplots(2, 2, figsize=(5, 4), layout="constrained")

for lw, ls, q in zip([1.5, 1.5, 4], ["--", "--", "-"], [0.1, 0.9, 0.5]):

    for ci, (n, label) in enumerate(zip(["h_w", "h_w_hat"], ["actual", "forecast"])):

        ## args for plotting
        sel = lambda x: x[n].quantile(q=q, dim="sample")
        plot_kwargs = dict(
            lw=lw, ls=ls, c=colors[ci], alpha=0.8, label=(label if (q == 0.5) else None)
        )

        ## plot data
        axs[0, 0].plot(warm0.lead, sel(warm0), **plot_kwargs)
        axs[1, 0].plot(warm1.lead, sel(warm1), **plot_kwargs)

        ## plot data
        axs[0, 1].plot(cold0.lead, -sel(cold0), **plot_kwargs)
        axs[1, 1].plot(cold1.lead, -sel(cold1), **plot_kwargs)


axs[0, 1].set_ylabel(r"Hist")
axs[1, 1].set_ylabel(r"Future")
axs[0, 0].set_title("El Niño")
axs[0, 1].set_title("La Niña")
axs[0, 0].legend(prop=dict(size=8))
format_axs(axs)
for ax in axs.flatten():
    ax.set_ylim([-3, 3])

plt.show()

In [None]:
## get Dec. initial conditions and reshape
data_dec = forecast_data1.isel(time=forecast_data1.time.dt.month == 1)
data_dec = data_dec.stack(sample=["member", "time"])

## get indices to sort data
T34_init = data_dec["T_34"].isel(lead=0).drop_vars("lead")

## select for quantile
cutoff = T34_init.quantile(q=0.1)
samples = data_dec.isel(sample=(T34_init >= cutoff))

Try to de-bug mismatch.  
Looks like months are off-by-one

In [None]:
plt.plot(forecast_data1.isel(lead=1)["T_34"].groupby("time.month").var().mean("member"))
plt.plot(
    forecast_data1.isel(lead=1)["T_34_hat"].groupby("time.month").var().mean("member")
)

In [None]:
samples

In [None]:
PLOT_VAR = "h_w"

import datetime

x = warm1.mean("sample")[["T_34", "h_w"]]
xhat = warm1.mean("sample")[["T_34_hat", "h_w_hat"]]

time_coord = dict(time=[datetime.datetime(1990, 2, 1)])
X0 = x.isel(lead=slice(0, 1)).rename({"lead": "time"}).assign_coords(time_coord)
v0 = model.reforecast(
    fit_ds=fit1,
    init_ds=X0,
    n_month=28,
    ncopy=1000,
    noise_type="red",
)

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(x.lead, x[PLOT_VAR], label="true evo")
ax.plot(x.lead, xhat[f"{PLOT_VAR}_hat"], label="forecast evo")
ax.plot(x.lead, v0[f"{PLOT_VAR}"].mean("member"), ls="--", label="fixed forecast")
ax.legend()
plt.show()

In [None]:
data_

In [None]:
## stochastic simulation data
data = model.simulate(
    fit_ds=fit1,
    X0_ds=Th1.isel(member=0, time=0),
    nyear=len(np.unique(Th1.time.dt.year)),
    ncopy=2,
)

## trim and reshape for forecasting
data_ = data.stack(sample=["member", "time"]).transpose(..., "sample")

## do reforecast
forecast = model.reforecast(
    fit_ds=fit1,
    init_ds=data_,
    n_month=29,
    ncopy=100,
).mean("member")

In [None]:
## massage into xarray
forecast = forecast.swap_dims({"sample": "init"}).unstack("init")
forecast = forecast.rename({n: f"{n}_hat" for n in list(forecast)})
forecast = forecast.isel(time=slice(None, -29))

# #### verification data

In [None]:
## get time index
time_idx = forecast.time
n_time = len(time_idx)

## loop thru lead times
target = []
for l in forecast.lead.values:

    target_ = data.sel(member=forecast.member).isel(time=slice(l, n_time + l))
    target_ = target_.assign_coords(dict(time=time_idx))
    target.append(target_)

## convert to xr
target = xr.concat(target, dim=forecast.lead)
# forecast_data = xr.merge([forecast, target])

In [None]:
forecast.shap

In [None]:
plt.plot(
    target.isel(lead=1).groupby("time.month").var().mean("member")["T_34"],
)
plt.plot(
    forecast.isel(lead=1).groupby("time.month").var().mean("member")["T_34_hat"],
)

In [None]:
forecast.time

In [None]:
data_.time

In [None]:
data.time

Scatter plot of predictions vs actual

In [None]:
lead = 7

plot_data = warm0.isel(lead=lead)
plot_var = "T_34"

fig, ax = plt.subplots(figsize=(3, 3))
ax.set_aspect("equal")

ax.scatter(
    plot_data[plot_var],
    plot_data[f"{plot_var}_hat"],
    s=10,
)

## 1-1 line
min_ = plot_data[plot_var].min().values.item()
max_ = plot_data[plot_var].max().values.item()
z = np.linspace(min_, max_)
ax.plot(z, z, c="k", lw=1.5)

## axes
kwargs = dict(ls="--", c="k", lw=1)
ax.axhline(0, **kwargs)
ax.axvline(0, **kwargs)
ax.set_xlim([-4, 4])
ax.set_ylim([-4, 4])
ax.set_xlabel("actual")
ax.set_ylabel("Predicted")

plt.show()
# ax.set_ylim