# Analyse modelled fluxes


In [1]:
import sys
import xarray as xr
import numpy as np
from scipy import stats
import scipy as sp
import geopandas as gpd
import pandas as pd
from odc.algo import xr_reproject
from matplotlib import pyplot as plt
from datacube.utils.dask import start_local_dask

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.spatial import xr_rasterize

In [2]:
client = start_local_dask(mem_safety_margin='2Gb')
client

2023-02-03 09:47:24,490 - distributed.diskutils - INFO - Found stale lock file and directory '/local/u46/cb3058/tmp/dask-worker-space/worker-tbn07092', purging


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 1
Total threads: 16,Total memory: 44.92 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:45667,Workers: 1
Dashboard: /proxy/8787/status,Total threads: 16
Started: Just now,Total memory: 44.92 GiB

0,1
Comm: tcp://127.0.0.1:34625,Total threads: 16
Dashboard: /proxy/36623/status,Memory: 44.92 GiB
Nanny: tcp://127.0.0.1:38937,
Local directory: /local/u46/cb3058/tmp/dask-worker-space/worker-gzim8vq3,Local directory: /local/u46/cb3058/tmp/dask-worker-space/worker-gzim8vq3


In [34]:
var = 'ER'
suffix='20230109'
results_name = var+'_2003_2021_5km_LGBM_'+suffix+'.nc'
mask_path = '/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc'
data_path = '/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc'

## Open predictions

In [35]:
ds = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/'+results_name,
                       chunks=dict(x=250,y=250, time=-1))#.sel(time=slice('2003','2018'))

In [36]:
grid = ds.odc.geobox.to_crs('EPSG:3577')
ds = xr_reproject(ds, geobox=grid.compat, resampling='bilinear')
area_per_pixel = ds.geobox.resolution[1]**2

ds = ds * area_per_pixel * 1e-15 * 12

## Open predictor data

In [24]:
data = xr.open_dataset(data_path,  chunks=dict(x=250,y=250, time=-1))
data = xr_reproject(data, geobox=grid.compat, resampling='bilinear')

## Correlations



In [25]:
data_var='vpd'

In [26]:
var_clim_mean = data[data_var].groupby('time.month').mean()
var_anom = (data[data_var].groupby('time.month') - var_clim_mean).compute()

var_clim_mean = var_clim_mean.compute()

  return self.array[key]


In [37]:
ds_clim_mean = ds.groupby('time.month').mean()
ds_anom = (ds.groupby('time.month') - ds_clim_mean).compute()

ds_clim_mean = ds_clim_mean.compute()

  return self.array[key]


### Create table of correlations per bioclimatic region

In [28]:
gdf = gpd.read_file('/g/data/os22/chad_tmp/NEE_modelling/data/bioclimatic_regions.geojson')
gdf.head()

Unnamed: 0,bioclimatic_regions,region_name,geometry
0,1.0,Tropics,"MULTIPOLYGON (((122.92500 -16.42500, 122.92500..."
1,2.0,Savanna,"MULTIPOLYGON (((147.67500 -19.87500, 147.72500..."
2,3.0,Warm Temperate,"MULTIPOLYGON (((145.42500 -36.02500, 145.42500..."
3,4.0,Cool Temperate,"MULTIPOLYGON (((147.07500 -43.37500, 147.12500..."
4,5.0,Mediterranean,"MULTIPOLYGON (((135.82500 -34.87500, 135.82500..."


In [29]:
data['vpd_anom'] = var_anom.drop('month')

In [38]:
clim_vars = ['rain_anom', 'rain_cml3_anom', 'rain_cml6_anom','rain_cml12_anom', 'tavg_anom', 'srad_anom', 'vpd_anom', 'kNDVI_anom']

In [39]:
outer = {}
for index, row in gdf.iterrows():
    print(row['region_name'])
    mask = xr_rasterize(gdf.iloc[[index]], ds_anom.isel(time=1))
    inner = {}
    for v in clim_vars:
        var_anom_region = data[v].where(mask)
        ds_anom_region = ds_anom.where(mask)
        r2 = xr.corr(ds_anom_region, var_anom_region, dim='time').compute()
        r2 = r2.mean(['x', 'y'])
        print('  ', v, r2.values)
        inner[v] = r2.values
    outer[row['region_name']] = inner    

Tropics
   rain_anom 0.3010361125163598
   rain_cml3_anom 0.39584345331292536
   rain_cml6_anom 0.4190790970507721
   rain_cml12_anom 0.2847108527306472
   tavg_anom 0.01789328418619021
   srad_anom -0.24873191351631796


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


   vpd_anom -0.3378394789346753
   kNDVI_anom 0.650658144003097
Savanna
   rain_anom 0.4258050943632751
   rain_cml3_anom 0.5027683227819504
   rain_cml6_anom 0.5432768869226903
   rain_cml12_anom 0.5180453711256738
   tavg_anom -0.3059702714608702
   srad_anom -0.5259331762625784


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


   vpd_anom -0.6432145373483046
   kNDVI_anom 0.7519923824365157
Warm Temperate
   rain_anom 0.3992940151925918
   rain_cml3_anom 0.5639494752350361
   rain_cml6_anom 0.5465238147349588
   rain_cml12_anom 0.47743825179693516
   tavg_anom -0.17982722886641062
   srad_anom -0.4012077514730826


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


   vpd_anom -0.5063410287873675
   kNDVI_anom 0.7492447259291836
Cool Temperate
   rain_anom 0.3376538771394378
   rain_cml3_anom 0.49003342614782147
   rain_cml6_anom 0.5084317919888883
   rain_cml12_anom 0.49023729271261157
   tavg_anom -0.016150216141434512
   srad_anom -0.2515735029544409


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


   vpd_anom -0.3549042183542805
   kNDVI_anom 0.6768562107771269
Mediterranean
   rain_anom 0.64004979887547
   rain_cml3_anom 0.668576391055295
   rain_cml6_anom 0.5924952409413708
   rain_cml12_anom 0.5186857572094286
   tavg_anom -0.1903927615837482
   srad_anom -0.5224491431971122


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


   vpd_anom -0.49352058552441497
   kNDVI_anom 0.6252493758229084
Desert
   rain_anom 0.6945312096731284
   rain_cml3_anom 0.6824930447715422
   rain_cml6_anom 0.6486216420880332
   rain_cml12_anom 0.5554339119165882
   tavg_anom -0.3357004329093372
   srad_anom -0.7106807338043947


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


   vpd_anom -0.6932809524167751
   kNDVI_anom 0.586000895223868


In [40]:
df = pd.DataFrame(outer)
df

Unnamed: 0,Tropics,Savanna,Warm Temperate,Cool Temperate,Mediterranean,Desert
rain_anom,0.3010361125163598,0.4258050943632751,0.3992940151925918,0.3376538771394378,0.64004979887547,0.6945312096731284
rain_cml3_anom,0.3958434533129253,0.5027683227819504,0.5639494752350361,0.4900334261478214,0.668576391055295,0.6824930447715422
rain_cml6_anom,0.4190790970507721,0.5432768869226903,0.5465238147349588,0.5084317919888883,0.5924952409413708,0.6486216420880332
rain_cml12_anom,0.2847108527306472,0.5180453711256738,0.4774382517969351,0.4902372927126115,0.5186857572094286,0.5554339119165882
tavg_anom,0.0178932841861902,-0.3059702714608702,-0.1798272288664106,-0.0161502161414345,-0.1903927615837482,-0.3357004329093372
srad_anom,-0.2487319135163179,-0.5259331762625784,-0.4012077514730826,-0.2515735029544409,-0.5224491431971122,-0.7106807338043947
vpd_anom,-0.3378394789346753,-0.6432145373483046,-0.5063410287873675,-0.3549042183542805,-0.4935205855244149,-0.6932809524167751
kNDVI_anom,0.650658144003097,0.7519923824365157,0.7492447259291836,0.6768562107771269,0.6252493758229084,0.586000895223868


In [41]:
df.to_csv('/g/data/os22/chad_tmp/NEE_modelling/results/'+var+'_anomaly_bioregion_correlations.csv')

### Over all of Aus

In [None]:
r2 = xr.corr(ds_anom_region, var_anom_region, dim='time').compute()
r2 = r2.mean(['x', 'y'])
print(r2)

In [None]:
ax2_ylim = -100,100
ax_ylim = -1,1

fig,ax=plt.subplots(1,2, figsize=(18,5), gridspec_kw={'width_ratios': [3, 1]})
ax2 = ax[0].twinx()
var_anom.mean(['x','y']).rolling(time=3).mean().plot(ax=ax2, label=data_var, c='orange')
ds_anom.sum(['x','y']).rolling(time=3).mean().plot(ax=ax[0], label=var)
ax[0].legend(loc=(0.80,0.925))
ax2.legend(loc=(0.80,0.85))
ax2.set_ylabel(data_var+' Anomaly', fontsize=15)
ax[0].set_xlabel('')
ax2.set_ylim(ax2_ylim)
ax[0].set_ylim(ax_ylim)
ax[0].text(.05, .90, 'r={:.2f}'.format(r2[0]),
            transform=ax[0].transAxes, fontsize=15)
ax[0].set_ylabel(var+' Anomalies (PgC y⁻¹)', fontsize=15)
ax[0].tick_params(axis='x', labelsize=14)
ax[0].tick_params(axis='y', labelsize=14)
ax2.tick_params(axis='y', labelsize=14)
ax[1].tick_params(axis='x', labelsize=14)
ax[1].tick_params(axis='y', labelsize=14)

ax[0].axhline(0, c='grey', linestyle='--')

ax3 = ax[1].twinx()

var_clim_mean.mean(['x','y']).plot(ax=ax3, label=data_var, c='orange')
ds_clim_mean.mean(['x','y']).plot(ax=ax[1], label='NEE')
ax3.set_ylabel(data_var)
ax3.set_ylabel(data_var, fontsize=15)
ax[1].set_ylabel(var+' (PgC y⁻¹)', fontsize=15)
ax[1].set_xticks(range(1,13))
ax[1].set_xticklabels(["J","F","M","A","M","J","J","A","S","O","N","D"]) 
ax[1].set_xlabel('')
ax3.tick_params(axis='y', labelsize=14)
ax[0].set_title(None)
ax[1].set_title(None)
ax2.set_title(None)
ax3.set_title(None)
plt.tight_layout();
plt.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_Aus_'+data_var+'_correlations.png')

### Plots by bioregion

In [None]:
ax2_ylim = -0.11,0.11

In [None]:
# Dictionary to save results 
results = {}
for index, row in gdf.iterrows():


    # Generate a polygon mask to keep only data within the polygon
    mask = xr_rasterize(gdf.iloc[[index]], var_anom.isel(time=1))
    mask['latitude'] = mask.latitude.astype('float32')
    mask['longitude'] = mask.longitude.astype('float32')  
    mask['latitude'] = np.array([round(i,4) for i in mask.latitude.values])
    mask['longitude'] = np.array([round(i,4) for i in mask.longitude.values])
    mask = mask.rename({'latitude':'y', 'longitude':'x'})
    
    # Mask dataset to set pixels outside the polygon to `NaN`
    var_anom_region = var_anom.where(mask)
    ds_anom_region = ds_anom.where(mask)
    
    var_clim_mean_region = var_clim_mean.where(mask)
    ds_clim_mean_region = ds_clim_mean.where(mask)
    
    r2 = xr.corr(ds_anom_region, var_anom_region, dim='time').compute()
    r2 = r2.mean(['x', 'y'])
    print(row['region_name'], r2)

    fig,ax=plt.subplots(1,2, figsize=(18,5), gridspec_kw={'width_ratios': [3, 1]})
    ax2 = ax[0].twinx()
    var_anom_region.mean(['x','y']).rolling(time=3).mean().plot(ax=ax2, label=data_var, c='orange')
    ds_anom_region.mean(['x','y']).rolling(time=3).mean().plot(ax=ax[0], label=var)
    
    ax[0].legend(loc=(0.80,0.925))
    ax2.legend(loc=(0.80,0.85))
    ax2.set_ylabel(data_var+' Anomaly', fontsize=15)
    ax[0].set_xlabel('')
    ax2.set_ylim(ax2_ylim)
    ax[0].set_ylim(ax_ylim)
    ax[0].text(.05, .90, 'r={:.2f}'.format(r2[0]),
                transform=ax[0].transAxes, fontsize=15)
    ax[0].set_ylabel(var+' Anomalies (gC m\N{SUPERSCRIPT TWO} m⁻¹)', fontsize=15)
    ax[0].tick_params(axis='x', labelsize=14)
    ax[0].tick_params(axis='y', labelsize=14)
    ax2.tick_params(axis='y', labelsize=14)
    ax[1].tick_params(axis='x', labelsize=14)
    ax[1].tick_params(axis='y', labelsize=14)

    ax[0].axhline(0, c='grey', linestyle='--')

    ax3 = ax[1].twinx()

    var_clim_mean_region.mean(['x','y']).plot(ax=ax3, label=data_var, c='orange')
    ds_clim_mean_region.mean(['x','y']).plot(ax=ax[1], label='NEE')
    
    ax3.set_ylabel(data_var)
    ax3.set_ylabel(data_var, fontsize=15)
    ax[1].set_ylabel(var+' (gC m\N{SUPERSCRIPT TWO} m⁻¹)', fontsize=15)
    ax[1].set_xticks(range(1,13))
    ax[1].set_xticklabels(["J","F","M","A","M","J","J","A","S","O","N","D"]) 
    ax[1].set_xlabel('')
    ax3.tick_params(axis='y', labelsize=14)
    ax[0].set_title(None)
    ax[1].set_title(None)
    ax2.set_title(None)
    ax3.set_title(None)
    
    plt.suptitle(row['region_name'], fontsize=18)
    plt.tight_layout();
    plt.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_'+row['region_name']+'_'+data_var+'_correlations.png')

## Per-pixel correlations

In [None]:
kndvi_corr = xr.corr(ds.chunk({'x':250,'y':250, 'time':-1}), data['kNDVI'], dim='time').compute()
rain_corr = xr.corr(ds.chunk({'x':250,'y':250, 'time':-1}), data['rain'], dim='time').compute()
vpd_corr = xr.corr(ds.chunk({'x':250,'y':250, 'time':-1}), data['vpd'], dim='time').compute()
srad_corr = xr.corr(ds.chunk({'x':250,'y':250, 'time':-1}), data['srad'], dim='time').compute()
tavg_corr = xr.corr(ds.chunk({'x':250,'y':250, 'time':-1}), data['tavg'], dim='time').compute()

In [None]:
# kndvi_corr.plot.imshow()
# plt.title('kNDVI correlation with ER')

### Correlations with anomalies

In [None]:
kNDVI_anom_corr = xr.corr(ds_anom.chunk({'x':250,'y':250, 'time':-1}), data['kNDVI_anom'], dim='time').compute()
rain_anom_corr = xr.corr(ds_anom.chunk({'x':250,'y':250, 'time':-1}), data['rain_anom'], dim='time').compute()
vpd_anom_corr = xr.corr(ds_anom.chunk({'x':250,'y':250, 'time':-1}), data['vpd'], dim='time').compute()
srad_anom_corr = xr.corr(ds_anom.chunk({'x':250,'y':250, 'time':-1}), data['srad_anom'], dim='time').compute()
tavg_anom_corr = xr.corr(ds_anom.chunk({'x':250,'y':250, 'time':-1}), data['tavg_anom'], dim='time').compute()

### Correlations with climatology

In [None]:
precip_clim_corr = xr.corr(ds_clim_mean.chunk({'x':250,'y':250, 'month':-1}), data['rain'].groupby('time.month').mean(), dim='month').compute()
srad_clim_corr = xr.corr(ds_clim_mean.chunk({'x':250,'y':250, 'month':-1}), data['srad'].groupby('time.month').mean(), dim='month').compute()
tavg_clim_corr = xr.corr(ds_clim_mean.chunk({'x':250,'y':250, 'month':-1}), data['tavg'].groupby('time.month').mean(), dim='month').compute()
kNDVI_clim_corr = xr.corr(ds_clim_mean.chunk({'x':250,'y':250, 'month':-1}), data['kNDVI'].groupby('time.month').mean(), dim='month').compute()


### Plot correlations with anomalies

In [None]:
anom_data = [rain_anom_corr,  tavg_anom_corr, srad_anom_corr, kNDVI_anom_corr]
clim_vars = ['Rainfall', 'Air Temperature','Solar Radiation', 'kNDVI']

In [None]:
fig,axes = plt.subplots(1,4, figsize=(24,7), sharey=True, sharex=True)

for ax, ds, clim in zip(axes.ravel(), anom_data, clim_vars):

    im = ds.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax, add_colorbar=False)
    ax.set_title(var+' Anomalies & '+clim+' Anomalies',  fontsize=18);
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_xticklabels([])

plt.tight_layout();
# fig.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_anomalies_perpixel_climate_correlations.png', bbox_inches='tight')

In [None]:
# get handles and labels for reuse
# label_params = ax[1,1].get_legend_handles_labels() 
cbar = fig.colorbar(im, spacing='uniform', ax=ax, orientation='horizontal', shrink=0.4);

# cbar
figl, axl = plt.subplots(figsize=(11,4))
axl.axis(False)
cbar = plt.colorbar(im, spacing='uniform', ax=axl, orientation='horizontal')
cbar.ax.tick_params(labelsize=20)
cbar.set_label("Pearson's Correlation",size=20)
figl.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/correlation_legend.png', bbox_inches='tight')

### Plot correlations with climatology

In [None]:
fig,ax = plt.subplots(2,2, figsize=(12,12), sharey=True, sharex=True)
precip_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[0,0], add_colorbar=False)
ax[0,0].set_title(var+' Climatology & Rain Climatology',  fontsize=18);
ax[0,0].set_yticklabels([])
ax[0,0].set_ylabel('')
ax[0,0].set_xlabel('')
ax[0,0].set_xticklabels([])

tavg_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[0,1], add_colorbar=False)
ax[0,1].set_title(var+' Climatology & TAVG Climatology',  fontsize=18);
ax[0,1].set_yticklabels([])
ax[0,1].set_ylabel('')
ax[0,1].set_xlabel('')
ax[0,1].set_xticklabels([])

srad_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[1,0], add_colorbar=False)
ax[1,0].set_title(var+' Climatology & SRAD Climatology',  fontsize=18);
ax[1,0].set_yticklabels([])
ax[1,0].set_ylabel('')
ax[1,0].set_xlabel('')
ax[1,0].set_xticklabels([])

im = kNDVI_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[1,1], add_colorbar=False)
ax[1,1].set_title(var+' Climatology & kNDVI Climatology', fontsize=18)
ax[1,1].set_yticklabels([])
ax[1,1].set_ylabel('')
ax[1,1].set_xlabel('')
ax[1,1].set_xticklabels([])
plt.tight_layout();
fig.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_climatology_perpixel_climate_correlations.png', bbox_inches='tight')

### Variable with highest correlation

In [None]:
corrs = xr.merge([
    np.abs(rain_corr.rename('rain')),
    np.abs(vpd_corr.rename('vpd')),
    np.abs(srad_corr.rename('srad')),
    np.abs(tavg_corr.rename('tavg'))
])

max_corrs = corrs.to_array("variable").idxmax("variable")
max_corrs = xr.where(max_corrs == 'rain', 1, max_corrs)
max_corrs = xr.where(max_corrs == 'vpd', 2, max_corrs)
max_corrs = xr.where(max_corrs == 'srad', 3, max_corrs)
max_corrs = xr.where(max_corrs == 'tavg', 4, max_corrs)

max_corrs = max_corrs.astype(np.float32)

In [None]:
rain_corrs = xr.merge([
    np.abs(precip_corr.rename('precip_anom')),
    np.abs(precip_3_corr.rename('precip_3_anom')),
    np.abs(precip_6_corr.rename('precip_6_anom')),
    np.abs(precip_12_corr.rename('precip_12_anom'))
])

rain_max_corrs = rain_corrs.to_array("variable").idxmax("variable")

rain_max_corrs = xr.where(rain_max_corrs == 'precip_anom', 1, rain_max_corrs)
rain_max_corrs = xr.where(rain_max_corrs == 'precip_3_anom', 2, rain_max_corrs)
rain_max_corrs = xr.where(rain_max_corrs == 'precip_6_anom', 3, rain_max_corrs)
rain_max_corrs = xr.where(rain_max_corrs == 'precip_12_anom', 4, rain_max_corrs)

rain_max_corrs = rain_max_corrs.astype(np.float32)

### Plots

In [None]:
fig,ax = plt.subplots(1,1, figsize=(10,7))
im = max_corrs.plot.imshow(add_colorbar=False, ax=ax)
cbar = fig.colorbar(im, spacing='uniform', ax=ax, orientation='vertical', shrink=0.4)
cbar.set_ticks([1,2,3,4])
cbar.set_ticklabels(['Rain', 'VPD', 'SRAD', 'TAVG'], fontsize=10)
plt.title('Climate Variable with Maximum Absolute Correlation with '+var);

In [None]:
fig,ax = plt.subplots(1,1, figsize=(10,7))
im = rain_max_corrs.plot.imshow(add_colorbar=False, ax=ax)
cbar = fig.colorbar(im, spacing='uniform', ax=ax, orientation='vertical', shrink=0.4)
cbar.set_ticks([1,2,3,4])
cbar.set_ticklabels(['Rain', 'Rain-3', 'Rain-6', 'Rain-12'], fontsize=10)
plt.title('Rainfall Variable with Maximum Absolute Correlation with '+var);

## Linear trends

In [None]:
# import xarray as xr
import dask.array as da
from dask.delayed import delayed
from  scipy import stats

def _calc_slope(y):
    """return linear regression statistical variables"""
    mask = np.isfinite(y)
    x = np.arange(len(y))
    return stats.linregress(x[mask], y[mask])

# regression function defition
def regression(y):
    """apply linear regression function along time axis"""
    axis_num = y.get_axis_num('time')
    return da.apply_along_axis(_calc_slope, axis_num, y)

# fill pixels that are all-NaNs
allnans = ds.isnull().all('time').compute()
ds = ds.where(~allnans, other=0)

# regression analysis
delayed_objs = delayed(regression)(ds).persist()

# transforms dask.delayed to dask.array
results = da.from_delayed(delayed_objs, shape=(5, ds.shape[1:][0], ds.shape[1:][1]), dtype=np.float32)
results = results.compute()
results = results.compute() #need this twice haven't figured out why

# statistical variables definition
variables = ['slope', 'intercept', 'r_value', 'p_value', 'std_err']

# coordination definition
coords = {'y': ds.y, 'x': ds.x}

# output xarray.Dataset definition
ds_out = xr.Dataset(
    data_vars=dict(slope=(["y", "x"], results[0]),
                   intercept=(["y", "x"], results[1]),
                   r_value=(["y", "x"], results[2]),
                   p_value=(["y", "x"], results[3]),
                   std_err=(["y", "x"], results[4]),
                  ),
    coords = coords)

#remask all-NaN pixel
ds_out = ds_out.where(~allnans)
ds_out

### Mask with Evergreen Trees

In [None]:
lc = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/data/Landcover_merged_5km.nc').isel(time=1)
lc['latitude'] = lc.latitude.astype('float32')
lc['longitude'] = lc.longitude.astype('float32')
lc = lc.rename({'latitude':'y','longitude':'x'})

trees = lc.PFT == 10

In [None]:
trees.plot.imshow()

In [None]:
ds_out.slope.max()

In [None]:
# import folium
# import odc.geo
# import folium
# from odc.geo.xr import assign_crs

# # Create folium Map (ipyleaflet is also supported)
# m = folium.Map(tiles='openstreetmap')

# # Plot each sample image with different colormap
# ds_out.slope.where(trees).odc.add_to(m, cmap='BrBG', vmax=0.2,vmin=-0.2, opacity=1.0)

# # Zoom map to Australia
# m.fit_bounds(ds_out.odc.map_bounds())

# # tile = folium.TileLayer(
# #         tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
# #         attr = 'Esri',
# #         name = 'Esri Satellite',
# #         overlay = True,
# #         control = True
# #        ).add_to(m)

# folium.LayerControl().add_to(m)
# display(m)


# ds_out.slope.where(trees).plot.imshow(size=10, robust=True, cmap='BrBG')
# plt.title('Linear Trend in Evergreen Forest GPP 2003-2018');

In [None]:
ds_out.slope.where(trees).plot.imshow(size=10, robust=True, cmap='BrBG')
plt.title('Linear Trend in Evergreen Forest GPP 2003-2018');

In [None]:
ds_out.slope.plot.imshow(size=10, robust=True, cmap='BrBG')
plt.title('Linear Trend in Evergreen Forest GPP 2003-2018');

## Plot relationships between NEE, GPP, ER and environmental variables (P, T, SM etc)

Following Lui et al. (2018) 

In [None]:
gpp = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/GPP_2003_2021_5km_LGBM.nc',
                       chunks=dict(x=250,y=250, time=-1))#.sel(time=slice('2003','2018'))

nee = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/NEE_2003_2021_5km_LGBM.nc',
                       chunks=dict(x=250,y=250, time=-1))#.sel(time=slice('2003','2018'))

In [None]:
def s_cor(x,y, pthres = 0.05, direction = True):
    """
    Uses the scipy stats module to calculate a Kendall correlation test
    :x vector: Input pixel vector to run tests on
    :y vector: The date input vector
    :pthres: Significance of the underlying test
    :direction: output only direction as output (-1 & 1)
    """
    # Check NA values
    co = np.count_nonzero(~np.isnan(x))
    if co < 4: # If fewer than 4 observations return -9999
        return np.nan
    # Run the kendalltau test
    r, p_value = stats.spearmanr(x, y, nan_policy='omit')

    # Criterium to return results in case of Significance
    if p_value > pthres:
        return np.nan
    else:
        return r 

def spearman_correlation(x,y,dim='year'):
    return xr.apply_ufunc(
        s_cor, x , y,
        input_core_dims=[[dim], [dim]],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[np.float32]
        )


In [None]:
r = spearman_correlation(gpp, nee ,'time').compute()

In [None]:
r.plot.imshow(size=6, vmin=-1, vmax=1, cmap='RdBu')
plt.title('Signficant (p<0.05) Temporal Spearman Correlations: GPP & NEE')

## Causality

###  Granger casaulity tests?

### Bayesian structure learning?
https://towardsdatascience.com/a-step-by-step-guide-in-detecting-causal-relationships-using-bayesian-structure-learning-in-python-c20c6b31cee5

In [None]:
from statsmodels.tsa.stattools import grangercausalitytests

def grangers_causation_matrix(data, variables, maxlag=12, test='ssr_chi2test', verbose=False):    
    """Check Granger Causality of all possible combinations of the Time series.
    The rows are the response variable, columns are predictors. The values in the table 
    are the P-Values. P-Values lesser than the significance level (0.05), implies 
    the Null Hypothesis that the coefficients of the corresponding past values is 
    zero, that is, the X does not cause Y can be rejected.

    data      : pandas dataframe containing the time series variables
    variables : list containing names of the time series variables.
    """
    df = pd.DataFrame(np.zeros((len(variables), len(variables))), columns=variables, index=variables)
    for c in df.columns:
        for r in df.index:
            test_result = grangercausalitytests(data[[r, c]], maxlag=maxlag, verbose=False)
            p_values = [round(test_result[i+1][0][test][1],4) for i in range(maxlag)]
            if verbose: print(f'Y = {r}, X = {c}, P Values = {p_values}')
            min_p_value = np.min(p_values)
            df.loc[r, c] = min_p_value
    df.columns = [var + '_x' for var in variables]
    df.index = [var + '_y' for var in variables]
    return df

grangers_causation_matrix(ndvi, variables = ndvi.columns)  