# Persistent forecast

In [None]:
import os, argparse
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy as ctp

import hyblim.geoplot as gpl
from hyblim.data import eof, preproc
from hyblim.utils import metric, eval, enso


## Load data

In [None]:
datapaths = {}
datapaths['ssta'] = "../../data/cesm2-picontrol/b.e21.B1850.f09_g17.CMIP6-piControl.001.pop.h.ssta_lat-31_33_lon130_290_gr1.0.nc"
datapaths['ssha'] = "../../data/cesm2-picontrol/b.e21.B1850.f09_g17.CMIP6-piControl.001.pop.h.ssha_lat-31_33_lon130_290_gr1.0.nc"
n_eof = [20, 10]

# Load data
print("Load data!", flush=True)
normalizer = dict()
da_arr = []
for var, path in datapaths.items():
    da = xr.open_dataset(path)[var]
    # Normalize data 
    norm = preproc.Normalizer()
    da = norm.fit_transform(da)
    # Store normalizer as an attribute in the Dataarray for the inverse transformation
    da.attrs = norm.to_dict()
    da_arr.append(da)
    normalizer[var] = norm

ds = xr.merge(da_arr)

# Apply land sea mask
lsm = xr.open_dataset("../../data/land_sea_mask_common.nc")['lsm']
ds = ds.where(lsm!=1, other=np.nan)

# Train, val, test split
train_period = (0, int(0.8*len(ds['time'])))
val_period = (int(0.8*len(ds['time'])), int(0.9*len(ds['time'])))
test_period = (int(0.9*len(ds['time'])), len(ds['time'])) 

data = dict(
    train = ds.isel(time=slice(*train_period)),
    val = ds.isel(time=slice(*val_period)),
    test = ds.isel(time=slice(*test_period)),
)

## Persistence forecast

In [None]:
datasplit = 'test'
lag_arr = [1, 3, 6, 9, 12, 15, 18, 21, 24]

ds_eval = data[datasplit]
verification_per_gridpoint, verification_per_time, nino_indices = [], [], []
for lag in lag_arr:
    print(f"Compute metrics for lag {lag}", flush=True)
    x_target = ds_eval.isel(time=slice(lag, None))
    x_frcst = ds_eval.isel(time=slice(None, -lag))
    x_frcst['time'] = x_target['time']

    # Unnormalize
    for var in x_target.data_vars:
        x_target[var] = normalizer[var].inverse_transform(x_target[var])
        x_frcst[var] = normalizer[var].inverse_transform(x_frcst[var])

    # Compute metrics per gridpoint
    grid_verif = metric.verification_metrics_per_gridpoint(
        x_target, x_frcst, None
    )
    grid_verif['lag'] = lag
    verification_per_gridpoint.append(grid_verif)
    
    # Compute metrics per time
    time_verif = metric.verification_metrics_per_time(
        x_target, x_frcst, None
    )
    time_verif['lag'] = lag
    verification_per_time.append(time_verif)

    # Nino indices
    nino_index = {
        'target': enso.get_nino_indices(x_target['ssta']),
        'frcst': enso.get_nino_indices(x_frcst['ssta']),
        'lag': lag
    }
    nino_indices.append(nino_index)

verification_per_gridpoint = metric.listofdicts_to_dictofxr(verification_per_gridpoint, dim_key='lag')
verification_per_time = metric.listofdicts_to_dictofxr(verification_per_time, dim_key='lag')
nino_indices = metric.listofdicts_to_dictofxr(nino_indices, dim_key='lag')

In [None]:
# Save metrics to file
print("Save metrics to file!", flush=True)
scorepath = f"../../models/persistence/"
if not os.path.exists(scorepath):
    os.makedirs(scorepath)

for key, score in verification_per_gridpoint.items():
    score.to_netcdf(scorepath + f"/gridscore_{key}_{datasplit}.nc")
for key, score in verification_per_time.items():
    score.to_netcdf(scorepath + f"/timescore_{key}_{datasplit}.nc")
for key, nino_idx in nino_indices.items():
    nino_idx.to_netcdf(scorepath + f"/nino_{key}_{datasplit}.nc")

## Plotting

In [None]:
# Plotting
scorekey = 'cc'
lag = 12
vars = ['ssta', 'ssha']
plparam  = {
    'mse' : {'ssta': dict(cmap='plasma', vmin=0, vmax=2, eps=0.25),
             'ssha': dict(cmap='plasma', vmin=0, vmax=150, eps=10)},
    'rmsess' : {'ssta': dict(cmap='RdGy_r', vmin=-.9, vmax=.9, eps=0.1),
                    'ssha': dict(cmap='RdGy_r', vmin=-.9, vmax=.9, eps=0.1)},
    'cc' : {'ssta': dict(cmap='RdBu_r', vmin=-1.0, vmax=1.0, eps=0.1, centercolor="#FFFFFF"),
            'ssha': dict(cmap='RdBu_r', vmin=-1.0, vmax=1.0, eps=0.1, centercolor="#FFFFFF")},
    'crpss' : {'ssta': dict(cmap='viridis', vmin=0, vmax=.5, eps=0.05),
              'ssha': dict(cmap='viridis', vmin=0, vmax=1.0, eps=.1)},
}


ncols = len(vars)
nrows = 1
fig = plt.figure(figsize=(ncols*5, nrows*2.5))

for i, var in enumerate(vars):
    score = verification_per_gridpoint[scorekey][var]
    ax = fig.add_subplot(nrows, ncols, i+1, projection=ctp.crs.PlateCarree(central_longitude=180))
    im = gpl.plot_map(score.sel(lag=lag), ax=ax, **plparam[scorekey][var], add_bar=False)

    ax.set_title(f"{var}")
        
    # Colorbar under each column
    axwidth = ax.get_position().width
    cbar_ax = fig.add_axes([ 0.5 * axwidth + i * (axwidth*1.3), -0.01, 0.8*axwidth, 0.02])
    cb = fig.colorbar(im['im'], cax=cbar_ax, orientation='horizontal', extend='both')
    cb.set_label(label=rf"{scorekey} {vars[i]}")

# Title
fig.suptitle(rf"$\tau$ = {lag}")
