In [1]:
import os
import sys
import json
import numpy as np
import pandas as pd
import xarray as xr

module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from analog import *
from utils import DotDict, to_monthly

In [2]:
# Test data name and output directory
test_data = 'real'
out_dir = '../output/MA'

# parameters
param = {
    'data_dir': '../data/real',
    'library_dir': '../data/cesm2',
    'vnames': ['sst', 'ssh'],
    'lat_slice': (-30, 30),
    'periods': {
        'library': (1865, 1958),
        'target': (1987, 2020),
    },    
}
param = DotDict(param)

In [3]:
# load target data
flst = [f'{param.data_dir}/{vname}_anomaly_2x2.nc' for vname in param.vnames]
ds = xr.open_mfdataset(flst)
ds = ds.sel(lat=slice(*param.lat_slice))

# load library
flst = [f'{param.library_dir}/{vname}_anomaly_2x2.nc' for vname in param.vnames]
library = xr.open_mfdataset(flst)
library = library.sel(lat=slice(*param.lat_slice))

# Find indices

In [4]:
# to monthly
ds_month = to_monthly(ds)
library_month = to_monthly(library)

# Weight by sqrt(cos(lat)) ~ sqrt(grid area)
wgt = np.sqrt(np.cos(np.deg2rad(ds.lat)))
ds_wgt = ds_month * wgt    
library_wgt = library_month * wgt    

# Scale by domain-averaged monthly std
std = ds_wgt.var(dim='year').mean(dim=['lat', 'lon']) ** 0.5
library_std = library_wgt.var(dim=['ens', 'year']).mean(dim=['lat', 'lon']) ** 0.5
ds_wgt_std = ds_wgt / std
library_wgt_std = library_wgt / library_std

# Split data, stack the ens and year dimensions to a sample dimension
t0_library = library_wgt_std.sel(
    year=slice(*param.periods.library)
).stack(lsample=('ens', 'year')
        ).to_array().transpose('lsample', ...).drop_vars('lsample')

t0_target = ds_wgt_std.sel(
    year=slice(*param.periods.target)
).to_array().transpose('year', ...)

In [5]:
%%time
# t0 distance (mse)
lst = [((t0_target.isel(year=i) - t0_library) ** 2
        ).sum(dim=['variable', 'lat', 'lon'])
       for i in range(t0_target.year.size)]    
d0 = xr.concat(lst, dim='year').compute()

# sort along the last axis
ma_idx = np.argsort(d0.data)

# To xarray
ma_idx = xr.DataArray(
    ma_idx[:, :, :1000], dims=['year', 'month', 'analog'],
    coords={'month': np.arange(1, 13),
            'year': t0_target.year}
)

  x = np.divide(x1, x2, out)


CPU times: user 5min 25s, sys: 2min 55s, total: 8min 21s
Wall time: 4min 29s


In [7]:
# Save
os.makedirs(out_dir, exist_ok=True)
with open(f'{out_dir}/param_{test_data}.json', 'w') as f:
    json.dump(param, f)
    
ma_idx.to_netcdf(f'{out_dir}/ma_index_{test_data}.nc')