# Given weights, find analog states

In [1]:
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
import xarray as xr

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from loaddata import load_cesm2_by_period
from utils import *
from ocean_basin import *
from analog import *
from eval_pred import *

In [2]:
exp, epoch = 'unet4-256_scaled_month01_lr1.5e-05_0', None
test_data = 'test'

out_dir = f'../output/{exp}'
if test_data == 'test':
    data_dir = '../data/cesm2'
elif test_data == 'real':
    data_dir = '../data/real'

In [3]:
# Read history & hyperparameters
out_dir = f'../output/{exp}'
if epoch is None:
    history = pd.read_csv(f'{out_dir}/history.csv', index_col=0)
    epoch = history['val_mse'].argmin()

with open(f'{out_dir}/hyperparameters.json', 'r') as f:
    hp = json.load(f)
    hp = DotDict(hp)

# MA indices
ma_idx_base = xr.open_dataarray(f'{out_dir}/ma_index_{test_data}_epoch{epoch}.nc')
ma_idx_base = ma_idx_base.stack(sample=['ens', 'year'])

In [5]:
# load data
(datasets, dataloaders, 
 t0_library, t0_mask, 
 t1_library) = load_cesm2_by_period(data_dir, **hp, shuffle=False)

# Get a dataset
ds = datasets[test_data].t0_ds

# To tensor
x = torch.from_numpy(ds.to_array().transpose('sample', ...).data)

In [34]:
# Read weights
weight = xr.open_dataset(f'{out_dir}/weight_{test_data}_epoch{epoch}.nc')

# sample dim
weight = weight.stack(sample=['ens', 'year'])

# Normalize weight
weight = weight / weight.to_array().sum(dim=['variable', 'lat', 'lon']) * 100

# Ocean basins
basin_da = basin_mask(weight.isel(sample=0)[hp.vnames[0]])
basins = np.array(['IO', 'PO', 'AO'])

# Select basins to retain
ibasins = np.array([1, 2])
# ibasins = np.array([3])
basin_str = '_'.join(basins[ibasins - 1])
print(basin_str)
cond = xr.concat([(basin_da == ibasin) for ibasin in ibasins], dim='basin').any(dim='basin')
weight = weight.where(cond, 0)

# To tensor
weight = torch.from_numpy(weight.to_array().transpose('sample', ...).data)

IO_PO


In [35]:
# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
t0_library = t0_library.to(device)
t0_mask = t0_mask.to(device)
weight = weight.to(device)
x = x.to(device)

# Mask weight
weight = torch.where(t0_mask[None], 0, weight)

Using cuda device


In [36]:
sample_size = x.shape[0]

# Weighted initial distance
d0 = torch.stack([
    ((x[i] - t0_library) ** 2 * weight[i]).sum(dim=[1, 2, 3])
    for i in range(sample_size)])

# Sort
ma_idx = torch.argsort(d0)

# to cpu
ma_idx = ma_idx.cpu().detach().numpy()

# insample
if test_data == 'train':
    ma_idx = ma_idx[:, 1:]

# To xarray
ma_idx = xr.DataArray(
    ma_idx[:, :100], dims=['sample', 'analog'],
    coords={'sample': ds['sample']})

ma_idx = ma_idx.astype(float).unstack().astype(int)

# save
ma_idx.to_netcdf(f'{out_dir}/ma_index_{test_data}_epoch{epoch}_{basin_str}.nc')

# Evaluate statistics

In [37]:
leads = np.arange(19)
vname = 'sst'

# Read
f = f'{data_dir}/{vname}_anomaly_2x2.nc'
da = xr.open_dataarray(f)

# Analog forecasts
af = get_af_month(
    da, hp.month, hp.periods['train'], 
    ma_idx, hp.n_analog, leads)
afm = af.mean(dim='analog')

# Time stats
t_mse = eval_stats_lead(eval_mse, da, afm, month=hp.month, dim=['ens', 'year'])
t_uac = eval_stats_lead(eval_uac, da, afm, month=hp.month, dim=['ens', 'year'])
t_cac = eval_stats_lead(eval_r, da, afm, month=hp.month, dim=['ens', 'year'])
t_rmsss = eval_stats_lead(eval_rmsss, da, afm, month=hp.month, dim=['ens', 'year'])
t_msss = eval_stats_lead(eval_msss, da, afm, month=hp.month, dim=['ens', 'year'])

# Over the equatorial Pacific
xy_mse = eval_stats_lead(
    eval_mse, da.sel(lat=slice(-10, 10), lon=slice(120, 290)), afm, month=hp.month, dim=['lat', 'lon'])
xy_uac = eval_stats_lead(
    eval_uac, da.sel(lat=slice(-10, 10), lon=slice(120, 290)), afm, month=hp.month, dim=['lat', 'lon'])

# Combine
t_stats = xr.merge([
    t_mse.rename('mse').assign_attrs(long_name='Mean square error'), 
    t_uac.rename('uac').assign_attrs(long_name='Uncentered anomaly correlation'),
    t_cac.rename('cac').assign_attrs(long_name='Centered anomaly correlation'),
    t_rmsss.rename('rmsss').assign_attrs(long_name='Root mean square skill score'),
    t_msss.rename('msss').assign_attrs(long_name='Mean square skill score'),
    ])

xy_stats = xr.merge([
    xy_mse.rename('mse').assign_attrs(long_name='Mean square error'), 
    xy_uac.rename('uac').assign_attrs(long_name='Uncentered anomaly correlation')
    ])

# Save
encoding = {key: {'dtype': 'float32'} for key in list(t_stats.keys())}
t_stats.to_netcdf(f'{out_dir}/{vname}_t_stats_{test_data}_epoch{epoch}_{basin_str}.nc',
                   encoding=encoding)

encoding = {key: {'dtype': 'float32'} for key in list(xy_stats.keys())}
xy_stats.to_netcdf(f'{out_dir}/{vname}_xy_stats_{test_data}_epoch{epoch}_{basin_str}.nc',
                    encoding=encoding)

# check how many analogs overlap with the original

In [38]:
overlap = (ma_idx.isel(analog=slice(0, hp.n_analog)) 
           == ma_idx_base.unstack().isel(analog=slice(0, hp.n_analog)))

overlap_frac = overlap.sum(dim='analog') / hp.n_analog

print(f'Overlap: {overlap_frac.mean().data*100:.2f}%')

Overlap: 5.28%


  multiarray.copyto(res, fill_value, casting='unsafe')
