# Try our phenology code per pixel

Extremely difficult because per-pixel we don't return an even number of peaks or troughs and the time at which peaks occur is different for every pixel so we'd need to different length arrays...or you'd need to fill...ergh

In [None]:
%matplotlib inline

import sys
import warnings
import numpy as np
import xarray as xr
import pandas as pd
from scipy import stats
from scipy import signal
import contextily as ctx
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs

sys.path.append('/g/data/os22/chad_tmp/Aus_phenology/src')
from phenology import extract_peaks_troughs, phenometrics

%load_ext autoreload
%autoreload 2

In [None]:
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask
start_local_dask()

## Analysis Parameters


In [None]:
ds_path = '/g/data/os22/chad_tmp/AusENDVI/results/publication/AusENDVI-clim_MCD43A4_gapfilled_1982_2022.nc'
chunks = dict(latitude=250, longitude=250)

## Open data

In [None]:
ds = assign_crs(xr.open_dataset(ds_path, chunks=chunks), crs='EPSG:4326')
ds = ds.rename({'AusENDVI_clim_MCD43A4':'NDVI'})
ds = ds['NDVI']

## Smoothing filters

In [None]:
#resample before we smooth
ds = ds.resample(time="2W").interpolate("linear")

# # Savitsky-Golay smoothing
ds_smooth = xr.apply_ufunc(
        signal.savgol_filter,
        ds,
        input_core_dims=[['time']],
        output_core_dims=[['time']],
        kwargs=dict(
            window_length=11,
            polyorder=3,
            deriv=0,
            mode='interp'),
        dask='parallelized',
    output_dtypes='float32'
    )

# ufunc reordered dims for some reason
ds_smooth = ds_smooth.transpose('time', 'latitude','longitude')
ds_smooth = ds_smooth.resample(time="1W").interpolate("slinear")

## Extract phenometrics 

## test per pixel

In [None]:
def xr_find_peaks(ds, peak_or_trough='peak', rolling=12, distance=12, prominence=0.01, plateau_size=2):
    
    def _find_peaks(ds):
        if peak_or_trough=='peak':
            ds = ds.rolling(time=rolling, min_periods=1, center=True).max()
            peaks = scipy.signal.find_peaks(ds, #invert
                         distance=distance,
                         prominence=prominence,
                         plateau_size=plateau_size
                                           )

        if peak_or_trough=='trough':
            ds = ds.rolling(rolling=rolling, min_periods=1, center=True).min()
            peaks = scipy.signal.find_peaks(ds*-1, #invert
                         distance=distance,
                         prominence=prominence,
                         plateau_size=plateau_size
                                           )
        return peaks
    
    ds_n_peaks = xr.apply_ufunc(_find_peaks,
                      ds, 
                      peak_or_trough,
                      rolling,
                      distance,
                      prominence,
                      plateau_size,
                      input_core_dims=[['time'],['time'],['time'],['time'],['time'],['time']],
                      vectorize=True, 
                      dask='parallelized',
                      output_dtypes=[np.float32]
                             )

    return ds_n_peaks.rename(peak_or_trough)

In [None]:
%%time
n_peaks = xr_find_peaks(ds_smooth.sel(time=slice('2004', '2005')))