## Bessel filter from py-eddy-tracker program
- Use a fixed 700km spatial filter (wavelength)
- Usage of xarray within py-eddy-tracker made possible by Aaron Wienkers, and consequently the use of dask. Much of this example is taken from https://github.com/eerie-project/EERIE_hackathon_2023/tree/main/RESULTS/pyeddytracker_xarray_dask_parallel
- Intake catalog of EERIE data done by Fabian Wachsmann

Feb 2024, Aaron Wienkers (ETHZ) and Dian Putrasahan (MPIM)

In [2]:
import xarray as xr
import numpy as np
from scipy.interpolate import CloughTocher2DInterpolator, LinearNDInterpolator, NearestNDInterpolator
import glob
import intake
import intake_xarray
import dask
import pandas as pd
dask.config.set({"array.slicing.split_large_chunks": True}) 

from py_eddy_tracker.dataset.grid import RegularGridDataset
from datetime import datetime, timedelta
from netCDF4 import Dataset

import io
import os,sys

import warnings
warnings.filterwarnings("ignore")

In [3]:
## Start Parallel Client
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
# Note: Could also use Dask Distributed Client
n_cpu = 64

In [4]:
cat = intake.open_catalog("https://raw.githubusercontent.com/eerie-project/intake_catalogues/main/eerie.yaml")
model = 'ifs-fesom2-sr'
expid = 'eerie-control-1950'
gridspec = 'gr025'
cat_regrid = cat['dkrz.disk.model-output'][model][expid]['ocean'][gridspec]
print(list(cat_regrid))

['daily', 'monthly']


In [5]:
ds = cat_regrid['daily'].to_dask()

In [6]:
ds_subset = ds.sel(time=slice('1950-01-01','1950-12-31'))
# ds_subset = ds
datearr = np.array([pd.Timestamp(t).to_pydatetime() for t in ds_subset.time.values])


In [7]:
varname='sst'
print('High pass filter daily '+varname+' for year='+str(datearr[0].year)+'-'+str(datearr[-1].year))
wavelength=700

scratch = '/scratch/m/m300466/'
datadir = scratch+expid+'/'+gridspec+'/'

if not os.path.exists(datadir+'/'+model):
    os.makedirs(datadir+'/'+model)

vardir=datadir+model+'/'+varname
filtdir=vardir+'/Bessel'
smdatadir=filtdir+'/sm'+str(int(wavelength))+'km/'
if not os.path.exists(vardir):
        os.makedirs(vardir)
if not os.path.exists(filtdir):
        os.makedirs(filtdir)
if not os.path.exists(smdatadir):
        os.makedirs(smdatadir)

High pass filter daily sst for year=1950-1950


In [8]:
def besselhighpass(ncfile, varname, smdatadir, wavelength, date):
    g = RegularGridDataset(None, "lon", "lat", centered=True, nc4file=ncfile)
    if varname=='rho':
        g.bessel_high_filter('rhopoto', wavelength, order=1)
    else:
        g.bessel_high_filter(varname, wavelength, order=1) #perfroms only on 1 time index
    
    if varname=='to' or varname=='so' or varname=='rho':
        zidx=1
        g.write(smdatadir+'/'+varname+'_'+str(zidx)+'_'+date.strftime('%Y%m%d')+'_hp'+str(wavelength)+'.nc')
    else:
        g.write(smdatadir+'/'+varname+'_'+date.strftime('%Y%m%d')+'_hp'+str(int(wavelength))+'.nc')
        
        
# Parallel function wrapper to the for-loop 
def delayed_filter_and_save(date,tt):
    
    # Load data from xarray into netcdf4 type
    da_ssh = ds_subset[varname].isel(time=tt)
    da_ssh.time.encoding.pop("_FillValue",None)
    da_netcdf = Dataset('in-mem-file', mode='r', memory=da_ssh.to_netcdf())
    
    #print('High pass filter of '+varname+' for '+date.strftime('%Y%m%d'))
    besselhighpass(da_netcdf,varname,smdatadir,wavelength,date)
    
    

In [9]:
# Filter & Make all Composites in parallel
#   For some (presumably memory-related reason) this works better with ThreadPool than ProcessPool...
with ThreadPoolExecutor(max_workers=n_cpu) as executor:
    results = list(executor.map(delayed_filter_and_save, datearr, range(len(datearr))))

    