# Testing Notebook


Random online python phenology functions that might help:

- https://gist.github.com/YanCheng-go/d4e17831f294199443d0f7682558e608

- https://github.com/JavierLopatin/PhenoPY



In [None]:
# !pip install richdem
# !pip install xarray --upgrade

In [None]:
%matplotlib inline

import datacube
import matplotlib.pyplot as plt
from odc.algo import xr_reproject
import hdstats
import numpy as np
import pandas as pd
import sys
import xarray as xr
import datetime as dt
import os

sys.path.append('../Scripts')
from deafrica_datahandling import load_ard
from deafrica_bandindices import calculate_indices
from deafrica_plotting import display_map, rgb
from deafrica_temporal_statistics import xr_phenology, temporal_statistics, fast_completion, smooth, allNaN_arg
from datacube.utils.geometry import assign_crs
from deafrica_dask import create_local_dask_cluster

import warnings
warnings.filterwarnings("ignore", "Mean of empty slice")
warnings.simplefilter("ignore", FutureWarning)

%load_ext autoreload
%autoreload 2

In [None]:
create_local_dask_cluster()

### Connect to the datacube

In [None]:
dc = datacube.Datacube(app='Vegetation_phenology')

### Analysis parameters


In [None]:
# Set the vegetation proxy to use
veg_proxy = 'NDVI'

# Define area of interest
lat = -34.288#22.817 #-34.288 
lon = 20.012#28.518 #20.012 
lon_buffer = 0.0175
lat_buffer = 0.004

# Combine central lat,lon with buffer to get area of interest
lat_range = (lat-lat_buffer, lat+lat_buffer)
lon_range = (lon-lon_buffer, lon+lon_buffer)

# Set the range of dates for the analysis
years_range = ('2018-01', '2018-12')

## View the selected location

In [None]:
# display_map(x=lon_range, y=lat_range)

## Load cloud-masked Sentinel-2 data

The first step is to load Sentinel-2 data for the specified area of interest and time range. 
The `load_ard` function is used here to load data that has been masked for cloud, shadow and quality filters, making it ready for analysis.

In [None]:
# Create a reusable query
query = {
    'y': lat_range,
    'x': lon_range,
    'time': years_range,
    'measurements': ['blue'],
    'resolution': (-20,20),
    'output_crs': 'epsg:6933'
}

# Load available data from Landsat 8
# ds = load_ard(dc=dc,
#               products=['s2_l2a'],
#               dask_chunks={'x':1000, 'y':1000,'time':-1},
#               **query
#               )

# Load available data from Landsat 8
ds1 = load_ard(dc=dc,
              products=['s2_l2a'],
              **query
              )

# print(ds)

In [None]:
da = ds1.blue

In [None]:
def date_of_median(da, year, sample_lat, sample_lon):
    """
    da = xr.DataArray
        Assuming an annual time-series
    year = str
        year of time-series in 'da'
    sample_lat = float
        latitude pixel coordinate
    sample_lon = float
        longitude pixel coordinate
    
    """
    
    #calculate medians for each month
    monthly_medians = da.groupby('time.month').median()
    
    months = [str(i) for i in range(1,13)]
    indexes = [i for i in range(0,12)]
    
    dates=[]
    values=[]
    for month, index in zip(months,indexes): 
        
        #select the month of interest from da
        m = da.sel(time=year+"-"+month)
        
        #find regions with all-NaN slices
        mask = m.isnull().all('time')
        
        #calculate distance each pixel has from median
        distance = m - monthly_medians.isel(month=index)
        
        #index of the absolute minimum distance
        distance = distance.fillna(float(distance.max() + 1))
        distance=xr.ufuncs.fabs(distance)
        idx = distance.idxmin(dim='time', skipna=True).where(~mask)
        value = distance.sel(time=idx, method='nearest')
        values.append(value)
        dates.append(idx)
    
    #join into dataarray along new dimension
    dates = xr.concat(dates, "date of median")
    dist_from_median = xr.concat(values, 'dist_from_monthly_median')
    
    #select pixel
    dates = dates.sel(x=sample_lon, y=sample_lat, method='nearest')
    dist_from_median = dist_from_median.sel(x=sample_lon, y=sample_lat, method='nearest')
    
    return dates, dist_from_median


In [None]:
a,b=date_of_median(da, sample_lon=1929690., sample_lat=-4123870., year='2018')

In [None]:
b.plot()

In [None]:
c.plot(col='median_month_argmin', col_wrap=4, vmax=1000)

In [None]:
1929690

In [None]:
da.y

In [None]:
z.isel(day_of_month=1).time.dt.dayofyear.plot()

In [None]:
monthly_medians = da.groupby('time.month').median()

In [None]:
# ds1.blue.sel(time='2018-01')

In [None]:
distance = da.sel(time='2018-01') - monthly_medians.isel(month=1)


In [None]:
idx = distance.idxmax(dim='time', skipna=True)

In [None]:
da.sel(time=idx, method='nearest')

In [None]:
idx = allNaN_arg(distance, "time", "min").astype("int16")

In [None]:
median_date = jan.isel(time=idx)

In [None]:
median_date.time.dt.dayofyear.plot()

In [None]:
def allNaN_arg(da, dim, stat):
    """
    Calculate da.argmax() or da.argmin() while handling
    all-NaN slices. Fills all-NaN locations with an
    float and then masks the offending cells.
    Params
    ------
    xarr : xarray.DataArray
    dim : str, 
            Dimension over which to calculate argmax, argmin e.g. 'time'
    stat : str,
        The statistic to calculte, either 'min' for argmin()
        or 'max' for .argmax()
    Returns
    ------
    xarray.DataArray
    """
    # generate a mask where entire axis along dimension is NaN
    mask = da.isnull().all(dim)

    if stat == "max":
        y = da.fillna(float(da.min() - 1))
        y = y.idxmax(dim=dim, skipna=True).where(~mask)
        return y

    if stat == "min":
        y = da.fillna(float(da.max() + 1))
        y = y.idxmax(dim=dim, skipna=True).where(~mask)
        return y


**Once the load is complete**, we can plot the data as a true-colour image using the `rgb` function.  

In [None]:
# rgb(ds, index=[0,5], col_wrap=1)

In [None]:
# Calculate the chosen vegetation proxy index and add it to the loaded data set
# ds = (ds.nir - ds.red)/(ds.nir + ds.red)
ds = calculate_indices(ds, index=veg_proxy, collection='s2')
# ds1 = calculate_indices(ds1, index=veg_proxy, collection='s2')
# ds

In [None]:
stats=['discordance','abs_change','complexity','f_mean','central_diff']

In [None]:
x = temporal_statistics(ds1.NDVI, stats=stats)
x

In [None]:
%time
y = temporal_statistics(ds.NDVI, stats=stats).compute()
y

In [None]:
z = x - y

In [None]:
x.discordance.plot()

In [None]:
y.discordance.plot()

In [None]:
z.discordance.plot()

In [None]:
%%time
phen = xr_phenology(ds.NDVI,
                    method_sos='median',
                    method_eos='median',
                    complete='linear',
                    smoothing='rolling_mean').compute()
phen

In [None]:
%%time
phen1 = xr_phenology(ds1.NDVI,
                    method_sos='median',
                    method_eos='median',
                    complete='fast_complete',
                    smoothing='wiener')
phen1

In [None]:
z  = phen - phen1

In [None]:
phen.SOS.plot()

In [None]:
phen1.SOS.plot()

In [None]:
z.SOS.plot()

In [None]:
i_complete=fast_completion(i.NDVI)

In [None]:
i_complete.mean(['x', 'y']).plot()

In [None]:
x=smooth(i_complete)

In [None]:
x.mean(['x', 'y']).plot()

In [None]:
z = I_mapblocks - i_complete

In [None]:
template=i.NDVI.drop('spatial_ref')

I_mapblocks = i.NDVI.map_blocks(
    fast_completion,
    template=template)

# I_mapblocks

In [None]:
template=I_mapblocks

I_mapblocks_smooth = I_mapblocks.map_blocks(
    smooth,
    template=template).compute()

In [None]:
I_mapblocks_smooth.mean(['x', 'y']).plot()

In [None]:
j = I_mapblocks_smooth - x

In [None]:
j.mean(['x', 'y']).plot()

In [None]:
# def poly_fit(time, data, degree):
    
#     pfit = np.polyfit(time, data, degree) 
    
#     return np.transpose(np.polyval(pfit,time))

# def poly_fit_smooth(time, data, degree, n_pts):
#         """
#         """
    
#         time_smooth_inds = np.linspace(0, len(time), n_pts)
#         time_smooth = np.interp(time_smooth_inds, np.arange(len(time)), time)

#         data_smooth = np.array([np.array([coef * (x_val ** current_degree) for
#                                 coef, current_degree in zip(np.polyfit(time, data, degree),
#                                 range(degree, -1, -1))]).sum() for x_val in time_smooth])

#         return data_smooth

# def xr_polyfit(doy,
#                da,
#                degree,
#                interp_multiplier=1):    
    
#     # Fit polynomial curve to observed data points
#     if interp_multiplier==1:
#         print('Fitting polynomial curve to existing observations')
#         pfit = xr.apply_ufunc(
#             poly_fit,
#             doy,
#             da, 
#             kwargs={'degree':degree},
#             input_core_dims=[["time"], ["time"]], 
#             output_core_dims=[['time']],
#             vectorize=True,  
#             dask="parallelized",
#             output_dtypes=[da.dtype],
#         )
    
#     if interp_multiplier > 1:
#         print("Fitting polynomial curve to "+str(len(doy)*interp_multiplier)+
#                                                       " interpolated points")
#         pfit = xr.apply_ufunc(
#             poly_fit_smooth,  # The function
#             doy,# time
#             da,#.chunk({'time': -1}), #the data
#             kwargs={'degree':degree, 'n_pts':len(doy)*interp_multiplier},
#             input_core_dims=[["time"], ["time"]], 
#             output_core_dims=[['new_time']], 
#             output_sizes = ({'new_time':len(doy)*interp_multiplier}),
#             exclude_dims=set(("time",)),
#             vectorize=True, 
#             dask="parallelized",
#             output_dtypes=[da.dtype],
#         ).rename({'new_time':'time'})
    
#         # Map 'dayofyear' onto interpolated time dim
#         time_smooth_inds = np.linspace(0, len(doy), len(doy)*interp_multiplier)
#         new_datetimes = np.interp(time_smooth_inds, np.arange(len(doy)), doy)
#         pfit = pfit.assign_coords({'time':new_datetimes})
    
#     return pfit

# # da=xr_polyfit(dayofyear=dayofyear, 
# #               da=da,
# #               degree=degree,
# #               interp_multiplier=interp_multiplier)

In [None]:
# #set up figure
# fig, ax = plt.subplots(nrows=5,ncols=2,figsize=(18,25), sharex=True, sharey=True)

# #start of season
# temp_stats.discordance.plot(ax=ax[0,0])
# ax[0,0].set_title('discordance')
# temp_stats.f_std.plot(ax=ax[0,1])
# ax[0,1].set_title('f_std')

# #peak of season
# temp_stats.f_mean.plot(ax=ax[1,0])
# ax[1,0].set_title('f_mean')
# phen.f_median.plot(ax=ax[1,1])
# ax[1,1].set_title('f_median')

# #end of season
# temp_stats.mean_change.plot(ax=ax[2,0])
# ax[2,0].set_title('mean_change')
# phen.med_change.plot(ax=ax[2,1])
# ax[2,1].set_title('med_change')

# #Length of Season
# temp_stats.abs_change.plot(ax=ax[3,0])
# ax[3,0].set_title('abs_change');

# #Amplitude
# temp_stats.complexity.plot(ax=ax[3,1])
# ax[3,1].set_title('complexity')

# #rate of growth
# temp_stats.central_diff.plot(ax=ax[4,0])
# ax[4,0].set_title('central_diff')

# #rate of Sensescence
# temp_stats.num_peaks.plot(ax=ax[4,1])
# ax[4,1].set_title('num_peaks');

# plt.tight_layout();