In [1]:
import os
import sys
import json
import numpy as np
import pandas as pd
import xarray as xr

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from analog import *
from utils import DotDict, to_monthly

In [2]:
# Test data name and output directory
test_data = 'test'
out_dir = '../output/MA'

# parameters
param = {
    'data_dir': '../data/cesm2',
    'vnames': ['sst', 'ssh'],
    'lat_slice': (-30, 30),
    'periods': {
        'library': (1865, 1958),
        'target': (1986, 1998),
    },    
}
param = DotDict(param)

In [3]:
# load data
flst = [f'{param.data_dir}/{vname}_anomaly_2x2.nc' for vname in param.vnames]
ds = xr.open_mfdataset(flst)
ds = ds.sel(lat=slice(*param.lat_slice))

# Find model-analog indices

If library contains (100 ensembles, 100 years),  
- index 0 - (ensemble 1, year 1)  
- index 1 - (ensemble 1, year 2)  
- index 100 - (ensemble 2, year 1)  

In [4]:
# to monthly
ds_month = to_monthly(ds)

# Weight by sqrt(cos(lat)) ~ sqrt(grid area)
wgt = np.sqrt(np.cos(np.deg2rad(ds.lat)))
ds_wgt = ds_month * wgt    

# Scale by domain-averaged monthly std
std = ds_wgt.var(dim=['ens', 'year']).mean(dim=['lat', 'lon']) ** 0.5
ds_wgt_std = ds_wgt / std

# Split data, stack the ens and year dimensions to a sample dimension
t0_library = ds_wgt_std.sel(
    year=slice(*param.periods.library)
).stack(lsample=('ens', 'year')
        ).to_array().transpose('lsample', ...).drop_vars('lsample')

t0_target = ds_wgt_std.sel(
    year=slice(*param.periods.target)
).stack(sample=('ens', 'year')
        ).to_array().transpose('sample', ...)  

In [5]:
%%time
# t0 distance (mse)
lst = [((t0_target.isel(sample=i) - t0_library) ** 2
        ).sum(dim=['variable', 'lat', 'lon'])
       for i in range(t0_target.sample.size)]    
d0 = xr.concat(lst, dim='sample').compute()

# sort along the last axis
ma_idx = np.argsort(d0.data)

# insample
if param.periods.library == param.periods.target:
    ma_idx = ma_idx[:, :, 1:]

# To xarray
ma_idx = xr.DataArray(
    ma_idx[:, :, :1000], dims=['sample', 'month', 'analog'],
    coords={'month': np.arange(1, 13),
            'sample': t0_target['sample']}
)

  x = np.divide(x1, x2, out)


CPU times: user 9h 30min 17s, sys: 5h 42min 16s, total: 15h 12min 33s
Wall time: 13min 43s


In [6]:
# Save
os.makedirs(out_dir, exist_ok=True)
with open(f'{out_dir}/param_{test_data}.json', 'w') as f:
    json.dump(param, f)
    
ma_idx.astype(float).unstack().astype(int).to_netcdf(f'{out_dir}/ma_index_{test_data}.nc')