In [1]:
from glob import glob
import numpy as np
import xarray as xr
from matplotlib import pyplot as plt
import xgcm
from xorca.lib import load_xorca_dataset
import pandas as pd
import xesmf as xe
from scipy import ndimage
from scipy.interpolate import interp1d
import pickle
import operator
import copy
import eddytools as et
from cmocean import cm
import dask

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)

%load_ext line_profiler

In [2]:
# paths
datapath = '/scratch/usr/shkifmjr/eddy_tracking'
meshpath = glob(datapath + '/1_mesh_mask.nc')
data_in = sorted(glob(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20??0101_????1231_grid_[TUV].nc'))
data_in_mops = sorted(glob(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20??0101_????1231_mops.nc'))
#data_in_tracer = sorted(glob(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_200[4-9]0101_????1231_tracer.nc'))
data_in_1m = sorted(glob(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_1m_20??0101_????1231_grid_T.nc'))

In [None]:
# define additional variables for xorca
mops_vars = {'O2': {'dims': ['t', 'z_c', 'y_c', 'x_c']}, 'PO4': {'dims': ['t', 'z_c', 'y_c', 'x_c']},
             'NO3': {'dims': ['t', 'z_c', 'y_c', 'x_c']}, 'DIC': {'dims': ['t', 'z_c', 'y_c', 'x_c']},
             'DICP': {'dims': ['t', 'z_c', 'y_c', 'x_c']}, 'ALK': {'dims': ['t', 'z_c', 'y_c', 'x_c']},
             'idealpo4': {'dims': ['t', 'z_c', 'y_c', 'x_c']}, 'CFC12': {'dims': ['t', 'z_c', 'y_c', 'x_c']},
             'SF6': {'dims': ['t', 'z_c', 'y_c', 'x_c']}, 'co2flux': {'dims': ['t', 'y_c', 'x_c']},
             'co2flux_pre': {'dims': ['t', 'y_c', 'x_c']}, 'ph': {'dims': ['t', 'y_c', 'x_c']},
             'ph_pre': {'dims': ['t', 'y_c', 'x_c']}, 'fco2': {'dims': ['t', 'y_c', 'x_c']},
             'fco2_pre': {'dims': ['t', 'y_c', 'x_c']}}

In [None]:
# load data
data = load_xorca_dataset(data_files=data_in, aux_files=meshpath, model_config='NEST', update_orca_variables = mops_vars,
                               input_ds_chunks = {"time_counter": 73, "t": 73, "z": 1, "deptht": 1, "depthu": 1, "depthv": 1, "depthw": 1},
                               target_ds_chunks = {"t": 73, "z_c": 1, "z_l": 1})

In [None]:
data_mops = load_xorca_dataset(data_files=data_in_mops, aux_files=meshpath, model_config='NEST', update_orca_variables = mops_vars,
                               input_ds_chunks = {"time_counter": 73, "t": 73, "z": 1, "deptht": 1, "depthu": 1, "depthv": 1, "depthw": 1},
                               target_ds_chunks = {"t": 73, "z_c": 1, "z_l": 1})

In [None]:
data_tracer = load_xorca_dataset(data_files=data_in_tracer, aux_files=meshpath, model_config='NEST', update_orca_variables = mops_vars,
                               input_ds_chunks = {"time_counter": 73, "t": 73, "z": 1, "deptht": 1, "depthu": 1, "depthv": 1, "depthw": 1},
                               target_ds_chunks = {"t": 73, "z_c": 1, "z_l": 1})

In [None]:
data_1m = load_xorca_dataset(data_files=data_in_1m, aux_files=meshpath, model_config='NEST', update_orca_variables = mops_vars,
                             input_ds_chunks = {"time_counter": 12, "t": 12, "z": 1, "deptht": 1, "depthu": 1, "depthv": 1, "depthw": 1},
                             target_ds_chunks = {"t": 12, "z_c": 1, "z_l": 1})

In [None]:
# define metrics for xgcm
at, au, av, af = data['e1t'] * data['e2t'], data['e1u'] * data['e2u'], data['e1v'] * data['e2v'], data['e1f'] * data['e2f']
vt, vu, vv, vw = data['e3t'] * at, data['e3u'] * au, data['e3v'] * av, data['e3w'] * at

data = data.update({'at': at, 'au': au, 'av': av, 'af': af, 'vt': vt, 'vu': vu, 'vv': vv, 'vw': vw})
data = data.set_coords(['at', 'au', 'av', 'af', 'vt', 'vu', 'vv', 'vw'])

metrics = {
    ('X',): ['e1t', 'e1u', 'e1v', 'e1f'], # X distances
    ('Y',): ['e2t', 'e2u', 'e2v', 'e2f'], # Y distances
    ('Z',): ['e3t', 'e3u', 'e3v', 'e3w'], # Z distances
    ('X', 'Y'): ['at', 'au', 'av', 'af'], # Areas
    ('X', 'Y', 'Z'): ['vt', 'vu', 'vv', 'vw'] # Volumes
}

metrics_noZ = {
    ('X',): ['e1t', 'e1u', 'e1v', 'e1f'], # X distances
    ('Y',): ['e2t', 'e2u', 'e2v', 'e2f'], # Y distances
    ('X', 'Y'): ['at', 'au', 'av', 'af'], # Areas
}

In [None]:
# add bathymetry to data to have depth information
bathy = xr.open_dataset('/scratch/usr/shkifmjr/NUSERDATA/ORION/10-data/bathy_meter/1_bathy_meter__3.6.0_ORION10.L46_Kv1.0.0.nc')
data = data.update({'bathymetry': (['y_c', 'x_c'], bathy['Bathymetry'].data)})

In [None]:
data = xr.merge([data, data_mops])

### OKUBO-WEISS

In [None]:
grid = xgcm.Grid(data, metrics=metrics)

In [None]:
data_OW = et.okuboweiss.calc(data, grid, 'vozocrtx', 'vomecrty').chunk({'x_c': -1, 'x_r': -1, 'y_c': -1, 'y_r': -1})

### INTERPOLATION

In [None]:
# interpolate variables to a regular grid to simplify calculations
interpolation_parameters = {'start_time': '2000-01-01', # time range start
                            'end_time': '2018-12-31', # time range end
                            'lon1': np.floor(data['llon_cc'][0,0].values), # minimum longitude of detection region
                            'lon2': np.ceil(data['llon_cc'][0,-1].values + 360), # maximum longitude
                            'lat1': np.floor(data['llat_cc'].min().values), # minimum latitude
                            'lat2': np.ceil(data['llat_cc'].max().values), # maximum latitude
                            'res': 1./10., # resolution of the fields
                            'vars_to_interpolate': ['OW', 'vort'], # variables to be interpolated 
                            'mask_to_interpolate': ['fmask', 'tmask', 'bathymetry']} # mask to interpolate

In [None]:
data_int_OW = et.interp.horizontal(data_OW.isel(z_c=9, z_l=9), interpolation_parameters)

In [None]:
# save interpolated file so the calculations above do not have to be redone!
data_int_OW.to_netcdf(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20000101_20181231_OW_interpolated_k10.nc', 
                   mode='w', format='NETCDF4_CLASSIC')

In [3]:
# load interpolated file if it has been saved before
data_int = xr.open_dataset(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20000101_20181231_OW_interpolated_k10.nc',
                           chunks={'time': 1, 'lon': 510, 'lat': 200})

In [4]:
data_int_T = xr.open_mfdataset(sorted(glob(datapath
                             + '/interpol/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_REG10_5d_20??0101_20??1231_grid_T.nc')),
                           chunks={'time': 1, 'lon': 510, 'lat': 200})

In [5]:
data_int_mops = xr.open_mfdataset(sorted(glob(datapath 
                                + '/interpol/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_REG10_5d_20??0101_20??1231_mops.nc')),
                           chunks={'time': 1, 'lon': 510, 'lat': 200})

In [6]:
data_int = xr.merge([data_int, data_int_T, data_int_mops])

In [7]:
mean_OW_spatial_std = xr.open_dataset(datapath + '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20000101_20181231_mean_OW_std.nc',
                           chunks={'lon': 510, 'lat': 200})

In [None]:
OW_tmp = data_int['OW'].compute()

In [None]:
OW_tmp = OW_tmp.where(OW_tmp != 0)

In [None]:
mean_OW_spatial_std = OW_tmp.rolling(lon=100, center=True, min_periods=1).std(skipna=True).rolling(lat=100, center=True, min_periods=1).std(skipna=True).mean('time')

In [None]:
mean_OW_spatial_std.to_dataset(name='OW_std').to_netcdf(datapath + 
        '/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20000101_20181231_mean_OW_std.nc', format='NETCDF4_CLASSIC')

In [8]:
data_int = xr.merge([data_int, mean_OW_spatial_std]).chunk({'time': 1})

### DETECTION

In [None]:
# Specify parameters for eddy detection
detection_parameters = {'start_time': '2000-01-01', # time range start
                        'end_time': '2018-12-31', # time range end
                        'lon1': np.floor(data_int['lon'][0].values), # minimum longitude of detection region
                        'lon2': np.ceil(data_int['lon'][-1].values), # maximum longitude
                        'lat1': np.floor(data_int['lat'].min().values), # minimum latitude
                        'lat2': np.ceil(data_int['lat'].max().values), # maximum latitude
                        'min_dep': 1000, # minimum ocean depth where to look for eddies
                        'res': 1./10., # resolution of the fields
                        'OW_thr': -0.0001, # 
                        'OW_thr_name': 'OW_std', #.compute(), # Okubo-Weiss threshold for eddy detection
                        'OW_thr_factor': -0.3,
                        'Npix_min': 15, # minimum number of pixels (grid cells) to be considered as eddy
                        'Npix_max': 2000} # maximum number of pixels (grid cells)

In [None]:
eddies = et.detection.detect(data_int.isel(depth=9), detection_parameters, 'OW', 'vort')

In [None]:
eddies_list = []
for i in np.arange(0, len(eddies)):
    eddies_list.append(eddies[i])

In [None]:
for i in np.arange(0, len(eddies_list)):
    datestring = str(eddies_list[i][0]['time'])[0:10]
    with open('/scratch/usr/shkifmjr/eddy_tracking/eddies/'
          + '1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_' + str(datestring) + '_eddies_OW0.3.pickle', 'wb') as f:
        pickle.dump(eddies_list[i], f, pickle.HIGHEST_PROTOCOL)
    f.close()

### TRACKING

In [9]:
# Specify parameters for eddy tracking
tracking_parameters = {'start_time': '2000-01-03', # time range start
                        'end_time': '2018-12-31', # time range end
                        'dt': 5,
                        'lon1': np.floor(data_int['lon'][0].values), # minimum longitude of detection region
                        'lon2': np.ceil(data_int['lon'][-1].values), # maximum longitude
                        'lat1': np.floor(data_int['lat'].min().values), # minimum latitude
                        'lat2': np.ceil(data_int['lat'].max().values), # maximum latitude
                       'dE': 0., # maximum distance of search ellipsis from eddy center in towards the east 
                                 # (if set to 0, it will be calculated as (150. / (7. / dt)))
                       'eddy_scale_min': 0.5, # minimum factor by which eddy amplitude and area are allowed to change in one timestep
                       'eddy_scale_max': 1.5, # maximum factor by which eddy amplitude and area are allowed to change in one timestep
                       'data_path': datapath + '/', # path to the detected eddies pickle files
                       'file_root': 'eddies/1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d',
                       'file_spec': 'eddies_OW0.3',
                       'ross_path': datapath + '/'} # path to rossrad.dat containing Chelton et 1. 1998 Rossby radii

In [10]:
tracks = et.tracking.track(tracking_parameters)

tracking at time step  155  of  1387
tracking at time step  309  of  1387
tracking at time step  463  of  1387
tracking at time step  617  of  1387
tracking at time step  772  of  1387
tracking at time step  926  of  1387
tracking at time step  1080  of  1387
tracking at time step  1234  of  1387


In [11]:
with open('/scratch/usr/shkifmjr/eddy_tracking/tracks/'
          + '1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20000101_20181231_tracks_OW0.3.pickle', 'wb') as f:
    pickle.dump(tracks, f, pickle.HIGHEST_PROTOCOL)
f.close()

In [None]:
with open('/scratch/usr/shkifmjr/eddy_tracking/'
          + '1_ORION10.L46.LIM2vp.CFCSF6.MOPS.JRA.XIOS2.5-EXP01_5d_20000101_20091231_tracks_OW0.3.pickle', 'rb') as f:
    tracks = pickle.load(f)
f.close()

### SAMPLING

In [None]:
sample_parameters = {'start_time': '2000-01-01', # time range start
                     'end_time': '2000-12-31', # time range end
                     'lon1': -180, #np.floor(data_int['lon'][0].values), # minimum longitude of detection region
                     'lon2': 180, #np.ceil(data_int['lon'][-1].values), # maximum longitude
                     'lat1': np.floor(data_int['lat'].min().values), # minimum latitude
                     'lat2': np.ceil(data_int['lat'].max().values), # maximum latitude
                     'type': 'anticyclonic', # this is cyclonic due to error in detection
                     'lifetime': 30, # length of the eddy's track in days
                     'size': 25, # eddy size (diameter in km)
                     'range': False,
                     'ds_range': 0, #data_int.isel(depth=9).chunk({'lat': 50, 'lon': 50, 'time': 1}), 
                     'var_range': ['votemper'],
                     'value_range': [[-2., 15.]], #[[-1.2, -0.1],], # only sample eddies within this mean var_range range
                     'split': False,
                     'ds_split': 0, #data_int.isel(depth=9).chunk({'lat': 50, 'lon': 50, 'time': 1}),
                     'var_split': ['DIC'],
                     'value_split': [2125,], # split eddies at this value of var_split
                     'sample_vars': ['votemper', 'DIC']
                      } 

In [None]:
test_sample = tracks[0]

In [None]:
%lprun -f add_fields add_fields(test_sample, data_int_cunked, 'votemper')

In [None]:
def add_fields(sampled, interpolated, var):
    # Initialize additional dictionary entries in which to write the fields
    sampled[var] = {}
    sampled[var + '_lon'] = {}
    sampled[var + '_lat'] = {}
    sampled[var + '_sec'] = {}
    sampled[var + '_sec_lon'] = {}
    sampled[var + '_sec_lat'] = {}
    sampled[var + '_around'] = {}
    sampled[var + '_sec_norm_lon'] = {}
    try:
        length = len(sampled['time'])
    except:
        length = 1
    for t in np.arange(0, length):
        # loop over all time steps of the eddy track
        t = int(t)
        if length == 1:
            time = sampled['time']
        else:
            time = sampled['time'][t]
        # get the indeces to use for extraction from `interpolated`
        indeces = np.vstack((sampled['eddy_i'][t],
                             sampled['eddy_j'][t]))
        t_index = np.min(np.where(interpolated['time'].values >= time))
        # add the variable `var` and its coordinates inside the eddy to
        # `sampled[i]`
        sampled[var][t] = interpolated[var][t_index, :,
                                               indeces[1, :], indeces[0, :]]
        sampled[var + '_lon'][t] = sampled[var][t]['lon'].values
        sampled[var + '_lat'][t] = sampled[var][t]['lat'].values
        # add the values of `var` along a zonal section through the middle of
        # the eddy, together with the coordinates
        sampled[var + '_sec'][t] =\
            interpolated[var][t_index, :, int(np.mean(indeces[1, :])),
                              np.min(indeces[0, :]):np.max(indeces[0, :])]
        sampled[var + '_sec_lon'][t] =\
            sampled[var + '_sec'][t]['lon'].values
        sampled[var + '_sec_lat'][t] =\
            sampled[var + '_sec'][t]['lat'].values
        # normalize longitude to the range (-0.5, 0.5) for easier comparison
        # of different eddies
        diff_lon = (sampled[var + '_sec'][t]['lon']
                    - sampled[var + '_sec'][t]['lon'].mean())
        norm_lon = diff_lon / (diff_lon[-1] - diff_lon[0])
        sampled[var + '_sec_norm_lon'][t] = norm_lon.values
        sampled[var + '_sec'][t] = sampled[var + '_sec'][t].values
        # add a depth profile of the values of `var` in the surroundings of
        # the eddy to calculate anomalies
        sampled[var + '_around'][t] =\
            average_surroundings(indeces, interpolated[var], t_index)
        sampled[var][t] = sampled[var][t].values
    return sampled

In [None]:
def average_surroundings(indeces, interpolated, t_index):
    # Calculate the radius of the eddy in "index space"
    radius = int(((np.max(indeces[0, :]) - np.min(indeces[0, :])) / 2))
    # add one radiues in each direction to define what are the surroundings
    imin = np.min(indeces[0, :]) - radius
    imax = np.max(indeces[0, :]) + radius + 1
    jmin = np.min(indeces[1, :]) - radius
    jmax = np.max(indeces[1, :]) + radius + 1
    reduced = interpolated[t_index, ...].copy().squeeze().values
    # Set the values inside the eddy to NaN because we only want the
    # surroundings
    for j in np.arange(len(indeces[0, :])):
        reduced[:, indeces[1, j], indeces[0, j]] = np.nan

    reduced = np.where(reduced == 0, np.nan, reduced)
    # Select the region around the eddy (+- one radius) and average
    around = np.nanmean(reduced[:, jmin:jmax, imin:imax], axis=(1, 2))
    return around

In [None]:
def monotonic_lon(data, lon):
    # Make sure the longitude is monotonically increasing for the interpolation
    if data['lon'][0] > data['lon'][-1]:
        lon_mod = data['lon']\
            .where(data['lon']
                   >= np.around(data['lon'][0].values),
                   other=data['lon'] + 360)
        data['lon'].values = lon_mod.values
        if lon < lon_mod[0]:
            lon = lon + 360
    return data, lon

In [None]:
sample_param = sample_parameters.copy()

In [None]:
#def sample(tracks, data, sample_param):
# Initialize depending on whether `split`: True or False
if sample_param['split']:
    below = {}
    above = {}
    i = 0
    j = 0
else:
    sampled = {}
    i = 0
    j = 0 
# Try to detect time step of eddy tracks in days
try:
    timestep = ((tracks[0]['time'][1] - tracks[0]['time'][0])
                / np.timedelta64(1, 'D'))
except:
    k = 0
    while len(tracks[k]['time']) < 2:
        k = k + 1
    timestep = ((tracks[k]['time'][1] - tracks[k]['time'][0])
                / np.timedelta64(1, 'D'))
# Convert lifetime from days to indeces
lifetime = sample_param['lifetime'] / timestep
start_time = np.datetime64(sample_param['start_time'])
end_time = np.datetime64(sample_param['end_time'])

    #if sample_param['split']:
    #    for ed in np.arange(0, len(tracks)):
    #        below[ed], above[ed] = sample_core(tracks[ed], sample_parameters)
    #    return below, above
    #else:
    #    for ed in np.arange(0, len(tracks)):
    #        sampled[ed] = sample_core(tracks[ed], sample_parameters)
   #     return sampled


In [None]:
def sample_core(tracks, data, sample_param, i, j, lifetime, start_time, end_time):
    if sample_param['split']:
        below = {}
        above = {}
    else:
        sampled = {}
    # loop over all eddies in `tracks`
    # determine first and last time step of each eddy
    try:
        length = len(tracks['time'])
        time0 = np.array(tracks['time'])[0]
        time1 = np.array(tracks['time'])[-1]
    except:
        length = 1
        time0 = np.array(tracks['time'])
        time1 = np.array(tracks['time'])
    # determine lon and lat of eddy
    lon_for_sel = tracks['lon'][0]
    if lon_for_sel > 180.:
        lon_for_sel = lon_for_sel - 360.
    lat_for_sel = tracks['lat'][0]
    # if necessary extract values of of `var_range` and `var_split` at the
    # eddy's location
    if sample_param['range']:
        data_range = sample_param['ds_range']
        data_range, lon_for_sel = monotonic_lon(data_range, lon_for_sel)
        vars_range_ed =\
            data_range[sample_param['var_range'][0]]\
            .sel(time=time0, lat=lat_for_sel,
                 lon=lon_for_sel, method='nearest').values
    if sample_param['split']:
        data_split = sample_param['ds_split']
        data_split, lon_for_sel = monotonic_lon(data_split, lon_for_sel)
        vars_split_ed =\
            data_split[sample_param['var_split'][0]]\
            .sel(time=time0, lat=lat_for_sel,
                 lon=lon_for_sel, method='nearest').values
    # construct bool of all conditions to be met by the eddy to be
    # considered in `sampled`
    conditions_are_met = (
        (time0 >= start_time)
         & (time1 <= end_time)
         & (tracks['type'] == sample_param['type'])
         & (length > lifetime)
         & (tracks['scale'].mean() > sample_param['size'])
         & (tracks['lon'][0] > sample_param['lon1'])
         & (tracks['lon'][0] < sample_param['lon2'])
         & (tracks['lat'][0] > sample_param['lat1'])
         & (tracks['lat'][0] < sample_param['lat2']))
    # add conditions for `range` and `split` if necessary and then add
    # the required fields to the eddy dictionary
    if sample_param['range']:
        conditions_are_met = (
            conditions_are_met
            & (vars_range_ed > sample_param['value_range'][0][0])
            & (vars_range_ed < sample_param['value_range'][0][1]))
    if sample_param['split']:
        if conditions_are_met & (vars_split_ed
                                 < sample_param['value_split'][0]):
            below = tracks.copy()
            for variable in sample_param['sample_vars']:
                below = add_fields(below, data, variable)
            i = i + 1
        elif conditions_are_met & (vars_split_ed
                                   >= sample_param['value_split'][0]):
            above = tracks.copy()
            for variable in sample_param['sample_vars']:
                above = add_fields(above, data, variable)
            j = j + 1
    else:
        if conditions_are_met:
            sampled = tracks.copy()
            for variable in sample_param['sample_vars']:
                sampled = add_fields(sampled, data, variable)
            i = i + 1
    if sample_param['split']:
        return below, above, i, j
    else:
        return sampled, i, j

In [None]:
data_int_cunked = data_int.chunk({'lat': 50, 'lon': 50, 'time': 1})

In [None]:
%lprun -f sample_core sample_core(tracks[4], data_int_cunked, sample_parameters, i, j, lifetime, start_time, end_time)

In [None]:
sampled = et.sample.sample(tracks, data_int.chunk({'lat': 50, 'lon': 50, 'time': 1}), sample_parameters)

In [None]:
plt.pcolormesh(above[0]['vosaline_sec_lon'][5], np.arange(0,46), above[0]['vosaline_sec'][5], vmin=34, vmax=35)

In [None]:
plt.pcolormesh(above[0]['vosaline_sec_lon'][5], np.arange(0,46), above[0]['vosaline_sec'][5] - above[0]['vosaline_around'][5][:,None], 
               vmin=-0.04, vmax=0.04)

#### parallel execution of loops in sample.py????

In [None]:
def add_fields(sampled, interpolated, i, var):
    # Initialize additional dictionary entries in which to write the fields
    sampled[i][var] = {}
    sampled[i][var + '_lon'] = {}
    sampled[i][var + '_lat'] = {}
    sampled[i][var + '_sec'] = {}
    sampled[i][var + '_sec_lon'] = {}
    sampled[i][var + '_sec_lat'] = {}
    sampled[i][var + '_around'] = {}
    sampled[i][var + '_sec_norm_lon'] = {}
    try:
        length = len(sampled[i]['time'])
    except:
        length = 1
    for t in np.arange(0, length):
        # loop over all time steps of the eddy track
        t = int(t)
        if length == 1:
            time = sampled[i]['time']
        else:
            time = sampled[i]['time'][t]
        # get the indeces to use for extraction from `interpolated`
        indeces = np.vstack((sampled[i]['eddy_i'][t],
                             sampled[i]['eddy_j'][t]))
        t_index = np.min(np.where(interpolated['time'].values >= time))
        # add the variable `var` and its coordinates inside the eddy to
        # `sampled[i]`
        sampled[i][var][t] = interpolated[var][t_index, :,
                                               indeces[1, :], indeces[0, :]]
        sampled[i][var + '_lon'][t] = sampled[i][var][t]['lon'].values
        sampled[i][var + '_lat'][t] = sampled[i][var][t]['lat'].values
        # add the values of `var` along a zonal section through the middle of
        # the eddy, together with the coordinates
        sampled[i][var + '_sec'][t] =\
            interpolated[var][t_index, :, int(np.mean(indeces[1, :])),
                              np.min(indeces[0, :]):np.max(indeces[0, :])]
        sampled[i][var + '_sec_lon'][t] =\
            sampled[i][var + '_sec'][t]['lon'].values
        sampled[i][var + '_sec_lat'][t] =\
            sampled[i][var + '_sec'][t]['lat'].values
        # normalize longitude to the range (-0.5, 0.5) for easier comparison
        # of different eddies
        diff_lon = (sampled[i][var + '_sec'][t]['lon']
                    - sampled[i][var + '_sec'][t]['lon'].mean())
        norm_lon = diff_lon / (diff_lon[-1] - diff_lon[0])
        sampled[i][var + '_sec_norm_lon'][t] = norm_lon.values
        sampled[i][var + '_sec'][t] = sampled[i][var + '_sec'][t].values
        # add a depth profile of the values of `var` in the surroundings of
        # the eddy to calculate anomalies
        sampled[i][var + '_around'][t] =\
            average_surroundings(indeces, interpolated[var], t_index)
        sampled[i][var][t] = sampled[i][var][t].values
    return sampled

In [None]:
def average_surroundings(indeces, interpolated, t_index):
    # Calculate the radius of the eddy in "index space"
    radius = int(((np.max(indeces[0, :]) - np.min(indeces[0, :])) / 2))
    # add one radiues in each direction to define what are the surroundings
    imin = np.min(indeces[0, :]) - radius
    imax = np.max(indeces[0, :]) + radius + 1
    jmin = np.min(indeces[1, :]) - radius
    jmax = np.max(indeces[1, :]) + radius + 1
    reduced = interpolated[t_index, ...].copy().squeeze().values
    # Set the values inside the eddy to NaN because we only want the
    # surroundings
    for j in np.arange(len(indeces[0, :])):
        reduced[:, indeces[1, j], indeces[0, j]] = np.nan

    reduced = np.where(reduced == 0, np.nan, reduced)
    # Select the region around the eddy (+- one radius) and average
    around = np.nanmean(reduced[:, jmin:jmax, imin:imax], axis=(1, 2))
    return around

In [None]:
def monotonic_lon(data, lon):
    if data['lon'][0] > data['lon'][-1]:
        lon_mod = data['lon']\
            .where(data['lon']
                   >= np.around(data['lon'][0].values),
                   other=data['lon'] + 360)
        data['lon'].values = lon_mod.values
        if lon < lon_mod[0]:
            lon = lon + 360
    return data, lon

In [None]:
from multiprocess import Process, Manager

In [None]:
def sample(tracks, data, sample_param):
    if sample_param['split']:
        below = {}
        above = {}
        i = 0
        j = 0
    else:
        sampled = {}
        i = 0
    # Try to detect time step of eddy tracks in days
    try:
        timestep = ((tracks[0]['time'][1] - tracks[0]['time'][0])
                    / np.timedelta64(1, 'D'))
    except:
        k = 0
        while len(tracks[k]['time']) < 2:
            k = k + 1
        timestep = ((tracks[k]['time'][1] - tracks[k]['time'][0])
                    / np.timedelta64(1, 'D'))
    # Convert lifetime from days to indeces
    lifetime = sample_param['lifetime'] / timestep
    start_time = np.datetime64(sample_param['start_time'])
    end_time = np.datetime64(sample_param['end_time'])
    data = data.sel(time=slice(sample_param['start_time'],
                               sample_param['end_time']))
    
    manager = Manager()
    sampled = manager.dict()
    job = [Process(target=sample_core, args=(tracks, data, sample_param, lifetime,
                                             start_time, end_time, ed, i, sampled=sampled)) for ed in np.arange(0, len(tracks))]
    _ = [p.start() for p in job]
    _ = [p.join() for p in job]
        
    if sample_param['split']:
        return below, above
    else:
        return sampled

In [None]:
def sample_core(tracks, data, sample_param, lifetime, start_time, end_time, ed, i, j=None, sampled=None, above=None, below=None):
    # loop over all eddies in `tracks`
    # determine first and last time step of each eddy
    try:
        length = len(tracks[ed]['time'])
        time0 = np.array(tracks[ed]['time'])[0]
        time1 = np.array(tracks[ed]['time'])[-1]
    except:
        length = 1
        time0 = np.array(tracks[ed]['time'])
        time1 = np.array(tracks[ed]['time'])
    # determine lon and lat of eddy
    lon_for_sel = tracks[ed]['lon'][0]
    if lon_for_sel > 180.:
        lon_for_sel = lon_for_sel - 360.
    lat_for_sel = tracks[ed]['lat'][0]
    # if necessary extract values of of `var_range` and `var_split` at the
    # eddy's location
    if sample_param['range']:
        data_range = sample_param['ds_range']
        data_range, lon_for_sel = monotonic_lon(data_range, lon_for_sel)
        vars_range_ed =\
            data_range[sample_param['var_range'][0]]\
            .sel(time=time0, lat=lat_for_sel,
                 lon=lon_for_sel, method='nearest').values
    if sample_param['split']:
        data_split = sample_param['ds_split']
        data_split, lon_for_sel = monotonic_lon(data_split, lon_for_sel)
        vars_split_ed =\
            data_split[sample_param['var_split'][0]]\
            .sel(time=time0, lat=lat_for_sel,
                 lon=lon_for_sel, method='nearest').values
    # construct bool of all conditions to be met by the eddy to be
    # considered in `sampled`
    conditions_are_met = (
        (time0 >= start_time)
        & (time1 <= end_time)
        & (tracks[ed]['type'] == sample_param['type'])
        & (length > lifetime)
        & (np.mean(tracks[ed]['scale']) > sample_param['size'])
        & (tracks[ed]['lon'][0] > sample_param['lon1'])
        & (tracks[ed]['lon'][0] < sample_param['lon2'])
        & (tracks[ed]['lat'][0] > sample_param['lat1'])
        & (tracks[ed]['lat'][0] < sample_param['lat2']))
    # add conditions for `range` and `split` if necessary and then add
    # the required fields to the eddy dictionary
    if sample_param['range']:
        conditions_are_met = (
            conditions_are_met
            & (vars_range_ed > sample_param['value_range'][0][0])
            & (vars_range_ed < sample_param['value_range'][0][1]))
    if sample_param['split']:
        if conditions_are_met & (vars_split_ed
                                 < sample_param['value_split'][0]):
            below[i] = tracks[ed].copy()
            for variable in sample_param['sample_vars']:
                below = add_fields(below, data, i, variable)
            i = i + 1
        elif conditions_are_met & (vars_split_ed
                                   >= sample_param['value_split'][0]):
            above[j] = tracks[ed].copy()
            for variable in sample_param['sample_vars']:
                above = add_fields(above, data, j, variable)
            j = j + 1
    else:
        if conditions_are_met:
                sampled[i] = tracks[ed].copy()
                for variable in sample_param['sample_vars']:
                    sampled = add_fields(sampled, data, i, variable)
                i = i + 1
    
    if sample_param['split']:
        return below, above
    else:
        return sampled

# AVERAGING &#9744;

In [None]:
def prepare(sampled, vars, z):
    lon_interp = np.arange(-0.5, 0.51, 0.01)
    max_time = 0
    for ed in np.arange(0, len(sampled)):
        if len(sampled[ed]['time']) > max_time:
            max_time = len(sampled[ed]['time'])
    aves = {}
    for v in vars:
        aves[v] = {}
        for ed in np.arange(0, len(sampled)):
            month = str(sampled[ed]['time'][0])[5:7]
            try:
                aves[v][month] = np.vstack((aves[v][month], np.zeros((1, max_time, len(z), len(lon_interp))) + np.nan))
            except:
                aves[v][month] = np.zeros((1, max_time, len(z), len(lon_interp))) + np.nan
            for m in np.arange(0, len(sampled[ed]['time'])):
                aves[v][month][ed, m, :, :] = interp(sampled[ed], v, m)
    return aves

In [None]:
def interp(sampled, v, t):
    lon_interp = np.arange(-0.5, 0.51, 0.01)
    interp_lon = interp1d(sampled[v + '_sec_norm_lon'][t],
                          sampled[v + '_sec'][t] - sampled[v + '_around'][t][:,None], axis=1, fill_value="extrapolate")
    return interp_lon(lon_interp)

In [None]:
testave = prepare(above, ['votemper', 'DIC'], data_int['z'].values)

In [None]:
lon_interp = np.arange(-0.5, 0.51, 0.01)
max_time = 0
for ed in np.arange(0, len(above)):
    if len(above[ed]['time']) > max_time:
        max_time = len(above[ed]['time'])
aves = {}

In [None]:
v = 'votemper'
aves[v] = {}
ed = 0
month = str(above[ed]['time'][0])[5:7]
z = data_int['z'].values

In [None]:
try:
    aves[v][month] = np.vstack((aves[v][month], np.zeros((1, max_time, len(z), len(lon_interp))) + np.nan))
except:
    aves[v][month] = np.zeros((1, max_time, len(z), len(lon_interp))) + np.nan

In [None]:
m = 0
aves[v][month][ed, m, :, :] = interp(above[ed], v, m)

In [None]:
above[0]['votemper_around'][0]

In [None]:
plt.pcolormesh(testave['votemper']['01'][0,3,:,:])

In [None]:
plt.pcolormesh(np.linspace(-1, 1, 101), data['depth_c'].values, np.nanmean(testave['votemper'][:,0,:,:], axis=0),
              cmap=cm.balance, vmin=-0.6, vmax=0.6)
plt.colorbar()
plt.ylim(-1000, 0)

In [None]:
plt.pcolormesh(np.linspace(-1, 1, 101), data['depth_c'].values, np.nanmean(testave['votemper'][:,10,:,:], axis=0),
              cmap=cm.balance, vmin=-0.6, vmax=0.6)
plt.colorbar()
plt.ylim(-1000, 0)