# Show correlations between rainfall and NDVI

To confirm that Australia's vegetation is typically water-limited

In [None]:
import os
import sys
import warnings
import xarray as xr
import seaborn as sb
import xarray as xr
from scipy import stats
from scipy.stats import gaussian_kde
import xskillscore as xs
import numpy as np
import matplotlib.pyplot as plt
import contextily as ctx
import odc.geo.xr
from odc.geo.xr import assign_crs

warnings.simplefilter('ignore')
import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import round_coords

%matplotlib inline

In [None]:
# import sys
# sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
# from _utils import start_local_dask

# client = start_local_dask(mem_safety_margin='2Gb')
# client

## Analysis Parameters

In [None]:
base = '/g/data/os22/chad_tmp/AusENDVI/data/'
ndvi_path = '/g/data/os22/chad_tmp/AusENDVI/results/publication/AusENDVI-clim_MCD43A4_gapfilled_1982_2022.nc'
# rain_cml3_path = base+'5km/rain_cml3_5km_monthly_1982_2022.nc'
# rain_cml6_path = base+'5km/rain_cml6_5km_monthly_1982_2022.nc'

t_range='1982', '2022'
# t_range='2000-03', '2022'
crs='EPSG:4326'

## Open data

In [None]:
pku = xr.open_dataarray(f'{base}/NDVI_harmonization/AVHRR_GIMMS-PKU-MODIS_1982_2022.nc')
pku = pku.rename('NDVI')
pku = assign_crs(pku, crs=crs)
pku.attrs['nodata'] = np.nan

# pku = xr.open_dataset(base+'NDVI_harmonization/AVHRR_GIMMS3g_v1.1_1982_2013.nc')['NDVI']
# pku = assign_crs(pku, crs='epsg:3577')
# pku.attrs['nodata'] = np.nan

#because modis is joined we can just load the final dataset and clip out 2000 onwards
modis = assign_crs(xr.open_dataset(ndvi_path)['AusENDVI_clim_MCD43A4'], crs='EPSG:4326')
modis = modis.sel(time=slice('2000-03', '2022'))
modis = assign_crs(modis, crs=crs)
modis.attrs['nodata'] = np.nan
modis = modis.rename('NDVI')

ause = assign_crs(xr.open_dataset(ndvi_path)['AusENDVI_clim_MCD43A4'], crs='EPSG:4326')
syn = assign_crs(xr.open_dataset(base+'/synthetic/NDVI/NDVI_CLIM_synthetic_5km_monthly_1982_2022.nc')['NDVI'], crs='EPSG:4326')

rain = xr.open_dataset('/g/data/os22/chad_tmp/AusENDVI/data/5km/rain_5km_monthly_1981_2022.nc').rain
rain = rain.sel(time=slice('1982', '2022'))
rain = assign_crs(rain, crs=crs)
rain.attrs['nodata'] = np.nan

## Harmonise

In [None]:
ause = ause.odc.reproject(pku.odc.geobox, resampling='average')
modis = modis.odc.reproject(pku.odc.geobox, resampling='average')
rain = rain.odc.reproject(pku.odc.geobox, resampling='average')
syn = syn.odc.reproject(pku.odc.geobox, resampling='average')

ause = round_coords(ause)
modis = round_coords(modis)
pku = round_coords(pku)
rain = round_coords(rain)
syn = round_coords(syn)

#gapfill PKU the same as our product
obs_monthly = pku.groupby('time.month').mean()
obs_anom = pku.groupby('time.month') - obs_monthly
obs_anom = obs_anom.interpolate_na(dim='time', method='linear', limit=1)
pku = obs_anom.groupby('time.month') + obs_monthly
pku = pku.drop('month')
pku = pku.fillna(syn)

ause_mask =  ~np.isnan(ause)
pku_mask =  ~np.isnan(pku)
syn_mask =  ~np.isnan(syn)

#combine masks
mask = (ause_mask & pku_mask & syn_mask)

pku = pku.where(mask)
ause = ause.where(mask)
modis = modis.where(mask.sel(time=modis.time))
rain = rain.where(mask)

### Clip time

In [None]:
# pku = pku.sel(time=slice(t_range[0], t_range[1]))

# ause = ause.sel(time=slice(t_range[0], t_range[1]))

# rain = rain.sel(time=slice(t_range[0], t_range[1]))

## Anomalies

In [None]:
#standardized anom
def stand_anomalies(ds):
    return xr.apply_ufunc(
        lambda x, m, s: (x - m) / s,
            ds.groupby("time.month"),
            ds.groupby("time.month").mean(),
            ds.groupby("time.month").std()
    )

ause_anom = stand_anomalies(ause.sel(time=slice('1982','2022')))
pku_anom = stand_anomalies(pku.sel(time=slice('1982','2022')))
modis_anom = stand_anomalies(modis)
rain_anom = stand_anomalies(rain.sel(time=slice('1982','2022')))

#couple of pesky infs crept in
ause_anom = ause_anom.where(~np.isinf(ause_anom))
modis_anom = modis_anom.where(~np.isinf(modis_anom))

## Demonstrate water-limitation using MODIS

### per pixel

In [None]:
rain_anom_short = stand_anomalies(rain.sel(time=slice('2000-03','2022')))

corr = xr.corr(modis_anom.resample(time='1Y').mean(),
               rain_anom_short.resample(time='1Y').mean(),
               dim='time')


In [None]:
corr_data = [corr] 

fig,ax=plt.subplots(1,1, figsize=(5,5), sharey=True, layout='constrained')
im = dss.plot(ax=ax, vmin=0.0, vmax=0.9, cmap='magma', add_colorbar=False, add_labels=False)
ctx.add_basemap(ax, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
ax.set_title('Annual MODIS NDVI vs rainfall anomalies')

cbar = fig.colorbar(im, orientation='vertical',ax=ax,cmap='magma',
                shrink=0.6)
cbar.ax.set_title('R', fontsize=12);
fig.savefig(f'/g/data/os22/chad_tmp/AusENDVI/results/figs/MODIS_rainfall_correlation_perpixel.png',
            bbox_inches='tight', dpi=300);

## Derive NDVI-rain relationships

In [None]:
ndvi_data = [modis_anom,
             ause_anom.sel(time=slice('1982','2000')),
             pku_anom.sel(time=slice('1982','2000')),
             pku_anom.sel(time=slice('2000','2022'))
            ]

rain_data = [rain_anom.sel(time=slice('2000','2022')),
             rain_anom.sel(time=slice('1982','2000')),
             rain_anom.sel(time=slice('1982','2000')),
            rain_anom.sel(time=slice('2000','2022'))
            ]

names=['MODIS MCD43A4 2000-2022', 'AusENDVI-clim 1982-2000','GIMMS-PKU-consolidated 1982-2000', 'GIMMS-PKU-consolidated 2000-2022'] 

with plt.style.context('ggplot'):
    fig,axes=plt.subplots(1,4, figsize=(16,4), sharey=True, layout='constrained')
    for ax, ndvi, r, n in zip(axes.ravel(), ndvi_data, rain_data, names):
        
            plot_df = r.rename('rain').rolling(time=12,
                        min_periods=12).mean().mean(['latitude','longitude']).to_dataframe().drop(['spatial_ref', 'month'], axis=1)
            # plot_df = plot_df.shift(6)
            
            plot_df['ndvi'] = ndvi.rolling(time=12,
                        min_periods=12).mean().mean(['latitude','longitude']).to_dataframe(name='ndvi').drop(['spatial_ref', 'month'], axis=1)

            # plot_df = r.rename('rain').mean(['latitude','longitude']).rolling(time=12,
            #             min_periods=12).mean().to_dataframe().drop(['spatial_ref', 'month'], axis=1)
            # # plot_df = plot_df.shift(6)
            
            # plot_df['ndvi'] = ndvi.mean(['latitude','longitude']).rolling(time=12,
            #             min_periods=12).mean().to_dataframe(name='ndvi').drop(['spatial_ref', 'month'], axis=1)
            
            plot_df=plot_df.dropna()
            
            obs, pred = plot_df['rain'].values, plot_df['ndvi'].values
        
            s, i, r_value, p_value, std_err = stats.linregress(obs,pred)
            r2 = r_value**2
            sb.scatterplot(data=plot_df, x='rain',y='ndvi', alpha=1.0, ax=ax, cmap='magma')
            sb.regplot(data=plot_df, x='rain',y='ndvi', scatter=False, color='blue', ax=ax)
            # sb.regplot(data=plot_df, x='rain',y='rain', color='black', scatter=False, line_kws={'linestyle':'dashed'}, ax=ax)
        #
            if i<0:
                ax.text(.05, .9, f'y={s:.2f}x{i:.3f}',
                    transform=ax.transAxes, fontsize=13)
            else:
                ax.text(.05, .9, f'y={s:.2f}x+{i:.3f}',
                    transform=ax.transAxes, fontsize=13)
            ax.text(.05, .825, 'r\N{SUPERSCRIPT TWO}={:.2f}'.format(r2),
                    transform=ax.transAxes, fontsize=13)
        
            ax.set_ylabel('12-month rolling mean NDVI anomaly')
            ax.set_xlabel('12-month rolling rainfall anomaly')
            ax.set_title(n, fontsize=12)

fig.savefig(f'/g/data/os22/chad_tmp/AusENDVI/results/figs/rainfall_ndvi_relationships.png',
            bbox_inches='tight', dpi=300);