In [1]:
import os
import sys
import json
import numpy as np
import xarray as xr
import warnings

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils import DotDict, nino_indices
from analog import *
from eval_pred import *

%load_ext autoreload
%autoreload 2

In [2]:
# Parameters
exp = 'test'
out_dir = f'../output/{exp}'
library_dir = '../data/cesm2'
vname = 'sst'

with open(f'{out_dir}/hyperparameters.json', 'r') as f:
    hp = json.load(f)
    hp = DotDict(hp)

data_dir = '../data/real'
test_data = 'real'

if vname == 'pr':
    grid = '2.5x2.5'
    leads = np.arange(13)
else:
    grid = '2x2'
    leads = np.arange(19)

# Data for analog forecasting

In [3]:
# load target data
if vname == 'pr':
    f = f'{data_dir}/{vname}_anomaly_{grid}_1987-2016.nc'
else:
    f = f'{data_dir}/{vname}_anomaly_{grid}.nc'
ref = xr.open_dataarray(f)

# load library
f = f'{library_dir}/{vname}_anomaly_{grid}.nc'
library = xr.open_dataarray(f)

# Variable skill

In [4]:
%%time
for epoch in range(hp.n_epochs):
    f = f'{out_dir}/ma_index_{test_data}_epoch{epoch}.nc'
    if not os.path.exists(f):
        continue

    ma_idx = xr.open_dataarray(f)

    if vname == 'pr':
        ma_idx = ma_idx.sel(year=slice(1987, 2015))

    # Get analog forecasts
    af = get_af_month(library, hp.month, hp.periods['train'], 
                      ma_idx, hp.n_analog, leads)
    afm = af.mean(dim='analog')        

    # Time stats
    t_mse = eval_stats_lead(eval_mse, ref, afm, month=hp.month, dim='year')
    t_uac = eval_stats_lead(eval_uac, ref, afm, month=hp.month, dim='year')
    t_cac = eval_stats_lead(eval_r, ref, afm, month=hp.month, dim='year')
    t_rmsss = eval_stats_lead(eval_rmsss, ref, afm, month=hp.month, dim='year')
    t_msss = eval_stats_lead(eval_msss, ref, afm, month=hp.month, dim='year')

    # Over the target region
    xy_mse = eval_stats_lead(
        eval_mse, ref.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
        afm, month=hp.month, dim=['lat', 'lon'])
    xy_uac = eval_stats_lead(
        eval_uac, ref.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
        afm, month=hp.month, dim=['lat', 'lon'])

    # Ensemble spread (time-mean)
    # Suppress runtime warning for empty array
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', category=RuntimeWarning)
        t_std = af.var(dim='analog').mean(dim='year') ** 0.5
    t_std = t_std.rename('std').assign_attrs(long_name='Ensemble spread') 

    # Probablistic stats
    t_crps = eval_stats_lead(
        eval_crps_decomp, 
        ref.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
        af.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
        month=hp.month, dim='year')

    # Combine
    t_stats = xr.merge([
        t_mse.rename('mse').assign_attrs(long_name='Mean square error'), 
        t_uac.rename('uac').assign_attrs(long_name='Uncentered anomaly correlation'),
        t_cac.rename('cac').assign_attrs(long_name='Centered anomaly correlation'),
        t_rmsss.rename('rmsss').assign_attrs(long_name='Root mean square skill score'),
        t_msss.rename('msss').assign_attrs(long_name='Mean square skill score'),
        ])

    xy_stats = xr.merge([
        xy_mse.rename('mse').assign_attrs(long_name='Mean square error'), 
        xy_uac.rename('uac').assign_attrs(long_name='Uncentered anomaly correlation')
        ])

    # Save
    encoding = {key: {'dtype': 'float32'} for key in list(t_stats.keys())}
    t_stats.to_netcdf(f'{out_dir}/{vname}_t_stats_{test_data}_epoch{epoch}.nc', encoding=encoding)

    encoding = {key: {'dtype': 'float32'} for key in list(xy_stats.keys())}
    xy_stats.to_netcdf(f'{out_dir}/{vname}_xy_stats_{test_data}_epoch{epoch}.nc', encoding=encoding)

    encoding = {'std': {'dtype': 'float32'}}
    t_std.to_netcdf(f'{out_dir}/{vname}_t_std_{test_data}_epoch{epoch}.nc', encoding=encoding) 

    encoding = {key: {'dtype': 'float32'} for key in list(t_crps.keys())}
    t_crps.to_netcdf(f'{out_dir}/{vname}_t_crps_{test_data}_epoch{epoch}.nc', encoding=encoding)
    print(f'Epoch {epoch} saved')

Epoch 9 saved
Epoch 19 saved
Epoch 29 saved
Epoch 39 saved
Epoch 49 saved
Epoch 51 saved
Epoch 59 saved
CPU times: user 59.3 s, sys: 46.3 s, total: 1min 45s
Wall time: 1min 50s


# Nino

In [5]:
if vname == 'sst':
    nino_ref = nino_indices(ref)
    nino_library = nino_indices(library)

    encoding = {key: {'dtype': 'float32'} for key in list(nino_ref.keys())}

    for epoch in range(hp.n_epochs):
        f = f'{out_dir}/ma_index_{test_data}_epoch{epoch}.nc'
        if not os.path.exists(f):
            continue

        ma_idx = xr.open_dataarray(f)

        nino_af = get_af_month(nino_library, hp.month, hp.periods['train'], ma_idx, hp.n_analog, leads)
        nino_afm = nino_af.mean(dim='analog')

        nino_t_uac = eval_stats_lead(eval_uac, nino_ref, nino_afm, month=hp.month, dim='year')
        nino_t_mse = eval_stats_lead(eval_mse, nino_ref, nino_afm, month=hp.month, dim='year')
        nino_t_rmsss = eval_stats_lead(eval_rmsss, nino_ref, nino_afm, month=hp.month, dim='year')
        nino_t_std = nino_af.var(dim='analog').mean(dim='year') ** 0.5
        nino_t_crps = eval_stats_lead(eval_crps_decomp, nino_ref['nino34'], nino_af['nino34'], 
                                      month=hp.month, dim='year')
        
        # Save
        encoding = {key: {'dtype': 'float32'} for key in list(nino_ref.keys())}
        nino_t_uac.to_netcdf(f'{out_dir}/nino_t_uac_{test_data}_epoch{epoch}.nc', encoding=encoding)
        nino_t_mse.to_netcdf(f'{out_dir}/nino_t_mse_{test_data}_epoch{epoch}.nc', encoding=encoding)
        nino_t_rmsss.to_netcdf(f'{out_dir}/nino_t_rmsss_{test_data}_epoch{epoch}.nc', encoding=encoding)
        nino_t_std.to_netcdf(f'{out_dir}/nino_t_std_{test_data}_epoch{epoch}.nc', encoding=encoding)

        encoding = {key: {'dtype': 'float32'} for key in list(nino_t_crps.keys())}
        nino_t_crps.to_netcdf(f'{out_dir}/nino34_t_crps_{test_data}_epoch{epoch}.nc', encoding=encoding)

        print(f'Epoch {epoch} saved')


Epoch 9 saved
Epoch 19 saved
Epoch 29 saved
Epoch 39 saved
Epoch 49 saved
Epoch 51 saved
Epoch 59 saved
