**Detecting marine heatwave events with xmhw**<br>
----

In [None]:
import xarray as xr
import numpy as np
import datetime  # this is needed to run original code
import dask

In [None]:
# you need this if running original otherwise it really slows down
import warnings
warnings.filterwarnings('ignore')

**Import functions from xmhw**<br><br>
We separated the calculation of the the climatologies from the identification of marine heat waves (mhw). In this way we have two separate functions and you can save a re-use the threshold while experimenting with different settings for the detection part.

In [None]:
from xmhw.xmhw import threshold, detect

**Setting up dask**<br>

I am not dask expert, this is something I tried to make the code run faster and making sure that cell grids are processed in parallel. I use dask.delayed to create delayed functions and that speed up the calculation. It also helps computing the threshold before the detection step.

In [None]:
# configure the dask scheduler to threaded.
# The threaded scheduler executes computations with a local multiprocessing.pool.ThreadPool 
# so you can run multithread 
dask.config.set(scheduler='threads')

In [None]:
# defining the functions as delayed functions
threshold = dask.delayed(threshold)
detect = dask.delayed(detect)

In [None]:
# import ProgressBar to help with diagnostic
from dask.diagnostics import ProgressBar

**Example calculation**<br>

Using NOAA OISST timeseries, I am selecting a smaller region to demo xmhw code.<br>
Before calculaitng anything the function land_check() (from xmhw.identify) is called. This function does has two steps:<br>
  - stacks all dimensions but "time" in a new 'cell' dimension;
  - removes all the land points, these are assumed to have np.nan values along the time axis
   
NB If the timeseries you want to use has a time axis which is not called 'time' you can specify that.<br>

In [None]:
# using NOAA oisst as timeseries
ds =xr.open_mfdataset('/g/data/ua8/NOAA_OISST/AVHRR/v2-1_modified/timeseries/oisst_timeseries_*.nc',
                        concat_dim='time', combine='nested', chunks={'time':-1, 'lat': 10, 'lon': 10})
# removing zlev dimension
sst =ds['sst'].squeeze()
sst = sst.drop('zlev')
# for the moment getting small region to test
# This correspond to ... ocean cell grid points
tos = sst.sel(lat=slice(-44,-41),lon=slice(144, 149))

In [None]:
# data is small enough to have 1 chunk)
# NB for each cell the timeseries should be in same chunk, fo this reason chunk({'time-dimension': -1}) 
# is included n the module where necessary
ts = tos.chunk({'time':-1, 'lat':-1, 'lon':-1})
ts

**Calculate threshold separately and save it to file**<br>


The *threshold* function will calculate the climatologies, ie.e seasonal average and threhsold, then use to detect marine heat waves (mhw) along the timeseries.<br>This function mimic the original code behaviour including returning a dictionary. We are looking at changing this so it will return a dataset instead.<br> As for the original several parameters can be set:
````
threshold(temp, tdim='time', climatologyPeriod=[None,None], pctile=90, windowHalfWidth=5,
          smoothPercentile=True, smoothPercentileWidth=31, maxPadLength=False, coldSpells=False, Ly=False)
````
Where *temp* is the temperature timeseries, this is the only input needed, if you're happy with the default settings and if you're time dimension is called 'time'.<br><br>
In the following example we're using all default settings for threshold.

In [None]:
# this won't do anything until we call compute(), because we are using delayed
clim = threshold(ts)

In [None]:
with ProgressBar():
    clim_dict = clim.compute()

It is important to notice that differently from the original function which takes a numpy 1D array, because we are using xarray we can pass a 3D array (in fact we could pass any n-dim array) and the code will deal with it.<br>
We selected a 12X20 lat-lon region and of these 135 grid cells are ocean. <br>

Before saving the results to netcdf the data should be *unstacked*. For threshold() the dataset is unstacked before being returned.<br> 
Differently from the original function, here the climatologies are saved not along the entire timeseries but only along the new *doy* dimension. Given that xarray keeps the coordinates with the arrays there is no need to repeat the climatologies along the time axis.

In [None]:
# save threshold and seasonal average to netcdf
climds = xr.merge([clim_dict['thresh'], clim_dict['seas']])
climds.to_netcdf('climatology_tas.nc')

**Filter MHW passing calculated climatologies to detect**<br>
The *detect* function indetifies all the mhw events and their characteristics. Corresponds to the second part of the original detect function and again mimic the logic of the original code.

````
    detect(temp, thresh, seas, minDuration=5, joinAcrossGaps=True, maxGap=2,
           maxPadLength=None, coldSpells=False, tdim='time')
````
This time you have to pass the timeseries, the threshold and the seasonal average. The others parameters are optional.<br> The results are stored differently form the original function:
````
   Original structure: 
       - mhw is a dictionary
       - each characteristic is a key with a list of values, each value represent an event
       - Ex.  mhw['intensity_max'][ev]
````
First of all, the new function returns an xarray dataset not a dictionary. Most importantly, there's one variable for each calculated field. The events are stored all together not as separate arrays.<br> Let's see an example, we are using all default settings for MHW filter.

In [None]:
# as before this won't do anything until I call compute()
mhw  = detect(tos, clim_dict['thresh'], clim_dict['seas'])

In [None]:
with ProgressBar():
    ds = mhw.compute()

This time the function returns a xarray dataset, 'cell' dimension is still present, so we need to unstack it if we want back the latitude and longitude grid.

In [None]:
mhwds = ds.unstack('cell')

The resulting dataset has two kind of variables:
````events (time, lat, lon)
    relSeas (time, lat, lon)
    relThresh (time, lat, lon)
    end_idx (event, lat, lon)
    start_idx (event, lat, lon)
    intensity_cumulative (event, lat, lon)

````
Some are defined on along time and they will have np.nan everywhere but where an event is defined. "events' is one of them it will look like:<br>
  nan, nan, nan, 3, 3, 3, 3, 3, nan ... <br>
Where 3 is the index of the first timestep for an event.
The *events* variable can be used as a coordinate for the other variables defined along the time axis.
The other group defines the mhw characteristics and they are defined along the *event* dimension.
The *event* dimension size is determined by the number of separate events individuated. Separate events have different startung times. This menas that if two different cells have events starting at timestep=50, these event will have the same index along the dimension 'event' regardless on their duration.<br>
Clearly this is an approximation because if an event starts even a timestep later is classified as separate.
This is because as for the orgiinal code, each event is individuated cell by cell. 

In [None]:
mhwds

In [None]:
mhw.intensity_cumulative

In [None]:
# save mhw to yearly netcdf files (to split size if you have a really long timeseries)
#years, datasets = zip(*mhwds.groupby("time.year"))
#paths = ["mhw_%s.nc" % y for y in years]
#xr.save_mfdataset(datasets, paths)
# you can use this if only doing a subset

mhwds.to_netcdf('mhw.nc')

**Find MHW using original code**<br><br>

In [None]:
#%%time
#from datetime import date
#from marineHeatWaves import detect as orig_detect

# create necessary time numpy array
t = np.arange(date(1981,9,1).toordinal(),date(2020,5,18).toordinal()+1)
sst = tos[:,0,0].squeeze().values
# call function with default settings
orig_mhw, orig_clim = orig_detect(t, sst)