# High pass filter other fields (already on 0.25deg)

In [None]:
from py_eddy_tracker.dataset.grid import RegularGridDataset
from datetime import datetime, timedelta
import numpy as np
from netCDF4 import Dataset
from matplotlib import pyplot as plt
import xarray as xr
#Read in example SSH data that has been mapped onto a 0.25deg regular grid.

expid='erc1011'
varname='to'
fq='dm'

# path to access ssh data to identify eddies
datadir = '/work/bm1344/k203123/reg25/erc1011/'+varname+'/'

In [None]:
import glob
#find datafiles
datafiles = sorted(glob.glob(datadir+"*.nc"))
print('# data files for identifying eddies: ', len(datafiles))
print('datafiles for identifying eddies: ', datafiles)
#create datetime objects for 2002 to 2008; each year one entry in list
datearrs = []
for x in range(len(datafiles)):
    yyyy=int(2002+x)
    datearrs.append(np.arange(datetime(yyyy,1,1), datetime(yyyy+1,1,1), timedelta(days=1)).astype(datetime))
print('datearrs: ', datearrs)

### Start SLURM cluster

In [None]:
import dask
from dask_jobqueue import SLURMCluster
from dask.distributed import Client

In [None]:
dask.config.config.get('distributed').get('dashboard').update({'link':'{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'})

In [None]:
cluster = SLURMCluster(name='dask-cluster',
                      cores=10,
                      memory='256GB',
                      processes=5,
                      interface='ib0',
                      queue='compute',
                      account='mh0033',
                      walltime='01:00:00',
                      asynchronous=0)

In [None]:
cluster.scale(cores=200)
client = Client(cluster)
client

### Define high pass filter 

In [None]:
def besselhighpass(varfile, varname, datearr, tt, wavelength):
    #wavelength: choice of spatial cutoff for high pass filter in km
    outdir = '/path/to/output/data/erc1011_eddytrack/to_hbp/'+'wv_'+str(int(wavelength))+'/'
    step_ht=0.005 #intervals to search for closed contours (5mm in this case)
    g = RegularGridDataset(varfile, "lon", "lat", centered=True, indexs = dict(time=tt))
    g.dimensions['time']=1  #extracts only one time step that was specified by indexs = dict(time=tt)
    if varname=='rho':
        g.bessel_high_filter('rhopoto', wavelength, order=1)
    else:
        g.bessel_high_filter(varname, wavelength, order=1) #perfroms only on 1 time index
    date = datearr[tt] # detect each timestep individually because of memory issues
    if varname=='to' or varname=='so' or varname=='rho':
        zidx=1
        g.write(outdir+expid+'_'+varname+'_'+str(zidx)+'_'+fq+'_'+date.strftime('%Y%m%d')+'_IFS25_hp'+str(wavelength)+'.nc')
    else:
        g.write(outdir+expid+'_'+varname+'_'+fq+'_'+date.strftime('%Y%m%d')+'_IFS25_hp'+str(int(wavelength))+'.nc')


### Apply high pass filter to data

In [None]:
#looping over wavelengths for high band pass filter
for wavelength in [200,700]:
    print('wavelength = ', wavelength)
    # looping over year (2002,2003...)
    for i in range(len(datearrs)):
        print('year = ', datearrs[i][0].year)
        ntsteps_per_loop = 61
        ntsteps = len(datearrs[i])
        tcounter = np.zeros((ntsteps//ntsteps_per_loop)+2)
        tcounter[:-1] = np.arange(0,(ntsteps//ntsteps_per_loop)+1)*ntsteps_per_loop
        tcounter[-1] = ntsteps
        tcounter
        # looping over each set of 61 time steps
        for x in range(6):
            print('tt vals = ', np.arange(tcounter[x],tcounter[x+1],1))
            lazy_results = []
            for tt in np.arange(tcounter[x],tcounter[x+1],1):
                # define computation 
                lazy_result = dask.delayed(besselhighpass)(varfile=datafiles[i], 
                                                                         varname=varname, 
                                                                         datearr=datearrs[i], 
                                                                         tt=int(tt), 
                                                                         wavelength=wavelength)
                # store computations to be done in parallel
                lazy_results.append(lazy_result)  
            # do computations in parallel
            futures = dask.compute(*lazy_results)
            results = dask.compute(*futures)

### Shutdown cluster

In [None]:
client.close()
client.shutdown()