# Seasonal Vegetation Anomaly Forecasts



## Background

Understanding how the vegetated landscape responds to longer-term environmental drivers such as the El Nino Southern Oscillation (ENSO) or climate change, requires the calculation of seasonal anomalies. Seasonal anomalies subtract the long-term seasonal mean from a time-series, thus removing seasonal variability and highlighting change related to longer-term drivers. 

## Import libraries

In [19]:
import xarray as xr
from datacube.utils.cog import write_cog
import matplotlib.pyplot as plt
import geopandas as gpd
import pandas as pd
import sys
import os

sys.path.append('../dea-notebooks/Scripts')
from dea_plotting import display_map, map_shapefile
from anomalies import calculate_anomalies
from dea_dask import create_local_dask_cluster
from dea_classificationtools import HiddenPrints

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Set up local dask cluster

Dask will create a local cluster of cpus for running this analysis in parallel. If you'd like to see what the dask cluster is doing, click on the hyperlink that prints after you run the cell and you can watch the cluster run.

In [2]:
create_local_dask_cluster()

0,1
Client  Scheduler: tcp://127.0.0.1:39561  Dashboard: /user/chad/proxy/8787/status,Cluster  Workers: 1  Cores: 8  Memory: 61.42 GB


## Analysis Parameters

The following cell sets the parameters, which define the area of interest and the season to conduct the analysis over. The parameters are:

* `shp_fpath`: Provide a filepath to a shapefile that defines your AOI, if not using a shapefile then put `None` here.
* `lat`, `lon`, `buffer`: If not using a shapefile to define the AOI, then use a latitide, longitude, and buffer to define a query 'box'.
* `year`: The year of interest, e.g. `'2018'`
* `season`:  The season of interest, e.g `'DJF'`,`'JFM'`, `'FMA'` etc
* `name` : A string value used to name the output geotiff, e.g 'NSW'
* `dask_chunks` : dictionary of values to chunk the data using dask e.g. `{'x':3000, 'y':3000}`

In [11]:
prediction_year = '2020'
prediction_quarter = 'MAM'

shp_fpath = 'data/mdb_shps/GWYDIR RIVER.shp'
lat, lon, buff = -34.958, 150.281, 0.35
resolution = (-120,120)
dask_chunks = {'x':1000, 'y':1000}


### Examine your area of interest

In [None]:
# map_shapefile(gpd.read_file(shp_fpath), attribute='DNAME')

In [None]:
# display_map(y=(lat-buff, lat + buff), x=(lon-buff, lon + buff))

(2017, 2020)


## Calculate time series of NDVI anomalies


In [None]:
import warnings
warnings.filterwarnings("ignore")

#define the 3-month intervals
quarter= {'JFM': [1,2,3],
           'FMA': [2,3,4],
           'MAM': [3,4,5],
           'AMJ': [4,5,6],
           'MJJ': [5,6,7],
           'JJA': [6,7,8],
           'JAS': [7,8,9],
           'ASO': [8,9,10],
           'SON': [9,10,11],
           'OND': [10,11,12],
           'NDJ': [11,12,1],
           'DJF': [12,1,2],
                      }
#get years to calculate
years_range = int(prediction_year) - 3, int(prediction_year) 
years = [str(i) for i in range(years_range[0], years_range[1])]

#loop through each 3 month period and calculate the anomaly
z=[]
for year in years:
    for q in quarter:
        print(year, q, end="\r")
        with HiddenPrints():
            anomalies = calculate_anomalies(shp_fpath=shp_fpath,
                                query_box=(lat,lon,buff),
                                resolution=resolution,
                                year=year,
                                season=q,
                                dask_chunks=dask_chunks).compute()
        
        z.append(anomalies.rename(year+'_'+q))

2017 MAM

In [None]:
# Build back into time-series
stand_anomalies=xr.concat(z, dim=pd.date_range(start='2/'+str(years_range[0]), end='1/'+str(years_range[1]), freq='M')).rename({'concat_dim':'time'})

stand_anomalies.mean(['x','y']).plot(figsize=(11,5))

## Make a forecast

`AutoReg` doesn't like the all-NaN's slices outide the mask extent, run `stand_anomalies.fillna(-999)`

In [None]:
mask = stand_anomalies.notnull().all('time')

In [None]:
#mask where its all-NaN's
stand_anomalies = stand_anomalies.fillna(-999)

In [None]:
test_length=1
window=20
lags=20

In [None]:
%%time
def xr_autoregress(da, test_length, window, lags):
    #dropna conveneiently with pandas
    da =  da[~np.isnan(da)]
    # split dataset
    train, test = da[1:len(da)-test_length], da[len(da)-test_length:]
    # train autoregression
    model = AutoReg(train, lags=lags)
    model_fit = model.fit()
    coef = model_fit.params

    # walk forward over time steps in test
    history = train[len(train)-window:]
    history = [history[i] for i in range(len(history))]

    predictions = list()
    for t in range(len(test)):
        length = len(history)
        lag = [history[i] for i in range(length-window,length)]
        yhat = coef[0]
        for d in range(window):
            yhat += coef[d+1] * lag[window-d-1]
        obs = test[t]
        predictions.append(yhat)
        history.append(obs) 
    
    return np.array(predictions).flatten()

predict = xr.apply_ufunc(xr_autoregress,
                      stand_anomalies, #.chunk(dict(x=750,y=750,time=-1)),
                      kwargs={'test_length':test_length,'window':window,'lags':window},
                      input_core_dims=[['time']],
                      output_core_dims=[['predictions']], 
                      output_sizes=({'predictions':test_length}),
                      exclude_dims=set(('time',)),
                      vectorize=True,
                      dask="parallelized",
                      output_dtypes=[stand_anomalies.dtype]).compute()

print(predict)

In [None]:
predict = predict.where(mask)

In [None]:
predict.plot(size=6, vmin=-2.0, vmax=2, cmap='BrBG')
plt.title('Standardised NDVI Anomaly one-month prediction');

In [None]:
stand_anomalies=stand_anomalies.where(mask)
stand_anomalies.isel(time=-1).plot(size=6, vmin=-2.0, vmax=2, cmap='BrBG')
plt.title('Standardised NDVI Anomaly observation');

In [None]:
diff = predict - stand_anomalies.isel(time=-1)

diff.plot(size=6, vmin=-2.0, vmax=2, cmap='RdBu')
plt.title('Difference');

In [None]:
diff.mean(['x','y'])