# Vegetation Phenology

## Load packages

In [None]:
%matplotlib inline
%load_ext autoreload

import os, sys
import xarray as xr
import numpy as np
import pandas as pd
import datacube

from scipy.signal import savgol_filter, wiener
from scipy.stats import zscore
from statsmodels.tsa.seasonal import STL as stl
import matplotlib.pyplot as plt

from datacube.drivers.netcdf import write_dataset_to_netcdf

sys.path.append('../Scripts')
from dea_datahandling import load_ard
from dea_dask import create_local_dask_cluster
from dea_plotting import display_map, rgb
#import deafrica_temporal_statistics as ts

sys.path.append('./scripts')
import phenolopy

### Set up a dask cluster

In [None]:
# initialise the cluster. paste url into dask panel for more info.
create_local_dask_cluster()

In [None]:
# open up a datacube connection
dc = datacube.Datacube(app='phenolopy')

## Study area and data setup

### Set study area and time range

In [None]:
# set lat, lon (y, x) dictionary of testing areas for gdv project
loc_dict = {
    'yan_full':   (-22.750, 119.10),
    'yan_full_1': (-22.725, 119.05),
    'yan_full_2': (-22.775, 119.15),
    'roy_full_1': (-22.487, 119.927),
    'roy_full_2': (-22.487, 120.092),
    'roy_full_3': (-22.623, 119.927),
    'roy_full_4': (-22.623, 120.092),
    'oph_full_1': (-23.375319, 119.859309),
    'oph_full_2': (-23.185611, 119.859309),
    'oph_full_3': (-23.233013, 119.859309),
    'oph_full_4': (-23.280432, 119.859309),
    'oph_full_5': (-23.327867, 119.859309),
    'test':       (-31.6069288, 116.9426373)
}

# set buffer length and height (x, y)
buf_dict = {
    'yan_full':   (0.15, 0.075),
    'yan_full_1': (0.09, 0.025),
    'yan_full_2': (0.05, 0.0325),
    'roy_full_1': (0.165209/2, 0.135079/2),
    'roy_full_2': (0.165209/2, 0.135079/2),
    'roy_full_3': (0.165209/2, 0.135079/2),
    'roy_full_4': (0.165209/2, 0.135079/2),
    'oph_full_1': (0.08, 0.047452/2),
    'oph_full_2': (0.08, 0.047452/2),
    'oph_full_3': (0.08, 0.047452/2),
    'oph_full_4': (0.08, 0.047452/2),
    'oph_full_5': (0.08, 0.047452/2),
    'test':       (0.075, 0.065)
}

In [None]:
# select location from dict
study_area = 'test'

# set buffer size in lon, lat (x, y)
lon_buff, lat_buff = buf_dict[study_area][0], buf_dict[study_area][1]

# select time range. for a specific year, set same year with month 01 to 12. multiple years will be averaged.
time_range = ('2015', '2020')

In [None]:
# select a study area from existing dict
lat, lon = loc_dict[study_area][0], loc_dict[study_area][1]

# combine centroid with buffer to form study boundary
lat_extent = (lat - lat_buff, lat + lat_buff)
lon_extent = (lon - lon_buff, lon + lon_buff)

# display onto interacrive map
display_map(x=lon_extent, y=lat_extent)

### Load sentinel-2a, b data for above parameters


In [None]:
# set measurements (bands)
measurements = [
    'nbart_blue',
    'nbart_green',
    'nbart_red',
    'nbart_nir_1',
    'nbart_swir_2'
]

# create query from above and expected info
query = {
    'x': lon_extent,
    'y': lat_extent,
    'time': time_range,
    'measurements': measurements,
    'output_crs': 'EPSG:3577',
    'resolution': (-10, 10),
    'group_by': 'solar_day',
}

# load sentinel 2 data
ds = load_ard(
    dc=dc,
    products=['s2a_ard_granule', 's2b_ard_granule'],
    min_gooddata=0.90,
    dask_chunks={'time': 1},
    **query
)

# display dataset
print(ds)

# display a rgb data result of temporary resampled median 
#rgb(ds.resample(time='1M').median(), bands=['nbart_red', 'nbart_green', 'nbart_blue'], col='time', col_wrap=12)

### Conform DEA band names

In [None]:
# takes our dask ds and conforms (renames) bands
ds = phenolopy.conform_dea_band_names(ds)

# display dataset
print(ds)

### Calculate vegetation index

In [None]:
# takes our dask ds and calculates veg index from spectral bands
ds = phenolopy.calc_vege_index(ds, index='mavi', drop=True)

# display dataset
print(ds)

## Pre-processing phase

### Temporary - load MODIS dataset

In [None]:
#ds = phenolopy.load_test_dataset(data_path='./data/')

### Group data by month and reduce by median

In [None]:
# take our dask ds and group and reduce dataset in median weeks (26 for one year)
ds = phenolopy.group(ds, group_by='month', reducer='median')

# display dataset
print(ds)

### Remove outliers from dataset on per-pixel basis

In [None]:
# chunk dask to -1 to make compatible with this function
ds = ds.chunk({'time': -1})

# takes our dask ds and remove outliers from data using median method
ds = phenolopy.remove_outliers(ds=ds, method='median', user_factor=2, z_pval=0.05)

# display dataset
print(ds)

### Resample dataset down to monthly medians

In [None]:
# takes our dask ds and resamples data to bi-monthly medians
ds = phenolopy.resample(ds, interval='2W', reducer='median')

# display dataset
print(ds)

### Interpolate missing (i.e. nan) values linearly

In [None]:
# chunk dask to -1 to make compatible with this function
ds = ds.chunk({'time': -1})

# takes our dask ds and interpolates missing values
ds = ds.interpolate_na(dim='time', method='linear')

# display dataset
print(ds)

In [None]:
# preforming forward and back fills to fill in


### Smooth data on per-pixel basis

In [None]:
# take our dask ds and smooth using savitsky golay filter
ds = phenolopy.smooth(ds=ds, method='savitsky', window_length=3, polyorder=1)

# display dataset
print(ds)

### Upper envelope correction
todo

In [None]:
# todo

## Calculate Phenolometrics

In [None]:
# working - mos date (middle of season)
#time for the mid of the season; computed as the mean value of the times for which,
#respectively, the left edge has increased to the 80 % level and the right edge has decreased
#to the 80 % level.

In [None]:
%autoreload

# calc phenometrics via phenolopy!
ds_phenos = phenolopy.calc_phenometrics(da=da, peak_metric='pos', base_metric='bse', method='seasonal_amplitude', factor=0.2, thresh_sides='two_sided', abs_value=40)

### Working - Prelim phenometrix

In [None]:
# calculate value (max vege) at pos (peak of season)
#def get_pos_v(da):
    #pos_v = da.max('time')  # get max veg value across time per vector
    
    #return pos_v


# calculate date (doy) at pos (peak of season)
#def get_pos_d(da):
    #i = da.argmax('time')  # get time index of max veg value across time per vector
    #pos_d = da.isel(time=i)  # select day of year for index
    #pos_d = pos_d['time'].dt.dayofyear  # convert to doy

    #return pos_d


# calculate value (min vege) at trh (lowest point of season)
#def get_trh_v(da):
    #trh_v = da.min('time')  # get min veg value across time per vector
    
    #return trh_v


# calculate date (doy) at trh (trough of season)
#def get_trh_d(da):
    #i = da.argmin('time')  # get time index of min veg value across time per vector
    #trh_d = da.isel(time=i)  # select day of year for index
    #trh_d = trh_d['time'].dt.dayofyear  # convert to doy

    #return trh_d


# calculate value (vege) at sos (start of season)
#def get_sos_v(da, pos_d, method='percentile', threshold=0.2):   
    #greening = da.where(da['time'] <= pos_d['time'])  # select times prior to pos (greening) (changed < to <= to capture when first is highest)
    #slope_diffs = greening.differentiate('time')  # calc second order diffs
    #pos_diffs = slope_diffs.where(slope_diffs > 0)  # select only positive diffs (i.e. non neg)
    
    #pos_greening = greening.where(pos_diffs)  # select raw veg values where positive on greening slope
    
    #if method == 'first':
        #slope_med = pos_greening.median('time')  # get median of raw veg on pos greening slopes
        #dists_from_median = pos_greening - slope_med  # calc vege distances from median val
        
        #nan_mask = dists_from_median.isnull().all('time')  # calc mask where all vals across time is nan
        
        #i = dists_from_median.fillna(float(dists_from_median.max() + 1))  # fill nans by taking highest val + 1
        #i = i.argmin(dim='time', skipna=True).where(~nan_mask).astype('int16')  # get index of lowest (furthest) dist from median
    
    #elif method == 'percentile':
        #slope_pct = pos_greening.quantile(dim='time', q=threshold, interpolation ='nearest',)  # cut off lower percentile of veg values
        #dists_from_percent = pos_greening - slope_pct  # calc vege distances from percentile val
        #dists_from_percent_abs = xr.ufuncs.fabs(dists_from_percent)  # convert values to absolute value
        
        #nan_mask = dists_from_percent_abs.isnull().all('time')  # calc mask where all vals across time is nan
        
        #i = dists_from_percent_abs.fillna(float(dists_from_percent_abs.max() + 1))  # fill nans by taking highest val + 1
        #i = i.argmin(dim='time', skipna=True).astype('int16')

    #elif method == 'median':
        #slope_med = pos_greening.median('time')  # get median of raw veg on pos greening slopes
        #dists_from_median = pos_greening - slope_med  # calc vege distances from median val
        #dists_from_median_abs = xr.ufuncs.fabs(dists_from_median)  # convert values to absolute value
        
        #nan_mask = dists_from_median_abs.isnull().all('time')  # calc mask where all vals across time is nan
        
        #i = dists_from_median_abs.fillna(float(dists_from_median_abs.max() + 1))  # fill nans by taking highest val + 1
        #i = i.argmin(dim='time', skipna=True).where(~nan_mask).astype('int16')  # get index of lowest (furthest) dist from median

    #sos_v = pos_greening.isel(time=i)
    
    #return sos_v


# calculate date (doy) at sos (start of season)
#def get_sos_d(da, sos_v):
    #sos_d = sos_v['time'].dt.dayofyear
    
    #return sos_d


# calculate value (vege) at eos (end of season)
#def get_eos_v(da, pos_d, method='percentile', threshold=0.8):
    #browning = da.where(da['time'] >= pos_d['time'])  # select times prior to pos (greening) (changed > to >= to capture when first is highest)
    #slope_diffs = browning.differentiate('time')  # calc second order diffs
    #neg_diffs = slope_diffs.where(slope_diffs < 0)  # select only negative diffs (i.e. non pos)
    
    #neg_browning = browning.where(neg_diffs)  # select raw veg values where negative on browning slope
    
    #if method == 'first':
        #slope_med = neg_browning.median('time')  # get median of raw veg on neg browing slopes
        #dists_from_median = neg_browning - slope_med  # calc vege distances from median
    
        #nan_mask = dists_from_median.isnull().all('time')  # calc mask where all vals across time is nan
    
        #i = dists_from_median.fillna(float(dists_from_median.max() + 1))  # fill nans by taking highest val + 1
        #i = i.argmin(dim='time', skipna=True).where(~nan_mask).astype('int16')  # get index of lowest (furthest) dist from median
    
    #elif method == 'percentile':
        #slope_pct = neg_browning.quantile(dim='time', q=threshold, interpolation ='nearest',)  # cut off lower percentile of veg values
        #dists_from_percent = neg_browning - slope_pct  # calc vege distances from percentile val
        #dists_from_percent_abs = xr.ufuncs.fabs(dists_from_percent)  # convert values to absolute value
        
        #nan_mask = dists_from_percent_abs.isnull().all('time')  # calc mask where all vals across time is nan
        
        #i = dists_from_percent_abs.fillna(float(dists_from_percent_abs.max() + 1))  # fill nans by taking highest val + 1
        #i = i.argmin(dim='time', skipna=True).astype('int16')

    #elif method == 'median':
        #slope_med = neg_browning.median('time')  # get median of raw veg on neg browing slopes
        #dists_from_median = neg_browning - slope_med  # calc vege distances from median
        #dists_from_median_abs = xr.ufuncs.fabs(dists_from_median)  # convert values to absolute value
    
        #nan_mask = dists_from_median_abs.isnull().all('time')  # calc mask where all vals across time is nan
    
        #i = dists_from_median_abs.fillna(float(dists_from_median_abs.max() + 1))  # fill nans by taking highest val + 1
        #i = i.argmin(dim='time', skipna=True).where(~nan_mask).astype('int16')  # get index of lowest (furthest) dist from median
        
    #eos_v = neg_browning.isel(time=i)
    
    #return eos_v


# calculate date (doy) at eos (end of season)
#def get_eos_d(da, eos_v, method='percentile'):
    #eos_d = eos_v['time'].dt.dayofyear
    
    #return eos_d


# calculate aos value (amplitude of season)
#def get_aos_v(da, pos_v, trh_v):
    #aos_v = pos_v - trh_v  # minus peak of season from lowest point of season
    
    #return aos_v


# calculate los (doy) from sos (start of season) to eos (end of season)
#def get_los_d(da, sos_d, eos_d):
    #los_d = eos_d - sos_d  # get difference in doys between end and start of season
    
    #max_doy = int(da['time'].dt.dayofyear[-1])  # get max doy in da
    #los_d = xr.where(los_d >= 0, los_d, max_doy + (eos_d.where(los < 0) - sos_d.where(los < 0)))  # correct for neg vals
    
    #return los_d

In [None]:
# method = first gets first positive slope value from left (greening), first negative slope value from right (browning)
# method = percentile gets greening or browning point above/below a user veg threshold (via percentile)
# method = median gets greening or browning point from the median of positive vals on slope (be it pos or neg)

#def calc_phenology(da):
#return pos_v
#pos_v = ds['veg_index'].map_blocks(get_pos_v, template=template)
#pos_d = ds['veg_index'].map_blocks(get_pos_d, template=template)
#trh_v = ds['veg_index'].map_blocks(get_trh_v, template=template)
#trh_d = ds['veg_index'].map_blocks(get_trh_d, template=template)
#sos_v = ds['veg_index'].map_blocks(get_sos_v, template=template, kwargs={'pos_d': pos_d, 'method': 'percentile', 'threshold': 0.2})
#sos_d = ds['veg_index'].map_blocks(get_sos_d, template=template, kwargs={'sos_v': sos_v})
#eos_v = ds['veg_index'].map_blocks(get_eos_v, template=template, kwargs={'pos_d': pos_d, 'method': 'percentile', 'threshold': 0.8})
#eos_d = ds['veg_index'].map_blocks(get_eos_d, template=template, kwargs={'method': 'first', 'eos_v': eos_v})
#aos_v = ds['veg_index'].map_blocks(get_aos_v, template=template, kwargs={'pos_v': pos_v, 'trh_v': trh_v})
#los_d = ds['veg_index'].map_blocks(get_los_d, template=template, kwargs={'sos_d': sos_d, 'eos_d': eos_d})

### working code

In [None]:
# different types of detection, using stl residuals - remove outlier method
#from scipy.stats import median_absolute_deviation

#v = ds.isel(x=0, y=0, time=slice(0, 69))
#v['veg_index'].data = data

#v_med = remove_outliers(v, method='median', user_factor=1, num_dates_per_year=24, z_pval=0.05)
#v_zsc = remove_outliers(v, method='zscore', user_factor=1, num_dates_per_year=24, z_pval=0.1)

#stl_res = stl(v['veg_index'], period=24, seasonal=5, robust=True).fit()
#v_rsd = stl_res.resid
#v_wgt = stl_res.weights

#o = v.copy()
#o['veg_index'].data = v_rsd

#w = v.copy()
#w['veg_index'].data = v_wgt

#m = xr.where(o > o.std('time'), True, False)
#o = v.where(m)

#m = xr.where(w < 1e-8, True, False)
#w = v.where(m)

#fig = plt.figure(figsize=(18, 7))
#plt.plot(v['time'], v['veg_index'], color='black', marker='o')
#plt.plot(o['time'], o['veg_index'], color='red', marker='o', linestyle='-')
#plt.plot(w['time'], w['veg_index'], color='blue', marker='o', linestyle='-')
#plt.axhline(y=float(o['veg_index'].std('time')))
#plt.show()

In [None]:
# working method for stl outlier dection. can't quite get it to match timesat results?
# need to speed this up - very slow for even relatively small datasets
#def func_stl(vec, period, seasonal, jump_l, jump_s, jump_t):
    #resid = stl(vec, period=period, seasonal=seasonal, 
                #seasonal_jump=jump_s, trend_jump=jump_t, low_pass_jump=jump_l).fit()
    #return resid.resid

#def do_stl_apply(da, multi_pct, period, seasonal):
    
    # calc jump size for lowpass, season and trend to speed up processing
    #jump_l = int(multi_pct * (period + 1))
    #jump_s = int(multi_pct * (period + 1))
    #jump_t = int(multi_pct * 1.5 * (period + 1))
    
    #f = xr.apply_ufunc(func_stl, da,
                       #input_core_dims=[['time']], 
                       #output_core_dims=[['time']], 
                       #vectorize=True, dask='parallelized', 
                       #output_dtypes=[ds['veg_index'].dtype],
                       #kwargs={'period': period, 'seasonal': seasonal, 
                               #'jump_l': jump_l, 'jump_s': jump_s, 'jump_t': jump_t}) 
    #return f

# chunk up to make use of dask parallel
#ds = ds.chunk({'time': -1})

# calculate residuals for each vector  stl
#stl_resids = do_stl_apply(ds['veg_index'], multi_pct=0.15, period=24, seasonal=13)

#s = ds['veg_index'].stack(z=('x', 'y'))
#s = s.chunk({'time': -1})
#s = s.groupby('z').map(func_stl)
#out = out.unstack()

#s = ds.chunk({'time': -1})
#t = xr.full_like(ds['veg_index'], np.nan)
#out = xr.map_blocks(func_stl, ds['veg_index'], template=t).compute()

#stl_resids = stl_resids.compute()

In [None]:
# working double logistic - messy though
# https://colab.research.google.com/github/1mikegrn/pyGC/blob/master/colab/Asymmetric_GC_integration.ipynb#scrollTo=upaYKFdBGEAo
# see for asym gaussian example

da = v.where(v['time.year'] == 2016, drop=True)

def logi(x, a, b, c, d):
    return a / (1 + xr.ufuncs.exp(-c * (x - d))) + b

# get date at max veg index
idx = int(da['veg_index'].argmax())

# get left and right of peak of season
da_l = da.where(da['time'] <= da['time'].isel(time=idx), drop=True)
da_r = da.where(da['time'] >= da['time'].isel(time=idx), drop=True)

# must sort right curve (da_r) descending to flip data
da_r = da_r.sortby(da_r['time'], ascending=False)

# get indexes of times (times not compat with exp)
da_l_x_idxs = np.arange(1, len(da_l['time']) + 1, step=1)
da_r_x_idxs = np.arange(1, len(da_r['time']) + 1, step=1)

# fit curve
popt_l, pcov_l = curve_fit(logi, da_l_x_idxs, da_l['veg_index'], method="trf")
popt_r, pcov_r = curve_fit(logi, da_r_x_idxs, da_r['veg_index'], method="trf")

# apply fit to original data
da_fit_l = logi(da_l_x_idxs, *popt_l)
da_fit_r = logi(da_r_x_idxs, *popt_r)

# flip fitted vector back to original da order
da_fit_r = np.flip(da_fit_r)

# get mean of pos value, remove overlap between l and r
pos_mean = (da_fit_l[-1] + da_fit_r[0]) / 2
da_fit_l = np.delete(da_fit_l, -1)
da_fit_r = np.delete(da_fit_r, 1)

# concat back together with mean val inbetween
da_logi = np.concatenate([da_fit_l, pos_mean, da_fit_r], axis=None)

# smooth final curve with mild savgol
da_logi = savgol_filter(da_logi, 3, 1)

fig = plt.subplots(1, 1, figsize=(6, 4))
plt.plot(da['time'], da['veg_index'], 'o')
plt.plot(da['time'], da_logi)

In [None]:
#from scipy.signal import find_peaks

#x, y = 0, 1

#v = da.isel(x=x, y=y)

#height = float(v.quantile(dim='time', q=0.75))
#distance = math.ceil(len(v['time']) / 4)

#p = find_peaks(v, height=height, distance=distance)[0]

#p_dts = v['time'].isel(time=p)

#for p_dt in p_dts:
    #plt.axvline(p_dt['time'].dt.dayofyear, color='black', linestyle='--')

#count_peaks = len(num_peaks[0])
#if count_peaks > 0:
    #return count_peaks
#else:
    #return 0
    
#plt.plot(v['time.dayofyear'], v)

In [None]:
# flip to get min closest to pos
# if we want closest sos val to pos we flip instead to trick argmin
#flip = dists_sos_v.sortby(dists_sos_v['time'], ascending=False)
#min_right = flip.isel(time=flip.argmin('time'))
#temp_pos_cls = da.isel(x=x, y=0).where(da['time'] == min_right['time'].isel(x=x, y=0))
#plt.plot(temp_pos_cls.time, temp_pos_cls, marker='o', color='black', alpha=0.25)

In [None]:
# roi and rod method from timesat. chad seems superior
# cut into left and right

# get value and time at 20% quantile on left then right
# get value and time at 80% quantile on left then right

# roi = (left 80% val - left 20% val) / (left 80% doy - left 20% doy)

#slope_l = da.where(da['time.dayofyear'] <= da_phenos['sos_times'])

#slope_l_low_values = slope_l.quantile(dim='time', q=0.2)
#slope_l_low_abs_dists = abs(slope_l_low_values - slope_l)
#slope_l_low_times = slope_l['time.dayofyear'].isel(time=slope_l_low_abs_dists.argmin('time'))

#slope_l_high_abs_dists = abs(slope_l_high_values - slope_l)
#slope_l_high_times = slope_l['time.dayofyear'].isel(time=slope_l_high_abs_dists.argmin('time'))

#roi = (slope_l_high_values - slope_l_low_values) / (slope_l_high_times - slope_l_low_times)
#roi.plot(cmap='RdYlGn', robust=False)