# Create example plots that show methods

Each step from input NDVI to output phenology metrics

In [None]:
%matplotlib inline

import sys
import pickle
import warnings
import numpy as np
import xarray as xr
import pandas as pd
import seaborn as sb
import scipy.signal
from scipy import stats
import geopandas as gpd
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs
from odc.geo.geom import Geometry

sys.path.append('/g/data/os22/chad_tmp/Aus_phenology/src')
from phenology import extract_peaks_troughs, phenometrics

%load_ext autoreload
%autoreload 2

## Analysis Parameters


In [None]:
# timeseries_file = '/g/data/os22/chad_tmp/Aus_phenology/data/ecoregions_NDVI_timeseries.pkl'
# ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/Ecoregions2017_aus_processed.geojson'

In [None]:
# k='Australian Alps montane grasslands'
# k='Rainfed cropping savanna'
# k='Great Sandy-Tanami desert'
# k='Eastern Australian warm temperate forests'
k='Oberon rainfed crop'

ds_path = '/g/data/os22/chad_tmp/AusENDVI/results/publication/AusENDVI-clim_MCD43A4_gapfilled_1982_2022.nc'

# ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/vectors/Ecoregions2017_modified.geojson'
# ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/vectors/IBRAv7_regions_modified.geojson'
ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/vectors/IBRAv7_subregions_modified.geojson'

var='SUB_NAME_7'
# var='REG_NAME_7'
# var='ECO_NAME'

## Open data

In [None]:
ds = assign_crs(xr.open_dataset(ds_path), crs='EPSG:4326')
ds = ds.rename({'AusENDVI_clim_MCD43A4':'NDVI'})
ds = ds['NDVI']

## Clip to a region

In [None]:
gdf = gpd.read_file(ecoregions_file)
gdf = gdf[gdf[var]==k]

In [None]:
#clip to ecoregion
geom = Geometry(geom=gdf.iloc[0].geometry, crs=gdf.crs)
ds = ds.odc.mask(poly=geom)
ds = ds.dropna(dim='longitude', how='all').dropna(dim='latitude', how='all')

# #summarise into 1d timeseries
ds = ds.mean(['latitude', 'longitude'])

## Smoothing filters

In [None]:
#resample before we smooth
ds_smooth = ds.resample(time="2W").interpolate("linear")

# # Savitsky-Golay smoothing
ds_smooth = xr.apply_ufunc(
        scipy.signal.savgol_filter,
        ds_smooth,
        input_core_dims=[['time']],
        output_core_dims=[['time']],
        kwargs=dict(
            window_length=11,
            polyorder=3,
            deriv=0,
            mode='interp'),
        dask='parallelized'
    )

with plt.style.context('ggplot'):
    fig,ax=plt.subplots(1,1, figsize=(12,4))
    ds.sel(time=slice('2010','2022')).plot(ax=ax, label='NDVI Observed')
    ds_smooth.sel(time=slice('2010','2022')).plot(ax=ax, label='Bi-monthly & S-G smoothed')
    ax.legend()
    ax.set_title(k);
    ax.set_xlabel(None)

## Interpolation to daily

In [None]:
ds_trim = ds.sel(time=slice('2010','2016'))
ds_smooth_trim = ds_smooth.sel(time=slice('2010','2016'))

quad=ds_smooth_trim.dropna(dim='time', how='all').resample(time='1D').interpolate(kind='quadratic')

In [None]:
with plt.style.context('ggplot'):
    fig,ax=plt.subplots(1,1, figsize=(12,4))
    plt.scatter(ds_trim.time, ds_trim.values, label='Observed NDVI', color='tab:orange')
    plt.scatter(ds_smooth_trim.time, ds_smooth_trim.values, label='Bi-monthly & S-G smoothed', color='tab:blue')
    # plt.plot(t_fit*365, ndvi_fit, label='Double logistic curve', color='red')
    plt.plot(quad.time, quad.values, label='Quadratic Interp. Daily', color='tab:green')
    plt.xlabel(None)
    plt.ylabel('NDVI')
    plt.legend()
    plt.title(k)
    plt.show()

## Show peak/trough extraction

In [None]:
d={k:ds_smooth.dropna(dim='time', how='all').resample(time='1D').interpolate(kind='quadratic')}

In [None]:
peaks_troughs = extract_peaks_troughs(d,
                                      rolling=90,
                                      distance=90,
                                      prominence='auto',
                                      plateau_size=10)

In [None]:
with plt.style.context('ggplot'):
    fig,ax=plt.subplots(1,1,figsize=(12,4))
    ax.scatter(peaks_troughs[k].index[-47:-1],  peaks_troughs[k]['peaks'][-47:-1],  c='black', label='Peak', zorder=10);
    ax.scatter(peaks_troughs[k].index[-47:-1],  peaks_troughs[k]['troughs'][-47:-1],  c='tab:purple', label='Trough', zorder=11);
    d[k].sel(time=slice('2002','2022')).plot(ax=ax, label='Observed', c='tab:blue')
    d[k].rolling(time=90, min_periods=1, center=True).max().sel(time=slice('2002','2022')).plot(ax=ax, label='Rolling maximum',c='tab:green')
    d[k].rolling(time=90, min_periods=1, center=True).min().sel(time=slice('2002','2022')).plot(ax=ax, label='Rolling minimum', c='tab:orange')
    ax.legend(ncols=2)
    ax.set_xlabel(None)
    ax.set_title(k);

## Extract phenometrics 

<!-- import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def double_logistic_function(t, wNDVI, mNDVI, S, A, mS, mA):
    sigmoid1 = 1 / (1 + np.exp(-mS * (t - S)))
    sigmoid2 = 1 / (1 + np.exp(mA * (t - A)))
    seasonal_term = sigmoid1 + sigmoid2 - 1
    return wNDVI + (mNDVI - wNDVI) * seasonal_term

def weight_function(t, S, A, r):
    tr = 100 * (t - S) / (A - S)
    tr = np.clip(tr, 0, 100)
    return np.exp(-np.abs(r) / (1 + tr / 10))

def fit_curve(t, ndvi_observed):
    initial_guess = [np.min(ndvi_observed), np.max(ndvi_observed), np.mean(t), np.mean(t), 1, 1]
    params, _ = curve_fit(double_logistic_function, t, ndvi_observed, p0=initial_guess, maxfev=10000)
    residuals = ndvi_observed - double_logistic_function(t, *params)
    weights = weight_function(t, params[2], params[3], residuals)
    params, _ = curve_fit(double_logistic_function, t, ndvi_observed, p0=initial_guess, sigma=weights, maxfev=10000)
    return params

doys = ndvi_cycle.time.dt.dayofyear.values[2:]
doys_frac = doys/365
values = ndvi_cycle.values[2:]

##Fit the curve
parameters = fit_curve(doys_frac, values)

##Plot the observed NDVI values
plt.scatter(doys, values, label='Observed NDVI')

##Generate points for the fitted curve
t_fit = np.linspace(min(doys_frac), max(doys_frac), 365)
ndvi_fit = double_logistic_function(t_fit, *parameters)

##Plot the fitted curve
plt.plot(t_fit*365, ndvi_fit, label='Fitted Curve', color='red')

plt.xlabel('Day of the Year')
plt.ylabel('NDVI')
plt.legend()
plt.title('Double Logistic Curve Fitting for NDVI Observations')
plt.show() -->


<!-- def xr_count_peaks(ds, order=16):
    def _find_peaks(ds):
        peaks = scipy.signal.argrelextrema(ds, np.greater, order=order)
        # peaks = scipy.signal.find_peaks(ds, height=0.2, distance=order)
        return len(peaks[0])
    
    ds_n_peaks = xr.apply_ufunc(_find_peaks,
                              ds, 
                              input_core_dims=[['time']],
                              vectorize=True, 
                              dask='parallelized',
                              output_dtypes=[np.float32]
                             )

    return ds_n_peaks.rename('n_peaks')

%%time
n_peaks = xr_count_peaks(ds_smooth.sel(time=slice('2004-09', '2006-03'))) -->

In [None]:
d={k:ds_smooth.dropna(dim='time', how='all').resample(time='1D').interpolate(kind='quadratic').sel(time=slice('2010','2022'))}
eco_regions_phenometrics = phenometrics(d,
                                        rolling=90,
                                        distance=90,
                                        prominence='auto',
                                        plateau_size=10,
                                        amplitude=0.20
                                       )

In [None]:
with plt.style.context('ggplot'):
    fig,ax = plt.subplots(1,1, figsize=(12,4))
    # d=.sel(time=slice('2010','2022'))
    d[k].sel(time=slice('2010','2022')).plot(ax=ax, color='tab:blue', label='Daily NDVI')
    ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(eco_regions_phenometrics[k].SOS.values, eco_regions_phenometrics[k].SOS_year.values)],
               y=eco_regions_phenometrics[k].vSOS,
              c='tab:green', label='SOS', zorder=10)
    
    ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(eco_regions_phenometrics[k].EOS.values, eco_regions_phenometrics[k].EOS_year.values)],
               y=eco_regions_phenometrics[k].vEOS,
              c='tab:orange', label='EOS', zorder=10)
    
    ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(eco_regions_phenometrics[k].POS.values, eco_regions_phenometrics[k].POS_year.values)],
                   y=eco_regions_phenometrics[k].vPOS,
                  c='black', label='POS', zorder=10)
        
    ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(eco_regions_phenometrics[k].TOS.values, eco_regions_phenometrics[k].TOS_year.values)],
               y=eco_regions_phenometrics[k].vTOS,
              c='tab:purple', label='TOS', zorder=10)
    
    ax.set_xlabel(None)
    ax.set_ylabel('NDVI')
    ax.set_title(k, fontsize=15)
    ax.legend(ncols=2)
    plt.tight_layout()