# Calculate pearson correlations between climate variables and fluxes


In [None]:
import sys
import xarray as xr
import numpy as np
from scipy import stats
import scipy as sp
import geopandas as gpd
from odc.geo.xr import assign_crs
import pandas as pd
from odc.algo import xr_reproject
# from odc.geo.geobox import zoom_out
from matplotlib import pyplot as plt
from datacube.utils.dask import start_local_dask

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.spatial import xr_rasterize

sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from _collect_prediction_data import round_coords, collect_prediction_data 

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

In [None]:
var = 'GPP'
suffix='20230320'
# results_name = var+'_2003_2022_1km_quantiles_'+suffix+'.nc'
results_name ='AusEFlux_'+var+'_2003_2022_1km_quantiles_v1.1.nc'
chunks_data = {'x':1100,'y':1100, 'time':1}
chunks = {'longitude':1100,'latitude':1100, 'time':1}
# mask_path = '/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc'
# data_path = '/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc'

## Open predictions

In [None]:
ds = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/'+results_name,
                       chunks=chunks)[var+'_median']
ds

In [None]:
# grid = zoom_out(ds.odc.geobox, 2)
# ds = xr_reproject(ds, geobox=grid.compat, resampling='average').compute()


## Open predictor data

In [None]:
data = collect_prediction_data(time_start='2003',
                             time_end='2022',
                             verbose=False,
                             export=False,
                             chunks=chunks_data
                             )
data = data.rename({'x':'longitude', 'y':'latitude'}).chunk(chunks)
data

## Calculate climatology and anomaly



In [None]:
ds_clim_mean = ds.groupby('time.month').mean().compute()
# ds_anom = (ds.groupby('time.month') - ds_clim_mean).compute()

## Create table of correlations per bioclimatic region

In [None]:
gdf = gpd.read_file('/g/data/os22/chad_tmp/NEE_modelling/data/bioclimatic_regions.geojson')
gdf.head()

In [None]:
clim_vars = ['rain_anom', 'rain_cml3_anom', 'rain_cml6_anom','rain_cml12_anom', 'tavg_anom', 'srad_anom', 'kNDVI_anom']

In [None]:
ds_anom = assign_crs(ds_anom, crs='EPSG:4326')

In [None]:
outer = {}
for index, row in gdf.iterrows():
    print(row['region_name'])
    mask = xr_rasterize(gdf.iloc[[index]], ds_anom.isel(time=1))
    mask = round_coords(mask)
    # mask = mask.rename({'latitude':'y', 'longitude':'x'})
    inner = {}
    for v in clim_vars:
        var_anom_region = data[v].where(mask).compute()
        ds_anom_region = ds_anom.where(mask)
        r2 = xr.corr(ds_anom_region.chunk(chunks),
                     var_anom_region.chunk(chunks),
                     dim='time').compute()
        r2 = r2.mean(['latitude', 'longitude'])
        print('  ', v, r2.values)
        inner[v] = r2.values
    outer[row['region_name']] = inner

In [None]:
df = pd.DataFrame(outer)
df

In [None]:
df.to_csv('/g/data/os22/chad_tmp/NEE_modelling/results/'+var+'_anomaly_bioregion_correlations.csv')

### Over all of Aus

In [None]:
# r2 = xr.corr(ds_anom_region, var_anom_region, dim='time').compute()
# r2 = r2.mean(['x', 'y'])
# print(r2)

In [None]:
# ax2_ylim = -100,100
# ax_ylim = -1,1

# fig,ax=plt.subplots(1,2, figsize=(18,5), gridspec_kw={'width_ratios': [3, 1]})
# ax2 = ax[0].twinx()
# var_anom.mean(['x','y']).rolling(time=3).mean().plot(ax=ax2, label=data_var, c='orange')
# ds_anom.sum(['x','y']).rolling(time=3).mean().plot(ax=ax[0], label=var)
# ax[0].legend(loc=(0.80,0.925))
# ax2.legend(loc=(0.80,0.85))
# ax2.set_ylabel(data_var+' Anomaly', fontsize=15)
# ax[0].set_xlabel('')
# ax2.set_ylim(ax2_ylim)
# ax[0].set_ylim(ax_ylim)
# ax[0].text(.05, .90, 'r={:.2f}'.format(r2[0]),
#             transform=ax[0].transAxes, fontsize=15)
# ax[0].set_ylabel(var+' Anomalies (PgC y⁻¹)', fontsize=15)
# ax[0].tick_params(axis='x', labelsize=14)
# ax[0].tick_params(axis='y', labelsize=14)
# ax2.tick_params(axis='y', labelsize=14)
# ax[1].tick_params(axis='x', labelsize=14)
# ax[1].tick_params(axis='y', labelsize=14)

# ax[0].axhline(0, c='grey', linestyle='--')

# ax3 = ax[1].twinx()

# var_clim_mean.mean(['x','y']).plot(ax=ax3, label=data_var, c='orange')
# ds_clim_mean.mean(['x','y']).plot(ax=ax[1], label='NEE')
# ax3.set_ylabel(data_var)
# ax3.set_ylabel(data_var, fontsize=15)
# ax[1].set_ylabel(var+' (PgC y⁻¹)', fontsize=15)
# ax[1].set_xticks(range(1,13))
# ax[1].set_xticklabels(["J","F","M","A","M","J","J","A","S","O","N","D"]) 
# ax[1].set_xlabel('')
# ax3.tick_params(axis='y', labelsize=14)
# ax[0].set_title(None)
# ax[1].set_title(None)
# ax2.set_title(None)
# ax3.set_title(None)
# plt.tight_layout();
# plt.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_Aus_'+data_var+'_correlations.png')

## Per-pixel correlations

### Correlations with climatology

In [None]:
c = {'latitude':1100,'longitude':1100}

precip_clim_corr = xr.corr(ds_clim_mean.chunk(c), data['rain'].groupby('time.month').mean(), dim='month').compute()
# srad_clim_corr = xr.corr(ds_clim_mean.chunk(c), data['srad'].groupby('time.month').mean(), dim='month').compute()
# tavg_clim_corr = xr.corr(ds_clim_mean.chunk(c), data['tavg'].groupby('time.month').mean(), dim='month').compute()
# kNDVI_clim_corr = xr.corr(ds_clim_mean.chunk(c), data['kNDVI'].groupby('time.month').mean(), dim='month').compute()


### Plot correlations with climatology

In [None]:
fig,ax = plt.subplots(2,2, figsize=(12,12), sharey=True, sharex=True)
precip_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[0,0], add_colorbar=False)
ax[0,0].set_title(var+' Climatology & Rain Climatology',  fontsize=18);
ax[0,0].set_yticklabels([])
ax[0,0].set_ylabel('')
ax[0,0].set_xlabel('')
ax[0,0].set_xticklabels([])

# tavg_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[0,1], add_colorbar=False)
# ax[0,1].set_title(var+' Climatology & TAVG Climatology',  fontsize=18);
# ax[0,1].set_yticklabels([])
# ax[0,1].set_ylabel('')
# ax[0,1].set_xlabel('')
# ax[0,1].set_xticklabels([])

# srad_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[1,0], add_colorbar=False)
# ax[1,0].set_title(var+' Climatology & SRAD Climatology',  fontsize=18);
# ax[1,0].set_yticklabels([])
# ax[1,0].set_ylabel('')
# ax[1,0].set_xlabel('')
# ax[1,0].set_xticklabels([])

# im = kNDVI_clim_corr.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax[1,1], add_colorbar=False)
# ax[1,1].set_title(var+' Climatology & kNDVI Climatology', fontsize=18)
# ax[1,1].set_yticklabels([])
# ax[1,1].set_ylabel('')
# ax[1,1].set_xlabel('')
# ax[1,1].set_xticklabels([])

plt.tight_layout();
fig.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_climatology_perpixel_climate_correlations.png',
            bbox_inches='tight', dpi=300)

### Correlations with anomalies

In [None]:
kNDVI_anom_corr = xr.corr(ds_anom.chunk(chunks), data['kNDVI_anom'], dim='time').compute()
rain_anom_corr = xr.corr(ds_anom.chunk(chunks), data['rain_anom'], dim='time').compute()
# vpd_anom_corr = xr.corr(ds_anom.chunk(chunks), data['vpd'], dim='time').compute()
srad_anom_corr = xr.corr(ds_anom.chunk(chunks), data['srad_anom'], dim='time').compute()
tavg_anom_corr = xr.corr(ds_anom.chunk(chunks), data['tavg_anom'], dim='time').compute()

### Plot correlations with anomalies

In [None]:
anom_data = [rain_anom_corr,  tavg_anom_corr, srad_anom_corr, kNDVI_anom_corr]
clim_vars = ['Rainfall', 'Air Temp.','Solar Rad.', 'kNDVI']

In [None]:
fig,axes = plt.subplots(1,4, figsize=(24,7), sharey=True, sharex=True)

for ax, ds, clim in zip(axes.ravel(), anom_data, clim_vars):

    im = ds.plot.imshow(vmin=-0.8, vmax=0.8, cmap='RdBu_r', ax=ax, add_colorbar=False)
    ax.set_title(var+' Anomalies & '+clim+' Anomalies',  fontsize=20);
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_xticklabels([])

plt.tight_layout();
fig.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_anomalies_perpixel_climate_correlations.png', bbox_inches='tight')

In [None]:
# get handles and labels for reuse
# label_params = ax[1,1].get_legend_handles_labels() 
cbar = fig.colorbar(im, spacing='uniform', ax=ax, orientation='horizontal', shrink=0.4);

# cbar
figl, axl = plt.subplots(figsize=(11,4))
axl.axis(False)
cbar = plt.colorbar(im, spacing='uniform', ax=axl, orientation='horizontal')
cbar.ax.tick_params(labelsize=20)
cbar.set_label("Pearson's Correlation",size=20)
figl.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/correlation_legend.png', bbox_inches='tight', dpi=300)

### Correlations with climate vars

In [None]:
# kndvi_corr = xr.corr(ds.chunk(chunks), data['kNDVI'], dim='time').compute()
# rain_corr = xr.corr(ds.chunk(chunks), data['rain'], dim='time').compute()
# vpd_corr = xr.corr(ds.chunk(chunks), data['vpd'], dim='time').compute()
# srad_corr = xr.corr(ds.chunk(chunks), data['srad'], dim='time').compute()
# tavg_corr = xr.corr(ds.chunk(chunks), data['tavg'], dim='time').compute()

### Plots by bioregion

In [None]:
ax2_ylim = -0.11,0.11

In [None]:
# Dictionary to save results 
results = {}
for index, row in gdf.iterrows():

    # Generate a polygon mask to keep only data within the polygon
    mask = xr_rasterize(gdf.iloc[[index]], var_anom.isel(time=1))
    mask['latitude'] = mask.latitude.astype('float32')
    mask['longitude'] = mask.longitude.astype('float32')  
    mask['latitude'] = np.array([round(i,4) for i in mask.latitude.values])
    mask['longitude'] = np.array([round(i,4) for i in mask.longitude.values])
    mask = mask.rename({'latitude':'y', 'longitude':'x'})
    
    # Mask dataset to set pixels outside the polygon to `NaN`
    var_anom_region = var_anom.where(mask)
    ds_anom_region = ds_anom.where(mask)
    
    var_clim_mean_region = var_clim_mean.where(mask)
    ds_clim_mean_region = ds_clim_mean.where(mask)
    
    r2 = xr.corr(ds_anom_region, var_anom_region, dim='time').compute()
    r2 = r2.mean(['x', 'y'])
    print(row['region_name'], r2)

    fig,ax=plt.subplots(1,2, figsize=(18,5), gridspec_kw={'width_ratios': [3, 1]})
    ax2 = ax[0].twinx()
    var_anom_region.mean(['x','y']).rolling(time=3).mean().plot(ax=ax2, label=data_var, c='orange')
    ds_anom_region.mean(['x','y']).rolling(time=3).mean().plot(ax=ax[0], label=var)
    
    ax[0].legend(loc=(0.80,0.925))
    ax2.legend(loc=(0.80,0.85))
    ax2.set_ylabel(data_var+' Anomaly', fontsize=15)
    ax[0].set_xlabel('')
    ax2.set_ylim(ax2_ylim)
    ax[0].set_ylim(ax_ylim)
    ax[0].text(.05, .90, 'r={:.2f}'.format(r2[0]),
                transform=ax[0].transAxes, fontsize=15)
    ax[0].set_ylabel(var+' Anomalies (gC m\N{SUPERSCRIPT TWO} m⁻¹)', fontsize=15)
    ax[0].tick_params(axis='x', labelsize=14)
    ax[0].tick_params(axis='y', labelsize=14)
    ax2.tick_params(axis='y', labelsize=14)
    ax[1].tick_params(axis='x', labelsize=14)
    ax[1].tick_params(axis='y', labelsize=14)

    ax[0].axhline(0, c='grey', linestyle='--')

    ax3 = ax[1].twinx()

    var_clim_mean_region.mean(['x','y']).plot(ax=ax3, label=data_var, c='orange')
    ds_clim_mean_region.mean(['x','y']).plot(ax=ax[1], label='NEE')
    
    ax3.set_ylabel(data_var)
    ax3.set_ylabel(data_var, fontsize=15)
    ax[1].set_ylabel(var+' (gC m\N{SUPERSCRIPT TWO} m⁻¹)', fontsize=15)
    ax[1].set_xticks(range(1,13))
    ax[1].set_xticklabels(["J","F","M","A","M","J","J","A","S","O","N","D"]) 
    ax[1].set_xlabel('')
    ax3.tick_params(axis='y', labelsize=14)
    ax[0].set_title(None)
    ax[1].set_title(None)
    ax2.set_title(None)
    ax3.set_title(None)
    
    plt.suptitle(row['region_name'], fontsize=18)
    plt.tight_layout();
    plt.savefig('/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_'+row['region_name']+'_'+data_var+'_correlations.png')

### Variable with highest correlation

In [None]:
corrs = xr.merge([
    np.abs(rain_corr.rename('rain')),
    np.abs(vpd_corr.rename('vpd')),
    np.abs(srad_corr.rename('srad')),
    np.abs(tavg_corr.rename('tavg'))
])

max_corrs = corrs.to_array("variable").idxmax("variable")
max_corrs = xr.where(max_corrs == 'rain', 1, max_corrs)
max_corrs = xr.where(max_corrs == 'vpd', 2, max_corrs)
max_corrs = xr.where(max_corrs == 'srad', 3, max_corrs)
max_corrs = xr.where(max_corrs == 'tavg', 4, max_corrs)

max_corrs = max_corrs.astype(np.float32)

In [None]:
rain_corrs = xr.merge([
    np.abs(precip_corr.rename('precip_anom')),
    np.abs(precip_3_corr.rename('precip_3_anom')),
    np.abs(precip_6_corr.rename('precip_6_anom')),
    np.abs(precip_12_corr.rename('precip_12_anom'))
])

rain_max_corrs = rain_corrs.to_array("variable").idxmax("variable")

rain_max_corrs = xr.where(rain_max_corrs == 'precip_anom', 1, rain_max_corrs)
rain_max_corrs = xr.where(rain_max_corrs == 'precip_3_anom', 2, rain_max_corrs)
rain_max_corrs = xr.where(rain_max_corrs == 'precip_6_anom', 3, rain_max_corrs)
rain_max_corrs = xr.where(rain_max_corrs == 'precip_12_anom', 4, rain_max_corrs)

rain_max_corrs = rain_max_corrs.astype(np.float32)

### Plots

In [None]:
fig,ax = plt.subplots(1,1, figsize=(10,7))
im = max_corrs.plot.imshow(add_colorbar=False, ax=ax)
cbar = fig.colorbar(im, spacing='uniform', ax=ax, orientation='vertical', shrink=0.4)
cbar.set_ticks([1,2,3,4])
cbar.set_ticklabels(['Rain', 'VPD', 'SRAD', 'TAVG'], fontsize=10)
plt.title('Climate Variable with Maximum Absolute Correlation with '+var);

In [None]:
fig,ax = plt.subplots(1,1, figsize=(10,7))
im = rain_max_corrs.plot.imshow(add_colorbar=False, ax=ax)
cbar = fig.colorbar(im, spacing='uniform', ax=ax, orientation='vertical', shrink=0.4)
cbar.set_ticks([1,2,3,4])
cbar.set_ticklabels(['Rain', 'Rain-3', 'Rain-6', 'Rain-12'], fontsize=10)
plt.title('Rainfall Variable with Maximum Absolute Correlation with '+var);