In [49]:
import os
import sys
import json
import numpy as np
import pandas as pd
import xarray as xr
import warnings

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils import DotDict, nino_indices
from analog import *
from eval_pred import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:
# Parameters
# exp_base = 'unet4-256_scaled_month01_lr1.5e-05'
exp_base = 'unet4-256_scaled_month09_lr2.0e-05'
common_epoch = 9
i_ens = 0
n_ens = 10

library_dir = '../data/cesm2'
vname = 'sst'

with open(f'../output/{exp_base}_{i_ens}/hyperparameters.json', 'r') as f:
    hp = json.load(f)
    hp = DotDict(hp)

data_dir = '../data/real'
test_data = 'real'

if vname == 'pr':
    grid = '2.5x2.5'
    leads = np.arange(13)
else:
    grid = '2x2'
    leads = np.arange(19)

# Out dir
out_dir = f'../output/{exp_base}_gens{i_ens}-{i_ens+n_ens-1}'
print(out_dir)
os.makedirs(out_dir, exist_ok=True)

../output/unet4-256_scaled_month09_lr2.0e-05_gens0-9


# Data for analog forecasting

In [51]:
# load target data
if vname == 'pr':
    f = f'{data_dir}/{vname}_anomaly_{grid}_1987-2016.nc'
else:
    f = f'{data_dir}/{vname}_anomaly_{grid}.nc'
ref = xr.open_dataarray(f)

# load library
f = f'{library_dir}/{vname}_anomaly_{grid}.nc'
library = xr.open_dataarray(f)

# MA indices

In [52]:
# MA indices
lst = []
for i in range(i_ens, i_ens + n_ens):
    exp = f'{exp_base}_{i}'
    in_dir = f'../output/{exp}'

    if common_epoch is None:
        history = pd.read_csv(f'{in_dir}/history.csv', index_col=0)
        epoch = history['val_mse'].argmin()
    else:
        epoch = common_epoch

    f = f'{in_dir}/ma_index_{test_data}_epoch{epoch}.nc'
    ma_idx = xr.open_dataarray(f)

    lst.append(ma_idx.sel(analog=slice(0, hp.n_analog)))
    
ma_idx = xr.concat(lst, dim='analog')

In [53]:
# Save
hp.n_analog = n_ens * hp.n_analog
if not os.path.exists(f'{out_dir}/hyperparameters.json'):
    with open(f'{out_dir}/hyperparameters.json', 'w') as f:
        json.dump(hp, f)

if common_epoch is None:
    ma_idx.to_netcdf(f'{out_dir}/ma_index_{test_data}.nc')
else:
    ma_idx.to_netcdf(f'{out_dir}/ma_index_{test_data}_epoch{common_epoch}.nc')

# Variable skill

In [54]:
%%time
if vname == 'pr':
    ma_idx = ma_idx.sel(year=slice(1987, 2015))

# Get analog forecasts
af = get_af_month(library, hp.month, hp.periods['train'], 
                    ma_idx, hp.n_analog, leads)
afm = af.mean(dim='analog')        

# Time stats
t_mse = eval_stats_lead(eval_mse, ref, afm, month=hp.month, dim='year')
t_uac = eval_stats_lead(eval_uac, ref, afm, month=hp.month, dim='year')
t_cac = eval_stats_lead(eval_r, ref, afm, month=hp.month, dim='year')
t_rmsss = eval_stats_lead(eval_rmsss, ref, afm, month=hp.month, dim='year')
t_msss = eval_stats_lead(eval_msss, ref, afm, month=hp.month, dim='year')

# Over the target region
xy_mse = eval_stats_lead(
    eval_mse, ref.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
    afm, month=hp.month, dim=['lat', 'lon'])
xy_uac = eval_stats_lead(
    eval_uac, ref.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
    afm, month=hp.month, dim=['lat', 'lon'])

# Ensemble spread (time-mean)
# Suppress runtime warning for empty array
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    t_std = af.var(dim='analog').mean(dim='year') ** 0.5
t_std = t_std.rename('std').assign_attrs(long_name='Ensemble spread') 

# Probablistic stats
t_crps = eval_stats_lead(
    eval_crps_decomp, 
    ref.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
    af.sel(lat=slice(*hp.target_lat_slice), lon=slice(*hp.target_lon_slice)), 
    month=hp.month, dim='year')

# Combine
t_stats = xr.merge([
    t_mse.rename('mse').assign_attrs(long_name='Mean square error'), 
    t_uac.rename('uac').assign_attrs(long_name='Uncentered anomaly correlation'),
    t_cac.rename('cac').assign_attrs(long_name='Centered anomaly correlation'),
    t_rmsss.rename('rmsss').assign_attrs(long_name='Root mean square skill score'),
    t_msss.rename('msss').assign_attrs(long_name='Mean square skill score'),
    ])

xy_stats = xr.merge([
    xy_mse.rename('mse').assign_attrs(long_name='Mean square error'), 
    xy_uac.rename('uac').assign_attrs(long_name='Uncentered anomaly correlation')
    ])

# Save
if common_epoch is None:
    t_stats_f = f'{out_dir}/{vname}_t_stats_{test_data}.nc'
    xy_stats_f = f'{out_dir}/{vname}_xy_stats_{test_data}.nc'
    t_std_f = f'{out_dir}/{vname}_t_std_{test_data}.nc'
    t_crps_f = f'{out_dir}/{vname}_t_crps_{test_data}.nc'
else:
    t_stats_f = f'{out_dir}/{vname}_t_stats_{test_data}_epoch{epoch}.nc'
    xy_stats_f = f'{out_dir}/{vname}_xy_stats_{test_data}_epoch{epoch}.nc'
    t_std_f = f'{out_dir}/{vname}_t_std_{test_data}_epoch{epoch}.nc'
    t_crps_f = f'{out_dir}/{vname}_t_crps_{test_data}_epoch{epoch}.nc'

encoding = {key: {'dtype': 'float32'} for key in list(t_stats.keys())}
t_stats.to_netcdf(t_stats_f, encoding=encoding)

encoding = {key: {'dtype': 'float32'} for key in list(xy_stats.keys())}
xy_stats.to_netcdf(xy_stats_f, encoding=encoding)

encoding = {'std': {'dtype': 'float32'}}
t_std.to_netcdf(t_std_f, encoding=encoding) 

encoding = {key: {'dtype': 'float32'} for key in list(t_crps.keys())}
t_crps.to_netcdf(t_crps_f, encoding=encoding)

CPU times: user 27.8 s, sys: 8.42 s, total: 36.2 s
Wall time: 36.9 s


# Nino

In [55]:
if vname == 'sst':
    nino_ref = nino_indices(ref)
    nino_library = nino_indices(library)

    encoding = {key: {'dtype': 'float32'} for key in list(nino_ref.keys())}

    nino_af = get_af_month(nino_library, hp.month, hp.periods['train'], ma_idx, hp.n_analog, leads)
    nino_afm = nino_af.mean(dim='analog')

    nino_t_uac = eval_stats_lead(eval_uac, nino_ref, nino_afm, month=hp.month, dim='year')
    nino_t_mse = eval_stats_lead(eval_mse, nino_ref, nino_afm, month=hp.month, dim='year')
    nino_t_rmsss = eval_stats_lead(eval_rmsss, nino_ref, nino_afm, month=hp.month, dim='year')
    nino_t_std = nino_af.var(dim='analog').mean(dim='year') ** 0.5
    nino_t_crps = eval_stats_lead(eval_crps_decomp, nino_ref['nino34'], nino_af['nino34'], 
                                    month=hp.month, dim='year')
    
    # Save
    if common_epoch is None:
        encoding = {key: {'dtype': 'float32'} for key in list(nino_ref.keys())}
        nino_t_uac.to_netcdf(f'{out_dir}/nino_t_uac_{test_data}.nc', encoding=encoding)
        nino_t_mse.to_netcdf(f'{out_dir}/nino_t_mse_{test_data}.nc', encoding=encoding)
        nino_t_rmsss.to_netcdf(f'{out_dir}/nino_t_rmsss_{test_data}.nc', encoding=encoding)
        nino_t_std.to_netcdf(f'{out_dir}/nino_t_std_{test_data}.nc', encoding=encoding)

        encoding = {key: {'dtype': 'float32'} for key in list(nino_t_crps.keys())}
        nino_t_crps.to_netcdf(f'{out_dir}/nino34_t_crps_{test_data}.nc', encoding=encoding)
    else:
        encoding = {key: {'dtype': 'float32'} for key in list(nino_ref.keys())}
        nino_t_uac.to_netcdf(f'{out_dir}/nino_t_uac_{test_data}_epoch{epoch}.nc', encoding=encoding)
        nino_t_mse.to_netcdf(f'{out_dir}/nino_t_mse_{test_data}_epoch{epoch}.nc', encoding=encoding)
        nino_t_rmsss.to_netcdf(f'{out_dir}/nino_t_rmsss_{test_data}_epoch{epoch}.nc', encoding=encoding)
        nino_t_std.to_netcdf(f'{out_dir}/nino_t_std_{test_data}_epoch{epoch}.nc', encoding=encoding)

        encoding = {key: {'dtype': 'float32'} for key in list(nino_t_crps.keys())}
        nino_t_crps.to_netcdf(f'{out_dir}/nino34_t_crps_{test_data}_epoch{epoch}.nc', encoding=encoding)