# Test the significance of improvements from MA

In [1]:
import os
import sys
import json
import xarray as xr
from scipy.stats import permutation_test

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils import *
from analog import *
from resample import *

%load_ext autoreload
%autoreload 2

# Prepare data

In [2]:
# Parameters
# exp = 'test'
exp = 'ch256_sst_ssh_taux_scaled_month01_lr1.5e-05-0'
out_dir = f'../output/{exp}'
library_dir = '../data/cesm2'
data_dir = '../data/real'
test_data = 'real'

vname = 'sst'
leads = np.arange(3, 15, 3)
n_resamples = 10000           # Number of resampling
batch = int(n_resamples/50)

with open(f'{out_dir}/hyperparameters.json', 'r') as f:
    hp = json.load(f)
    hp = DotDict(hp)

history = pd.read_csv(f'{out_dir}/history.csv', index_col=0)
epoch = history['val_mse'].argmin()

In [3]:
if vname == 'pr':
    grid = '2.5x2.5'
else:
    grid = '2x2'

# Load target data
if vname == 'pr':
    f = f'{data_dir}/{vname}_anomaly_{grid}_1987-2016.nc'
else:
    f = f'{data_dir}/{vname}_anomaly_{grid}.nc'
da = xr.open_dataarray(f)

# load library
f = f'{library_dir}/{vname}_anomaly_{grid}.nc'
library = xr.open_dataarray(f)

# # Region for evaluating skills
# da = da.sel(lat=slice(-10, 10), lon=slice(120, 290))
# library = library.sel(lat=slice(-10, 10), lon=slice(120, 290))

In [4]:
# Base MA
print('Base MA')
f = f'../output/MA/ma_index_{test_data}.nc'
ma_idx = xr.open_dataarray(f)
ma_idx = ma_idx.sel(month=hp.month)

if vname == 'pr':
    ma_idx = ma_idx.sel(year=slice(1987, 2015))

af_base = get_af_month(
    library, hp.month, hp.periods['train'], 
    ma_idx, hp.n_analog, leads)
afm_base = af_base.mean(dim='analog')

Base MA


In [5]:
# MA-ML exp
print(f'{exp} at epoch {epoch}')
f = f'{out_dir}/ma_index_{test_data}_epoch{epoch}.nc'
ma_idx = xr.open_dataarray(f)

if vname == 'pr':
    ma_idx = ma_idx.sel(year=slice(1987, 2015))

af = get_af_month(
    library, hp.month, hp.periods['train'], 
    ma_idx, hp.n_analog, leads)
afm = af.mean(dim='analog')

ch256_sst_ssh_taux_scaled_month01_lr1.5e-05-0 at epoch 45


# Permutation of ACU difference

In [6]:
%%time
t_acu_diff = permute_acu_diff(
    da, afm, afm_base, ['year'], 
    month=hp.month, 
    n_resamples=n_resamples, batch=batch,
)

ACU diff: 3-month lead
ACU diff: 6-month lead


In [7]:
# Save
encoding = {key: {'dtype': 'float32'} for key in list(t_acu_diff.keys())}
outf = f'{out_dir}/{vname}_t_acu_diff_{test_data}_epoch{epoch}.nc'
t_acu_diff.to_netcdf(outf, encoding=encoding)
print(outf)

../output/ch256_sst_ssh_taux_scaled_month01_lr1.5e-05-0/sst_t_acu_diff_real_epoch45.nc


# Permutation of RMSE skill

In [8]:
%%time
t_nrmse = permute_nrmse(
    da, afm, afm_base, ['year'], 
    month=hp.month, 
    n_resamples=n_resamples, batch=batch,
)

NRMSE: 3-month lead
NRMSE: 6-month lead
NRMSE: 9-month lead
NRMSE: 12-month lead
CPU times: user 8.51 s, sys: 1.76 s, total: 10.3 s
Wall time: 10.3 s


In [9]:
# Save
encoding = {key: {'dtype': 'float32'} for key in list(t_nrmse.keys())}
outf = f'{out_dir}/{vname}_t_nrmse_{test_data}_epoch{epoch}.nc'
t_nrmse.to_netcdf(outf, encoding=encoding)
print(outf)

../output/ch256_sst_ssh_taux_scaled_month01_lr1.5e-05-0/sst_t_nrmse_real_epoch45.nc
