# Annual integrated NDVI

Rather than calculating the integral of the season (integral from SOS to EOS), instead calculate integrated NDVI each calendar year, then compute the trends. Also need to subtract the soil signal.

In [None]:
%matplotlib inline
import os
import sys
import pingouin as pg
import xarray as xr
import numpy as np
import pandas as pd
import geopandas as gpd
import contextily as ctx
import matplotlib.pyplot as plt
from xarrayMannKendall import Mann_Kendall_test
from odc.geo.xr import assign_crs

import warnings
warnings.simplefilter(action='ignore')

In [None]:
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask
start_local_dask(n_workers=12, threads_per_worker=1, memory_limit='300GiB')

#### Open NDVI data and subtract soil signal

NDVI data has previously been interpolated to biweekly and S-G smoothed

In [None]:
dask = dict(latitude=250, longitude=250, time=-1)
path = '/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/NDVI_smooth_AusENDVI-clim_MCD43A4.nc'
ds = xr.open_dataset(path, chunks=dask)
ds = ds['NDVI']

# subtract soil signal
ndvi_min_path =f'/g/data/xc0/project/AusEFlux/data/ndvi_of_baresoil_5km.nc'
ndvi_min = xr.open_dataarray(ndvi_min_path, chunks=dask)
ds = ds - ndvi_min #bare soil value
ds.name = 'NDVI'
ds

## Annual iNDVI

On daily NDVI! Need hugemem queue for this

In [None]:
ds = ds.sel(time=slice('1982', '2021')).fillna(0)
ds = ds.resample(time='1D').interpolate(kind='quadratic')
indvi = ds.groupby('time.year').map(lambda x: x.integrate('time', datetime_unit='D')).compute()

### mask urban and water

In [None]:
mask = xr.open_dataarray('/g/data/xc0/project/AusEFlux/data/urban_water_mask_5km.nc').rename({'x':'longitude', 'y':'latitude'})
indvi = indvi.where(mask!=1)

In [None]:
indvi.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/results/iNDVI.nc')

### Trends

In [None]:
indvi = xr.open_dataarray('/g/data/os22/chad_tmp/Aus_phenology/results/iNDVI.nc')
mask = xr.open_dataarray('/g/data/xc0/project/AusEFlux/data/urban_water_mask_5km.nc')

In [None]:
trends = Mann_Kendall_test(indvi,
                               alpha=0.05,
                               method='theilslopes',
                               coords_name={'time':'year','x':'longitude','y':'latitude'}
                           ).compute()

trends = trends.where(mask!=1)              

In [None]:
fig,ax = plt.subplots(1,1, figsize=(6,6),sharey=True, layout='constrained')

im = trends.trend.where(trends.trend!=0).plot(ax=ax, cmap='PuOr', vmin=-1, vmax=1, add_colorbar=False)
ctx.add_basemap(ax, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
xr.plot.contourf(trends.trend.where(trends.signif), ax=ax, alpha=0, hatches=['....'], add_colorbar=False)
ax.set_title(None)
ax.set_yticklabels([])
ax.set_ylabel('')
ax.set_xlabel('')
ax.set_xticklabels([]);

cb = fig.colorbar(im, ax=ax, shrink=0.65, orientation='vertical', label='NDVI yr\u207B\u00B9')
fig.savefig("/g/data/os22/chad_tmp/Aus_phenology/results/figs/iNDVI_trends.png", bbox_inches='tight', dpi=300)

## Annual max NDVI

In [None]:
annual_max = ds.groupby('time.year').max().compute()

In [None]:
mask = xr.open_dataarray('/g/data/xc0/project/AusEFlux/data/urban_water_mask_5km.nc').rename({'x':'longitude', 'y':'latitude'})
annual_max = annual_max.where(~mask).where(annual_max>0)

### Trends

In [None]:
trends_vpos = Mann_Kendall_test(annual_max,
                               alpha=0.05,
                               method='theilslopes',
                               coords_name={'time':'year','x':'longitude','y':'latitude'}
                           ).compute()

In [None]:
trends_vpos

In [None]:
trees = xr.open_dataset('/g/data/os22/chad_tmp/AusENDVI/data/5km/trees_5km_monthly_1982_2022.nc')['trees']
trees = trees.mean('time').rename('longitude':'x','latitude':'y'})
tree_mask = xr.where(trees>=0.5, 1, 0)

#and significant change in vPOS
# tree_mask = ((tree_mask) & (trends_vpos.p <=0.05))

In [None]:
tree_mask.plot()

## Partial correlation

In [None]:
import os
import xarray as xr
import numpy as np
import pandas as pd
from odc.geo.xr import assign_crs

import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask
start_local_dask(n_workers=15, threads_per_worker=1, memory_limit='120GiB')

In [None]:
base = '/g/data/ub8/au/OzWALD/8day/NDVI/'
files = [f'{base}{i}' for i in os.listdir(base) if i.endswith(".nc")]
files.sort()
years = [str(i) for i in range(2000, 2024)]
for f, y in zip(files, years):
    print(f)    
    ds = xr.open_dataset(f,chunks=dict(time=-1, latitude=1000, longitude=1000))
    # ds = ds.transpose('time', 'latitude', 'longitude')
    
    #tidy up
    ds = assign_crs(ds, crs='epsg:4326')
    ds = ds.to_array()
    ds = ds.squeeze().drop('variable')
    ds.attrs['nodata'] = np.nan
    
    #resample time
    ds = ds.resample(time='MS', loffset=pd.Timedelta(14, 'd')).mean().compute()
    ds.attrs['nodata'] = np.nan
    ds = ds.transpose('time', 'latitude', 'longitude')
    ds.astype('float32').to_netcdf(f'/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/ozwald_ndvi/NDVI_{y}.nc')


In [None]:
base = '/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/ozwald_ndvi/'
files = [f'{base}/{i}' for i in os.listdir(base) if i.endswith(".nc")]
files.sort()
# chunks=dict(time=1, latitude=1000, longitude=1000)
#combine annual files into one file
ds = xr.open_mfdataset(files).sel(time=slice('2001', '2022'))

ds = ds.chunk(dict(time=-1)).interpolate_na(dim='time', method='linear', limit=2)

# ds_monthly = ds.groupby('time.month').mean()
# ds = ds.groupby("time.month").fillna(ds_monthly)
# ds = assign_crs(ds, crs='epsg:4326')



In [None]:
ds = ds.compute()

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/NDVI_OzWALD_500m.nc')

In [None]:
# v.plot(figsize=(14,5))

In [None]:
# with plt.style.context('ggplot'):
#     fig,axes = plt.subplots(5,2, figsize=(20,15), layout='constrained')
#     for ax, (k,v) in zip(axes.reshape(-1), flux_tss.items()):
#         v.plot(ax=ax, c='tab:blue', label='OzFlux')
#         ndvi_tss[k].plot(ax=ax, c='tab:red', label='AusEFlux')
#         ax.set_title(k)
#         ax.grid(axis='y', which='both')
#         ax.set_xlabel(None)
#         ax.set_ylabel('GPP gC/m2/month')
#         ax.legend()
#         # ax.set_ylim(0.10,0.9)
#         # ax1.set_ylim(-1,350)
    
        
#         # ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(ndvi_pheno[k].SOS.values, ndvi_pheno[k].SOS_year.values)],
#         #            y=ndvi_pheno[k].vSOS,
#         #           c='tab:green', label='SOS', zorder=10)
        
#         # ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(ndvi_pheno[k].EOS.values, ndvi_pheno[k].EOS_year.values)],
#         #            y=ndvi_pheno[k].vEOS,
#         #           c='tab:purple', label='EOS', zorder=10)
        
#         # ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(ndvi_pheno[k].POS.values, ndvi_pheno[k].POS_year.values)],
#         #                y=ndvi_pheno[k].vPOS,
#         #               c='black', label='POS', zorder=10)
            
#         # ax.scatter(x=[pd.to_datetime(d-1, unit='D', origin=str(int(y))) for d,y in zip(ndvi_pheno[k].TOS.values, ndvi_pheno[k].TOS_year.values)],
#         #            y=ndvi_pheno[k].vTOS,
#         #           c='tab:orange', label='TOS', zorder=10)
# fig.savefig('/g/data/os22/chad_tmp/Aus_phenology/results/figs/flux_tower_validate_GPP.png',
#             bbox_inches='tight', dpi=300)

In [None]:
    # # Index NDVI at location and time so we have matching tim series
    # lat,lon = v.latitude, v.longitude
    # ndvi = ds.sel(latitude=lat, longitude=lon, method='nearest')
    
    # #smooth
    # ndvi = ndvi.resample(time="2W").interpolate("linear")
    # v = v.sel(time=ndvi.time, method='nearest')
    # ndvi=sg_smooth(ndvi, window=11, poly=3, deriv=0)
    # v=sg_smooth(v, window=11, poly=3, deriv=0)

    # #interpolate
    # v = v.drop_duplicates(dim='time')
    # ndvi = ndvi.dropna(dim='time',
    #         how='all').resample(time='1D').interpolate(kind='quadratic')
    # v = v.dropna(dim='time',
    #         how='all').resample(time='1D').interpolate(kind='quadratic')

    # # same length of time for both ds
    # ndvi = ndvi.sel(time=v.time, method='nearest')
    # v = v.sel(time=ndvi.time, method='nearest')