In [None]:
import marineHeatWaves as mhw

In [None]:
# Load required modules
import numpy as np
from datetime import date
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
# Generate time vector using datetime format (January 1 of year 1 is day 1)
t = np.arange(date(1982,1,1).toordinal(),date(2014,12,31).toordinal()+1)
dates = [date.fromordinal(tt.astype(int)) for tt in t]
# Generate synthetic temperature time series
sst = np.zeros(len(t))
sst[0] = 0 # Initial condition
a = 0.85 # autoregressive parameter
for i in range(1,len(t)):
    sst[i] = a*sst[i-1] + 0.75*np.random.randn() + 0.5*np.cos(t[i]*2*np.pi/365.25)
sst = sst - sst.min() + 5.

In [None]:
%%time
mhws, clim = mhw.detect(t, sst)

In [None]:
mhws['n_events']

In [None]:
ev = np.argmax(mhws['intensity_max']) # Find largest event
print ('Maximum intensity:', mhws['intensity_max'][ev], 'deg. C')
print ('Average intensity:', mhws['intensity_mean'][ev], 'deg. C')
print ('Cumulative intensity:', mhws['intensity_cumulative'][ev], 'deg. C-days')
print ('Duration:', mhws['duration'][ev], 'days')
print ('Start date:', mhws['date_start'][ev].strftime("%d %B %Y"))
print ('End date:', mhws['date_end'][ev].strftime("%d %B %Y"))

In [None]:
np.where(t==mhws['time_start'][ev])[0][0]

In [None]:
plt.figure(figsize=(14,10))
plt.subplot(2,1,1)
# Plot SST, seasonal cycle, and threshold
plt.plot(dates, sst, 'k-')
plt.plot(dates, clim['thresh'], 'g-')
plt.plot(dates, clim['seas'], 'b-')
plt.title('SST (black), seasonal climatology (blue), \
          threshold (green), detected MHW events (shading)')
plt.xlim(dates[0], dates[-1])
plt.ylim(sst.min()-0.5, sst.max()+0.5)
plt.ylabel(r'SST [$^\circ$C]')
plt.subplot(2,1,2)
# Find indices for all ten MHWs before and after event of interest and shade accordingly
for ev0 in np.arange(ev-10, ev+11, 1):
    t1 = np.where(t==mhws['time_start'][ev0])[0][0]
    t2 = np.where(t==mhws['time_end'][ev0])[0][0]
    plt.fill_between(dates[t1:t2+1], sst[t1:t2+1], clim['thresh'][t1:t2+1], \
                     color=(1,0.6,0.5))
# Find indices for MHW of interest and shade accordingly
t1 = np.where(t==mhws['time_start'][ev])[0][0]
t2 = np.where(t==mhws['time_end'][ev])[0][0]
plt.fill_between(dates[t1:t2+1], sst[t1:t2+1], clim['thresh'][t1:t2+1], \
                 color='r')
# Plot SST, seasonal cycle, threshold, shade MHWs with main event in red
plt.plot(dates, sst, 'k-', linewidth=2)
plt.plot(dates, clim['thresh'], 'g-', linewidth=2)
plt.plot(dates, clim['seas'], 'b-', linewidth=2)
plt.title('SST (black), seasonal climatology (blue), \
          threshold (green), detected MHW events (shading)')
plt.xlim(date.fromordinal(mhws['time_start'][ev]-150), date.fromordinal(mhws['time_end'][ev]+150))
plt.ylim(clim['seas'].min() - 1, clim['seas'].max() + mhws['intensity_max'][ev] + 0.5)
plt.ylabel(r'SST [$^\circ$C]')

In [None]:
import xarray as xr


vfunc = np.vectorize(lambda x:np.datetime64(x))

temp = xr.DataArray(sst,dims=['TIME'],coords={'TIME':vfunc(dates)})

def rle(inarray,minlength):
        """ run length encoding. Partial credit to R rle function. 
            Multi datatype arrays catered for including non Numpy
            returns: tuple (runlengths, startpositions, values) """
        ia = np.asarray(inarray)                # force numpy
        n = len(ia)
        if n == 0: 
            return (None, None, None)
        else:
            y = np.array(ia[1:] != ia[:-1])     # pairwise unequal (string safe)
            i = np.append(np.where(y), n - 1)   # must include last element posi
            z = np.diff(np.append(-1, i))       # run lengths
            p = np.cumsum(np.append(0, z))[:-1] # positions
            z =z[ia[i]]
            p =p[ia[i]]
            mask =z>=minlength
            z = z[mask]
            p =p[mask]
            return(z, p)

# calculate the clim
def ts2clm(ts,percentile=90):
    ts.coords['year']=temp.TIME.dt.year
    ts.coords['dayofyear']=temp.TIME.dt.dayofyear
    t1 =ts.set_index(TIME=['dayofyear','year']).unstack().pad(dayofyear=31, mode='wrap').rolling(dayofyear=11,min_periods=1,center=True).construct("window_dim")
    seas =t1.reduce(np.nanmean,dim=('year','window_dim')).rolling(dayofyear=31,center=True).mean()[31:-31]
    thresh = t1.reduce(np.nanpercentile,dim=('year','window_dim'), q=percentile).rolling(dayofyear=31,center=True).mean()[31:-31]
    ds = xr.Dataset({'seas':seas,'thresh':thresh})
    return ds


# return the number of heat waves

def detect(ts,clim,minDuration=5):
    exceed_bool = (ts.groupby('TIME.dayofyear')-clim.thresh)>0
    runs, index =rle(exceed_bool,minDuration)
    return len(index)
    
clim =ts2clm(temp)
detect(temp,clim)

#
#

# exceed_bool[exceed_bool<=0] = False
# exceed_bool[exceed_bool>0] = True
#     # Fix issue where missing temp vaues (nan) are counted as True
# exceed_bool[np.isnan(exceed_bool)] = False