# Create input datasets to GAM

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from scipy import stats
import geopandas as gpd
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt

import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _collect_prediction_data import round_coords

import warnings
warnings.filterwarnings("ignore")

In [None]:
name='nontrees'
model_var = 'NDVI'

In [None]:
ds = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/AVHRR_NDVI_5km_monthly_1982_2013.nc')
ds = assign_crs(ds, crs ='epsg:4326')

In [None]:
before_fraction_avail = (~np.isnan(ds['NDVI_avhrr'])).sum('time')/len(ds.time)

In [None]:
#filter by num of obs/month
# ds = ds.where(ds['n_obs']>=2)

#remove any very low NDVI vals
ds = ds.where(ds['NDVI_avhrr']>=0.01)

# filter by coefficient of variation each month
ds['ndvi_cv'] = ds['NDVI_stddev'] / ds['NDVI_avhrr']
ds = ds.where(ds['ndvi_cv']<0.5)

#filter by large std dev anomalies
def stand_anomalies(ds, clim_mean, clim_std):
    std_anom = xr.apply_ufunc(lambda x, m, s: (x - m) / s,
    ds.compute().groupby("time.month"),
    clim_mean, clim_std)
    return std_anom

#calculate anomalies
ndvi_clim_std = ds.groupby('time.month').std()
ndvi_clim = ds.groupby('time.month').mean()
ndvi_std_anom = stand_anomalies(ds, ndvi_clim, ndvi_clim_std)

#create masks where values are < 4 stddev >
ndvi_anom_mask = xr.where((ndvi_std_anom['NDVI_avhrr'] > -4) & (ndvi_std_anom['NDVI_avhrr'] < 4), 1, 0)
sza_anom_mask = xr.where((ndvi_std_anom['SZEN_median'] > -4) & (ndvi_std_anom['SZEN_median'] < 4), 1, 0)
tod_anom_mask = xr.where((ndvi_std_anom['TIMEOFDAY_median'] > -4) & (ndvi_std_anom['TIMEOFDAY_median'] < 4), 1, 0)

ds = ds.where(ndvi_anom_mask)
ds = ds.where(sza_anom_mask)
ds = ds.where(tod_anom_mask)

### Plot available fraction of data before/after filtering

In [None]:
after_fraction_avail = (~np.isnan(ds['NDVI_avhrr'])).sum('time')/len(ds.time)

fig,ax=plt.subplots(1,2, figsize=(11,4))
before_fraction_avail.plot.imshow(robust=True, ax=ax[0], cmap='magma')
after_fraction_avail.plot.imshow(robust=True, ax=ax[1], cmap='magma')
ax[0].set_title(str(before_fraction_avail.mean().values));
ax[1].set_title(str(after_fraction_avail.mean().values));

### Add lagged NDVI features

In [None]:
ds['NDVI_avhrr_1f'] = ds['NDVI_avhrr'].shift(time=1)
ds['NDVI_avhrr_1b'] = ds['NDVI_avhrr'].shift(time=-1)
# ds['NDVI_avhrr_2f'] = ds['NDVI_avhrr'].shift(time=2)
# ds['NDVI_avhrr_2b'] = ds['NDVI_avhrr'].shift(time=-2)

In [None]:
# ds = ds.sel(time=slice(t1, t2))
ds = ds.drop('month')

### Open covariables

In [None]:
base = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/'

datasets = [
    'NDVI_harmonization/MODIS_NDVI_5km_monthly_200003_202212.nc',
    '5km/rain_cml3_5km_monthly_1982_2022.nc',
    '5km/srad_5km_monthly_1982_2022.nc',
           ]

In [None]:
dss = []
names = []
for d in datasets:
    xx =  assign_crs(xr.open_dataset(base+d),crs='epsg:4326').sel(time=slice('1982','2013'))
    # if "MODIS" in d:
    #     #xx = xx.odc.reproject(how=ds.odc.geobox)
    #     #xx = round_coords(xx)
    xx = round_coords(xx)
    xx = xx.drop('spatial_ref')
    names.append(list(xx.data_vars)[0])
    dss.append(xx.transpose('time', 'latitude', 'longitude'))

covars = xr.merge(dss)
covars = assign_crs(covars, crs ='epsg:4326')
covars = covars.rename({'NDVI_median':'NDVI_modis'})

#merge the AVHR with covariables
ds = xr.merge([ds,covars])

### Add some MODIS summary stats

In [None]:
mean_modis = ds['NDVI_modis'].mean('time')
mean_modis = mean_modis.expand_dims(time=ds.time)
ds['NDVI_modis_mean'] = mean_modis

min_modis = ds['NDVI_modis'].quantile(0.05, dim='time').drop('quantile')
min_modis = min_modis.expand_dims(time=ds.time)
ds['NDVI_modis_min'] = min_modis

max_modis = ds['NDVI_modis'].quantile(0.95, dim='time').drop('quantile')
max_modis = max_modis.expand_dims(time=ds.time)
ds['NDVI_modis_max'] = max_modis

#remove unneeded variables
ds = ds.drop(['NDVI_stddev', 'n_obs', 'ndvi_cv', 'NDVI_modis'])

### Mask

In [None]:
name='trees'

In [None]:
# if model_var=='NDVI'
trees = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/WCF_5km_monthly_1982_2022.nc')['WCF']
trees = assign_crs(trees, crs ='epsg:4326')
trees = trees.sel(time=slice('2001', '2018'))
trees = trees.mean('time')

if model_var=='NDVI':
    if name=='trees':
        mask = xr.where(trees>0.25, 1, 0)
    if name=='nontrees':
        mask = xr.where(trees<=0.25, 1, 0)

if model_var=='LST':
    if name=='AUS':
        mask = xr.where(trees>=0, 1, 0) # set everywhere as valid

In [None]:
ds = ds.where(mask)

### Export

In [None]:
for i in ds.data_vars:
    try:
        del ds[i].attrs['grid_mapping']
    except:
        continue

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+name+'_AVHRR_NDVI_5km_monthly_1982_2013_extraFeatures.nc')

### Modis clipped to regions

In [None]:
mod_path = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/MODIS_NDVI_5km_monthly_200003_202212.nc'
mod = xr.open_dataset(mod_path)
mod = assign_crs(mod, crs ='epsg:4326')

In [None]:
mod = mod.where(mask)

In [None]:
mod.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+name+'_MODIS_NDVI_5km_monthly_200003_202212.nc')

## Post-process GAM results

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from scipy import stats
import geopandas as gpd
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt

import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _collect_prediction_data import round_coords

import warnings
warnings.filterwarnings("ignore")

In [None]:
name='nontrees'
model_var = 'NDVI'

In [None]:
path = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+name+'_AVHRR_MODIS_NDVI_GAM_harmonized_climate_1982_2013.nc'
ds = xr.open_dataset(path)
ds = assign_crs(ds, crs ='epsg:4326')

In [None]:
ds = ds['ndvi_mcd_pred'].rename('NDVI')

In [None]:
# ds.sel(time='2001').plot.imshow(col='time', col_wrap=4, robust=True)

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/NDVI_'+name+'_GAM_harmonize_5km_monthly_1982_2013.nc')