# Extract covariable timeseries over the ecoregions

And save to disk as this process is time consuming.
Use a hugemem queue.


***

**To Do:**
* Consider sourcing climate data from somewhere other than ANUClim, perhaps ERA5 or AGCD, or TerraClimate
* Could get some of the climate variables from AusEFlux (SILO, OzWald)
* Get burned area from here: https://data.cci.ceda.ac.uk/thredds/catalog/esacci/fire/data/burned_area/AVHRR-LTDR/grid/v1.1/catalog.html

In [None]:
%matplotlib inline
import os
import sys
import math
import pickle
import warnings
import dask
import pandas as pd
import xarray as xr
import rioxarray as rxr
import geopandas as gpd
import numpy as np

from odc.geo.xr import assign_crs
from odc.geo.geom import Geometry

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _feature_datasets import _vegetation_fractions
from _utils import start_local_dask

In [None]:
start_local_dask()

## Analysis Parameters


In [None]:
# save_file = '/g/data/os22/chad_tmp/Aus_phenology/data/ecoregions_NDVI_timeseries.pkl'
# save_file = '/g/data/os22/chad_tmp/Aus_phenology/data/IBRA_regions_NDVI_timeseries.pkl'
save_file = '/g/data/os22/chad_tmp/Aus_phenology/data/IBRA_subregions_climate_timeseries.pkl'

# ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/vectors/Ecoregions2017_modified.geojson'
# ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/vectors/IBRAv7_regions_modified.geojson'
ecoregions_file = '/g/data/os22/chad_tmp/Aus_phenology/data/vectors/IBRAv7_subregions_modified.geojson'

# var='ECO_NAME'
# var='REG_NAME_7'
var='SUB_NAME_7'

## Load climate data

In [None]:
base_clim = '/g/data/os22/chad_tmp/AusENDVI/data/5km/'
co2 = xr.open_dataset(base_clim+'CO2_5km_monthly_1982_2022.nc')
rain = xr.open_dataset(base_clim+'rain_5km_monthly_1981_2022.nc').sel(time=slice('1982','2022')).drop_vars('spatial_ref')
srad = xr.open_dataset(base_clim+'srad_5km_monthly_1982_2022.nc').drop_vars('spatial_ref')
tavg = xr.open_dataset(base_clim+'tavg_5km_monthly_1982_2022.nc').drop_vars('spatial_ref')
vpd = xr.open_dataset(base_clim+'vpd_5km_monthly_1982_2022.nc').drop_vars('spatial_ref')

## Calculate tree fractions

This will be our measure of woody encroachment

In [None]:
results='/g/data/os22/chad_tmp/Aus_phenology/data/'
ndvi_path='/g/data/os22/chad_tmp/AusENDVI/results/publication/AusENDVI-clim_MCD43A4_gapfilled_1982_2022.nc'
ndvi_min='/g/data/os22/chad_tmp/AusEFlux/data/ndvi_of_baresoil_5km.nc'
ndvi_max=0.91
dask_chunks={'latitude': 250, 'longitude': 250, 'time': -1}

# NDVI value of bare soil (supplied by Luigi Renzullo)
ndvi_min = xr.open_dataarray(ndvi_min,
                            chunks=dict(latitude=dask_chunks['latitude'],
                            longitude=dask_chunks['longitude'])
                            )
#ndvi data is here
ds = xr.open_dataset(ndvi_path, chunks=dask_chunks)
ds = ds.rename({'AusENDVI_clim_MCD43A4':'NDVI'})
ds = ds['NDVI']

#calculate f-total
ft = (ds - ndvi_min) / (ndvi_max - ndvi_min)
ft = xr.where(ft<0, 0, ft)
ft = xr.where(ft>1, 1, ft)

#calculate initial persistent fraction (equation 1 & 2 in Donohue 2009)
persist = ft.rolling(time=7, min_periods=1).min()
persist = persist.rolling(time=9, min_periods=1).mean()

#calculate initial recurrent fraction (equation 3 in Donohue 2009)
recurrent = ft - persist

###------- equations 4 & 5 in Donohue 2009----------------
persist = xr.where(recurrent<0, persist - np.abs(recurrent), persist) #eq4
recurrent = ft - persist # eq 5
## ---------------------------------------------------------

#ensure values are between 0 and 1
persist = xr.where(persist<0, 0, persist)
recurrent = xr.where(recurrent<0, 0, recurrent)

#assign variable names
recurrent.name='grass'
persist.name='trees'

# Aggregate to annual layers
# Use the maximum fraction of trees and grass to create annual layers.
# Bare soil is the residual
persist_annual = persist.resample(time='1Y').max().compute()
recurrent_annual = recurrent.resample(time='1Y').max().compute()
bare_annual = 1-(persist_annual+recurrent_annual)
bare_annual.name='bare'

#create a monthly timeseries (same vale for each month within a year)
dss_trees=[]
dss_grass=[]
dss_bare=[]
for y in bare_annual.time.dt.year.values:
    # print(y)
    y = str(y)
    time = pd.date_range(y+"-01", y+"-12", freq='MS') 
    time = [t+pd.Timedelta(14, 'd') for t in time]

    #trees
    ds_persist = persist_annual.sel(time=y).squeeze().drop('time')
    ds_persist = ds_persist.expand_dims(time=time)
    dss_trees.append(ds_persist)

    #grass
    ds_recurrent = recurrent_annual.sel(time=y).squeeze().drop('time')
    ds_recurrent = ds_recurrent.expand_dims(time=time)
    dss_grass.append(ds_recurrent)

    ds_bare = bare_annual.sel(time=y).squeeze().drop('time')
    ds_bare = ds_bare.expand_dims(time=time)
    dss_bare.append(ds_bare)

# join all the datasets back together
trees = xr.concat(dss_trees, dim='time').sortby('time')

# add right metadata
trees.attrs['nodata'] = np.nan
trees = assign_crs(trees, crs='EPSG:4326')
#export
# trees.to_netcdf(results+'trees_5km.nc')

## Merge all variables

In [None]:
covars = xr.merge([co2, rain, srad, tavg, vpd, trees])
covars = assign_crs(covars, crs='EPSG:4326')
covars = covars.transpose('time', 'latitude','longitude')
covars = covars.sel(time=slice('1982', '2022'))

for v in covars.data_vars:
    try:
        del covars[v].attrs['grid_mapping']
    except:
        continue

In [None]:
# covars['vpd'].isel(time=1).plot()

In [None]:
# save to disk for multiprocessing next
covars.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/data/covars.nc')

### Summarise covariables data over polygons

Slow so using Dask to multiprocess

In [None]:
gdf = gpd.read_file(ecoregions_file)

In [None]:
#decorate the function
@dask.delayed
def clim_zonal_timeseries(index, ds, gdf, var):
    
    ds = assign_crs(ds, crs='EPSG:4326')
    geom = Geometry(geom=gdf.iloc[index].geometry, crs=gdf.crs)
    yy = ds.odc.mask(poly=geom)
    yy = yy.dropna(dim='longitude',
          how='all').dropna(dim='latitude', how='all')

    #summarise into 1d timeseries
    yy = yy.mean(['latitude', 'longitude'])

    if np.isnan(yy['rain']).sum() == len(yy.time):
        yy=np.nan

    return yy

#delay open datasets
path='/g/data/os22/chad_tmp/Aus_phenology/data/covars.nc'
dss = dask.delayed(xr.open_dataset)(path)
gdff = dask.delayed(gpd.read_file)(ecoregions_file)

results_clim={}
# lazily loop through polygons
for index, row in gdf.iterrows():
    zz = clim_zonal_timeseries(index, dss, gdff, var)
    results_clim[row[var]] = zz

In [None]:
%%time
results_clim = dask.compute(results_clim)[0] #bring into memory

# remove NaNs
results_clim = {k: results_clim[k] for k in results_clim if not type(results_clim[k]) is float}

## Save data

In [None]:
with open(save_file, 'wb') as f:
    pickle.dump(results_clim, f)