# Detecting marine heatwave events with xmhw

## Introduction

*xmhw* is a xarray based version of the MarineHeatWave code.
The main difference with the original code are the following:
* uses xarray and dask
* the calculation of climatologies and detection of mhw events are in separate functions, so they can be called independently
* can handle multi dimensional grids and detect land points
* NaNs treatment can be customised
* severity of events added to *detect* function output
* produce xarray datasets instead of list of dictionaries

### Import functions from xmhw
Import *threhshold* to calculate the climatologies and *detect* to detect the mhw events.

In [1]:
import xarray as xr
#import numpy as np
import dask
from xmhw.xmhw import threshold, detect

### Calculating the climatologies

For this demo I am using a small subset of the NOAA OISST timeseries. You can use whatever seawater temperature dataset you have available, just select a small region initially to test the code.

In [3]:
# open file, read sst and calculate climatologies
ds =xr.open_dataset('sst_test.nc')
sst = ds['sst']
clim = threshold(sst)
clim

Unnamed: 0,Array,Chunk
Bytes,686.25 kiB,57.03 kiB
Shape,"(366, 12, 20)","(365, 1, 20)"
Count,11531 Tasks,24 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 686.25 kiB 57.03 kiB Shape (366, 12, 20) (365, 1, 20) Count 11531 Tasks 24 Chunks Type float64 numpy.ndarray",20  12  366,

Unnamed: 0,Array,Chunk
Bytes,686.25 kiB,57.03 kiB
Shape,"(366, 12, 20)","(365, 1, 20)"
Count,11531 Tasks,24 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,686.25 kiB,57.03 kiB
Shape,"(366, 12, 20)","(365, 1, 20)"
Count,11531 Tasks,24 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 686.25 kiB 57.03 kiB Shape (366, 12, 20) (365, 1, 20) Count 11531 Tasks 24 Chunks Type float64 numpy.ndarray",20  12  366,

Unnamed: 0,Array,Chunk
Bytes,686.25 kiB,57.03 kiB
Shape,"(366, 12, 20)","(365, 1, 20)"
Count,11531 Tasks,24 Chunks
Type,float64,numpy.ndarray


As you can see above **clim** is a xarray dataset with two variables:
 * thresh - the percentile threshold
 * seas - the climatology mean.<br>
 
The dimension is **doy** which stands for day of the year, this is based on a 366 days calendar. <br>
Finally the dataset includes a few global attributes detailing the climatology period, the percenttile used and other parameters used in the calculation.<br>
This can be easily saved to a file simply by running:<br>
> clim.to_netcdf('filename')

### Detecting MHW events

Now that we have the climatologies we can run detect

In [5]:
mhw = detect(sst, clim['thresh'], clim['seas'])
mhw

We can see above all the output variables listed and again global attributes detailing the dataset settings.<br>
The dimension **events** represents the starting point of each event. Let's select one grid point to see more in detail its structure.

In [6]:
mhw_point = mhw.isel(lat=2, lon=15)
mhw_point.events

Printing out the all events array shows that the first detected event occurs at the 91st timestep of the original timeseries, the last events starts at timestep 14381.<br>
Not all these events will be occuring at the selected grid point. We can see that having a look at the index_start or time_start variables.<br>
By dropping all the NaN values along the events dimension, we can see there are 60 mhw events occuring at this grid point.

In [7]:
mhw_point.time_start.dropna(dim='events')

As for the climatologies dataset, we can save the mhw dataset to a netcdf file easily.

In [8]:
mhw.to_netcdf('mhw_test.nc')

This file has a small grid, so we could save it as it is and still produce a small file. However, it is worth adding some "encoding" to save storage, this will be necessary when dealing with bigger grids.<br>
Xarray has automatically used a float64 format for ~20 of the variables. Converting all the variables to float32 format will save a lot of storage.<br>
This dataset also has a lot of NaNs values, as its structure is "sparse", so it is a good idea to save the results in a compressed format.<br>
Encoding allows us to add internal compression and also to convert the arrays format.<br>

In [9]:
# First we create a dictionary representing the settings we want to use
# then we apply that to all the dataset variables and we use the resulting dictionary when calling to_netcdf() 
#
comp = dict(zlib=True, complevel=5, shuffle=True, dtype='float32')
encoding = {var: comp for var in mhw.data_vars}
mhw.to_netcdf('mhw_test_encoded.nc', encoding=encoding)

Checking the sizes of both files

In [10]:
!du -sh mhw_test.nc
!du -sh mhw_test_encoded.nc

109M	mhw_test.nc
2.2M	mhw_test_encoded.nc


## Threshold in detail
Before we used the threshold function with its default arguments, so we only needed to pass the temperature timeseries.<br>
As for the original Marine heatwave code several parameters can be set:
````
 threshold(temp, tdim='time', climatologyPeriod=[None,None], pctile=90, windowHalfWidth=5,  
           smoothPercentile=True, smoothPercentileWidth=31, maxPadLength=None, 
           coldSpells=False, Ly=False, anynans=False, skipna=False):
````

Where *temp* is the temperature timeseries, this is the only input needed. We tried to be as consitent as possible with the original MarineHeatWave code:<br>
 * **climatologyPeriod** - can be used to set a different base for the climatologies. It accept a list of two integers indicating the laternative start and end year eg. [1983,2003]. If not specified the entire timeseries is used, this is the default behaviour.<br>
 * **pctile** - is the percentile to use to calculate the threshold. Default is 90.<br>
 * **windowHalfWidth** - width of window (one sided) about day-of-year used for
            the pooling of values and calculation of threshold percentile.
            Default is 5.<br>
 * **smoothPercentile** - if True (default) smooth the percentile using a moving average.<br>
 * **smoothPercentileWidth** - the width of the window used to smooth the percentile. Default is 31.<br>
 * **maxPadLength** - specifies the maximum length [days] over which to interpolate
                  NaNs in the input time series. Any consecutive blocks of NaNs with length greater than maxPadLength    will be left as NaN.<br>
 * **coldSpells** - specifies if the code should detect cold events instead of
                heat events. Default is False.<br>
 * **Ly** - Boolean: specifies if the length of the year is < 365/366 days (e.g. a 
                             360 day year from a climate model). This affects the calculation of the climatology. Not yet fully implemented.<br>
 

#### New arguments
 * **tdim** - optional, to specify the time dimension name, default is "time" . NB you do not need to pass the time array as in the original as the timeseries is an xarray data array the time dimension is included <br>
 * **anynans** - boolean, defines in land_check() which grid points will be classified as land. By default only ones with all nans values, if anynans is True then all cells with even only 1 NaN along time dimension will be dropped <br>
 * **skipna** - boolean, determines if percentile and mean calculation will skip or not NaNs. Default is False, this is much faster than having to skip NaNs.<br>

More on missing values later.

#### Example
This is just showing how we can call the function changing some of the default parameters. <br>
In this case we are assuming sst time dimension is called 'time_0' and we want a base period from 1 Jan 1984 to 31 Dec 1994.

> clim = threshold(sst, climatologyPeriod=[1984,1994], tdim='time_0')

NB after passing the timeseries as first argument, the order of the other
   ones is irrelevant as they are all keywords arguments.

It is important to notice that differently from the original function which takes a numpy 1D array, because we are using xarray we can pass a 3D array (in fact we could pass any n-dim array) and the code will deal with it.<br>
We selected a 12X20 lat-lon region and of these 135 grid cells are ocean. <br>

The function return a dataset with the arrays: <br>
   - **thresh** - for the threshold timeseries
   - **seas** - for the seasonal mean <br>

Differently from the original function, here the climatologies are saved not along the entire timeseries but only along the new **doy** dimension. Given that xarray keeps the coordinates with the arrays there is no need to repeat the climatologies along the time axis.<br>
We also try to follow the CF conventions and define appropriate variables attributes and some global attributes that record the parameters used to calculate the threshold for provenance.

### Handling of dimensions and land points

As so before we are passing the full grid to the function without worrying about land points, or how many dimensions it has. Before calculating anything, the code calls the function land_check() (from xmhw.identify). This function handles the dimensions and land points of the grid in two steps:<br>
  - stacks all dimensions but the time dimension in a new 'cell' dimension;
  - removes all the land points, these are assumed to have all NaN values along the time axis

In our example 'cell' will be composed by stacked (lat,lon) points. The resulting array will have (time, cell) dimensions, and the cell points which are land will not be part of it. The climatologies then will be calculated for each cell point. Finally the results will be unstacked before returning the final output.<br>
NB This approach can occasionally produce a grid of different size from the original if all the cells at a specific latitude or longitude are masked as land. In that case the final grid will be smaller, you can however easily reindex your results as the original grid.
> clim = clim.reindex_like(sst)

### Handling of NaNs

It is important to understand how the **threshold()** function is dealing with NaNs.<br>
If there are NaNs values in the timeseries that is passed to the function, this could produce wrong results.
You can take care of NaNs in the timeseries before passing it to threshold or you can take one of the following approaches:
1) We already saw that land_check() will remove all the points that have all NaNs values along the time dimension.<br> You can choose to be more strict and also exclude any ell points that even just one NaN value.
To do so you can set the **anynans** argument to True.<br> This is a bit of an extreme approach as especially with observations data it is not unusual to have a few NaNs.<br>
> clim = threshold(sst, anynans=True)

2) set **skipna** to True - this tells the code to skip NaNs when calculating averages and/or the percentile.<br>
By default the **skipna** argument is set to False as using this option can double up the execution time. But if you are working on a small grid than it is a safer option.<br>
> clim = threshold(sst, skipnans=True)

3) use **maxPadLength** this will trigger interpolation for all NaNs points, with the exception of consecutive blocks with length greater than maxPadLength. 
> clim = threshold(sst, maxpadlength=5, anynans=True)

Used in conjuction with **anynans** as shown above you can use it to eleiminate only the cell points that have bigger gaps.

## Detect function in detail
The *detect* function indetifies all the mhw events and their characteristics. Corresponds to the second part of the original detect function and again mimic the logic of the original code.

````
    def detect(temp, th, se, minDuration=5, joinAcrossGaps=True, maxGap=2, maxPadLength=None,
           coldSpells=False, tdim='time', intermediate=False, anynans=False):

````
This time you have to pass the timeseries, the threshold and the seasonal average. The other arguments are all optional.<br>
Again we kept most of the original arguments and added an option to pass the name of the time dimension (**tdim**) and the **anynans** argument to define which grid cells will be removed from calculation.<br> It is important that this is consistent with the approach used when calculating the threshold.<br><br>
The last new argument is **intermediate**, when set to True also intermediate results are saved these include the original timeseries, climatologies, detected events, categories and some of the mhw variables but along the time axis.<br>

Arguments specific to **detect()**:
 * minDuration            Integer: minimum duration for acceptance detected MHWs
                             (DEFAULT = 5 [days])
 * joinAcrossGaps         Boolean: switch indicating whether to join MHWs      
                             separated b a short gap (DEFAULT = True)
 * maxGap                 Maximum length of gap allowed for the joining of MHWs
                             (DEFAULT = 2 [days])

In [11]:
mhw, intermediate = detect(sst, clim['thresh'], clim['seas'], intermediate=True)
intermediate

In [1]:
intermediate.variables

NameError: name 'intermediate' is not defined

This time the function returns a xarray dataset, 'cell' dimension is still present, so we need to unstack it if we want back the latitude and longitude grid.

The resulting dataset has a new dimension `events` which is defined as the starting day of an mhw event.
And each variable is a characteristic of the detected mhw:
````event         (events, lat, lon)
    index_start   (events, lat, lon)
    index_end     (events, lat, lon)
    time_start    (events, lat, lon)
    time_end      (events, lat, lon)
    intensity_max (events, lat, lon)
    intensity_mean(events, lat, lon)
    ...

````

The *events* dimension size is determined by the number of separate events individuated. Separate events have different startung times. This means that if two different cells have events starting at timestep=50, these event will have the same index along the dimension `events` regardless on their duration.<br>
Clearly this is an approximation because if an event starts even a timestep later is classified as separate.
This is because as for the original code, each event is individuated cell by cell. 

In [None]:
mhwds.intensity_cumulative[1768,:,:].plot()

**Block average**<br><br>
The blockAverage function on the original MHW code is used to calculate statistics along a block of time. The default is 1 year block. If the time series used starts or ends in the middle of the year then the results for this two years have to be treated carefully.<br>
Most of the statistics calculated on the block are averages. Given that the mhw properties are saved now as an array we can simply calculate a mean after grouping by year or "bins"of years on the entire dataset.<br>
Below I'm showing the current stage of a new block_average function, which I'm adding to xmhw. 

In [None]:
def block_average(mhwds, temp=None, clim=None, blockLength=1, removeMissing=False):
    '''
    Options:

      blockLength            Size of block (in years) over which to calculate the
                             averaged MHW properties. Must be an integer greater than
                             or equal to 1 (DEFAULT = 1 [year])
      removeMissing          Boolean switch indicating whether to remove (set = NaN)
                             statistics for any blocks in which there were missing 
                             temperature values (DEFAULT = FALSE)
      clim                   The temperature climatology (including missing value information)
                             as output by marineHeatWaves.detect 
      temp                   Temperature time series. If included mhwBlock will output block
                             averages of mean, max, and min temperature (DEFAULT = None but
                             required if removeMissing = TRUE)
                             

                             If both clim and temp are provided, this will output annual counts
                             of moderate, strong, severe, and extreme days.

    Notes:

      This function assumes that the input time vector consists of continuous daily values. Note that
      in the case of time ranges which start and end part-way through the calendar year, the block
      averages at the endpoints, for which there is less than a block length of data, will need to be
      interpreted with care.

    '''
    # Check if all the necessary variables are present 
    if removeMissing and not temp:
        print(f'To remove missing values you need to pass the original temperature timeseries')
        return None
    # Check what stats to output
    # if temp included calculate stats for it, if clim also included calculate categories days count
    sw_temp=False
    sw_cats=False
    if temp is not None:
        sw_temp = True
        if clim is not None:
            sw_cats = True
        else:
            sw_cats = False
    
    # create bins based on blockLength to used with groupby_bins
    # NB if the last bin has less than blockLength years, it won't be included.
    # So I'm using last-year+blockLength+1 to make sure we get a bin for last year/s included
    bins=range(1981,2021+blockLength+1,blockLength)

    # calculate mean of variables after grouping by year
    block = mhwds.groupby_bins(mhwds.time_start.dt.year, bins, right=False).mean()
    
    # remove averages of indexes, events and category (which need special treatment)
    block = block.drop(['event', 'index_start', 'index_end', 'category'] )
    
    # Other stats can be calculated one by one
    # calculate maximum of intensity_max
    block['intensity_max_max'] = mhwds.intensity_max.groupby_bins(mhwds.time_start.dt.year, bins, right=False).max()
    
    # if sw_temp
    if sw_temp:
        pass
    if sw_cats:
        block['moderate_days'] = mhwds.duration_moderate.groupby_bins(mhwds.time_start.dt.year, bins, right=False).sum()
    
    return block

In [None]:
# To call with standard parameters, below results for a grid point are shown compared to the orginal function 
block = block_average(mhwds)

**Setting up dask**<br>
Both the threshold and detect functions are set up to use dask delayed. I found this was a good away to make sure the main processes would be automatically run in parallel even if you are not experienced with dask. This approach add some overhead before the actual calculation start, this is usually negligible, but it can become a big overhead with big grids as it will produce too many tasks. In that case it's better to split the grid and run the code separately for each grid section and then recompose together the results. An example is shown below.


**Find MHW using original code**<br><br>

In [None]:
#%%time
#from datetime import date
from marineHeatWaves import detect as orig_detect
from marineHeatWaves import blockAverage

# create necessary time numpy array
t = np.arange(datetime.date(1981,9,1).toordinal(),datetime.date(2021,1,25).toordinal()+1)
sst = tos[:,0,0].squeeze().values
# call function with default settings
orig_mhw, orig_clim = orig_detect(t, sst)

In [None]:
# test to see if groupby_bins as used in the new block_average function produces the same values, 
# first of all I need to go from a list of values to an xarray array for one variable to apply groupby_bins
import pandas
start = pandas.to_datetime(orig_mhw['date_start'])
intensity_mean = xr.DataArray(orig_mhw['intensity_mean'],
                              dims=['start'], 
                              coords=dict(start=start))
intensity_mean

In [None]:
#compare mean of intensity_mean with blockLength=1
blockLength=1
#calculate with Eric's code
blockMHW1=blockAverage(t, orig_mhw, clim=orig_clim, blockLength=1, temp=sst)
# create bins and use groupby_bins
bins=range(1981,2021+1+blockLength, blockLength)
print(intensity_mean.groupby_bins(intensity_mean.start.dt.year,bins, right=False).mean())
print(blockMHW1['intensity_mean'])

In [None]:
#compare mean of intensity_mean with blockLength=2
blockLength=2
#calculate with Eric's code
blockMHW2=blockAverage(t, orig_mhw, clim=orig_clim, blockLength=2, temp=sst)
# create bins and use groupby_bins
bins=range(1981,2021+1+blockLength, blockLength)
print(intensity_mean.groupby_bins(intensity_mean.start.dt.year,bins, right=False).mean())
print(blockMHW2['intensity_mean'])

In [None]:
#compare mean of intensity_mean with blockLength=3
blockLength=3
#calculate with Eric's code
blockMHW3=blockAverage(t, orig_mhw, clim=orig_clim, blockLength=3, temp=sst)
# create bins and use groupby_bins
bins=range(1981,2021+1+blockLength, blockLength)
print(intensity_mean.groupby_bins(intensity_mean.start.dt.year,bins, right=False).mean())
print(blockMHW3['intensity_mean'])

Working with big grids
Consider casting your array type to float32 while there would be some 
loss in precision it should really not matter overall, it will half your memory usage


**Comparison with original code**<br><br>


We added tests and used them to make sure we could produce exactly the same results as the original code, since then however we introduced 

In [None]:
# you need this if running original code otherwise it really slows down
import datetime
import warnings
warnings.filterwarnings('ignore')