In [None]:
import numpy as np
import xarray as xr
import seaborn as sb
import pandas as pd
import geopandas as gpd
from scipy import stats
import xskillscore as xs
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs
from scipy.stats import gaussian_kde
from sklearn.metrics import mean_absolute_error

import sys
sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from _collect_prediction_data import round_coords

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.spatial import xr_rasterize

## Plotting harmonized NDVI time-series some regions

In [None]:
var='ndvi'
crs='epsg:3577'

merge = assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/'+var.upper()+'_harmonization/regions/trees_Harmonized_'+var.upper()+'_AVHRR_MODIS_1982_2013.nc'),
                   crs=crs)

merge_lgbm = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/NDVI_trees_LGBM_harmonize_test_5km_monthly_1982_2013.nc')['NDVI']

merge_old = assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/'+var.upper()+'_harmonization/Harmonized_'+var.upper()+'_AVHRR_MODIS_1982_2013.nc'),
                   crs=crs)

# merge = merge.isel(x=range(625,715), y=range(650,755))

In [None]:
trees = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/trees_5km_monthly_1982_2022.nc')['trees']
trees = assign_crs(trees, crs ='epsg:4326')
trees=trees.sel(time=slice('2001', '2019'))
trees=trees.odc.reproject(how=merge.odc.geobox)
trees = trees.mean('time')
mask = xr.where(trees>0.5, 1, 0)

In [None]:
merge_old = merge_old.where(mask)
# merge_lgbm = merge_lgbm.where(mask)

In [None]:
avhrr_mask = ~np.isnan(merge_old['ndvi_cdr'])

In [None]:
# mask = mask.odc.reproject(ndvi.odc.geobox, resampling= 'nearest')
mask = round_coords(mask)

In [None]:
ndvi = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/NDVI_5km_monthly_1982_2022.nc')

In [None]:
ndvi.NDVI.where(mask).sel(longitude=slice(140,150), latitude=slice(-10,-20)).isel(time=1).plot() #.sel(longitude=slice(144,149), latitude=slice(-40,-45))

In [None]:
merge_lgbm.isel(x=slice(600,800), y=slice(0,230)).mean(['x','y']).rolling(time=3).mean().plot(figsize=(11,4))

In [None]:
# merge_lgbm.sel(longitude=slice(140,150), latitude=slice(-10,-20)).mean(['x','y']).rolling(time=3).mean().plot(figsize=(11,4))
# plt.title("Tropical Forests QLD ('trees') Calibrated AVHRR & MODIS NDVI")

In [None]:
ndvi.NDVI.where(mask).sel(longitude=slice(140,150), latitude=slice(-10,-20)).mean(['latitude','longitude']).rolling(time=3).mean().plot(figsize=(11,4))
plt.title("Tropical Forests QLD ('trees') Calibrated AVHRR & MODIS NDVI")

## Comparing the different models

In [None]:
var='ndvi'
crs='epsg:3577'
name='nontrees'

merge = assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/'+var.upper()+'_harmonization/regions/non_trees_Harmonized_'+var.upper()+'_AVHRR_MODIS_1982_2013.nc'),
                   crs=crs)

merge_lgbm = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/NDVI_'+name+'_LGBM_harmonize_test_5km_monthly_2001_2013.nc')['NDVI']

merge_old = assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/'+var.upper()+'_harmonization/Harmonized_'+var.upper()+'_AVHRR_MODIS_1982_2013.nc'),
                   crs=crs)

# merge = merge.isel(x=range(625,715), y=range(650,755))

In [None]:
trees = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/trees_5km_monthly_1982_2022.nc')['trees']
trees = assign_crs(trees, crs ='epsg:4326')
trees=trees.sel(time=slice('2001', '2018'))
trees=trees.odc.reproject(how=merge_lgbm.odc.geobox)
trees = trees.mean('time')

if name=='trees':
    mask = xr.where(trees>0.5, 1, 0)
if name=='nontrees':
    mask = xr.where(trees<=0.5, 1, 0)

In [None]:
sami = assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/MCD43_AVHRR_NDVI_hybrid_EasternOzWoody.nc'),
                                 crs='epsg:4326')[['ndvi_mcd_pred', 'ndvi_mcd', 'ndvi_cdr']]
sami = sami.sel(time=slice('2001', '2013'))
sami = sami.odc.reproject(how=merge.odc.geobox)
sami = sami.astype('float32')
sami = sami.where(mask)
sami['time'] = merge.sel(time=slice('2001','2013')).time

sami_mask = ~np.isnan(sami['ndvi_cdr'])

sami = sami.where(sami_mask)

In [None]:
avhrr_mask = ~np.isnan(merge['ndvi_cdr'].sel(time=slice('2001', '2013')))

merge = merge.where(avhrr_mask).sel(time=slice('2001', '2013'))

merge_old = merge_old.where(avhrr_mask).sel(time=slice('2001', '2013'))

merge_lgbm = merge_lgbm.where(avhrr_mask).sel(time=slice('2001', '2013'))

# sami = sami.mask(avhrr_mask)

## Convert to dataframes for nice plotting

In [None]:
modis_flat = sami[var+'_mcd'].values.flatten()
avhrr_flat = sami[var+'_cdr'].values.flatten()
avhrr_adjust = sami[var+'_mcd_pred'].values.flatten()

df = pd.DataFrame({'MODIS':modis_flat, 'AVHRR-original':avhrr_flat, 'AVHRR-adjusted':avhrr_adjust})
df = df.dropna()
df_sample = df.sample(n=20000, random_state=1) #too many pixels to plot so grab sample

In [None]:
modis_flat = merge[var+'_mcd'].values.flatten()
avhrr_flat = merge[var+'_cdr'].values.flatten()
avhrr_adjust = merge_lgbm.values.flatten()

df = pd.DataFrame({'MODIS':modis_flat, 'AVHRR-original':avhrr_flat, 'AVHRR-adjusted':avhrr_adjust})
df = df.dropna()
df_sample = df.sample(n=20000, random_state=1) #too many pixels to plot so grab sample

In [None]:
modis_flat = merge_old[var+'_mcd'].values.flatten()
avhrr_flat = merge_old[var+'_cdr'].values.flatten()
avhrr_adjust = merge_old[var+'_mcd_pred'].values.flatten()

df = pd.DataFrame({'MODIS':modis_flat, 'AVHRR-original':avhrr_flat, 'AVHRR-adjusted':avhrr_adjust})
df = df.dropna()
df_sample = df.sample(n=20000, random_state=1) #too many pixels to plot so grab sample

In [None]:
modis_flat = merge[var+'_mcd'].values.flatten()
avhrr_flat = merge[var+'_cdr'].values.flatten()
avhrr_adjust = merge[var+'_mcd_pred'].values.flatten()

df = pd.DataFrame({'MODIS':modis_flat, 'AVHRR-original':avhrr_flat, 'AVHRR-adjusted':avhrr_adjust})
df = df.dropna()
df_sample = df.sample(n=20000, random_state=1) #too many pixels to plot so grab sample

In [None]:
# df = merge.ndvi_mcd.mean(['x', 'y']).rename('MODIS').drop('spatial_ref').to_dataframe()
# df['AVHRR-original'] = merge.ndvi_cdr.mean(['x', 'y']).drop('spatial_ref').to_dataframe()
# df['AVHRR-adjusted'] = merge.ndvi_mcd_pred.mean(['x', 'y']).drop('spatial_ref').to_dataframe()
# df = df.dropna()
# df.head()

## Scatter plots of before and after harmonization

In [None]:
products=['AVHRR-original', 'AVHRR-adjusted']

fig,ax = plt.subplots(1,2, figsize=(10,5), sharey=True)
font=15

for prod, ax in zip(products, ax.ravel()):
    obs,pred = df_sample['MODIS'].values, df_sample[prod].values
    slope, intercept, r_value, p_value, std_err = stats.linregress(obs,pred)
    r2 = r_value**2
    ac = mean_absolute_error(obs, pred)
    
    xy = np.vstack([obs,pred])
    z = gaussian_kde(xy)(xy)
    
    sb.scatterplot(data=df_sample, x='MODIS',y=prod, c=z, s=20, lw=1, alpha=0.5, ax=ax)
    sb.regplot(data=df_sample, x='MODIS',y=prod, scatter=False, color='blue', ax=ax)
    sb.regplot(data=df_sample, x='MODIS',y='MODIS', color='black', scatter=False, line_kws={'linestyle':'dashed'}, ax=ax)
    ax.set_title(prod, fontsize=font)
    ax.set_xlabel('MODIS '+var.upper(), fontsize=font)
    ax.set_ylabel('')
    ax.set_ylim(0.0,0.8)
    ax.set_xlim(0.0,0.8)
    ax.text(.05, .90, 'r\N{SUPERSCRIPT TWO}={:.2f}'.format(np.mean(r2)),
            transform=ax.transAxes, fontsize=font)
    ax.text(.05, .825, 'MAE={:.2g}'.format(np.mean(ac)),
            transform=ax.transAxes, fontsize=font)
    ax.tick_params(axis='x', labelsize=font)
    ax.tick_params(axis='y', labelsize=font)

fig.supylabel('AVHRR '+var.upper(), fontsize=font)
plt.tight_layout();

## Comparisons

In [None]:
# fig, ax = plt.subplots(1,1, figsize=(11,4))
# merge_old[var+'_mcd'].mean(['x','y']).plot(ax=ax, label='MODIS-mine', c='black')
# sami[var+'_mcd'].mean(['x','y']).plot(ax=ax, label='MODIS-sami', c='red')

In [None]:
# Dictionary to save results 
gdf = gpd.read_file('/g/data/os22/chad_tmp/NEE_modelling/data/bioclimatic_regions.geojson')

for index, row in gdf.iterrows():
    
    if (name=='trees') & (row['region_name']=='Desert'):
        pass

    else:
        print(row['region_name'])
        
        # Generate a polygon mask to keep only data within the polygon
        mask_region = xr_rasterize(gdf.iloc[[index]], merge)
        mask_lgbm = xr_rasterize(gdf.iloc[[index]], merge_lgbm)
        mask_old = xr_rasterize(gdf.iloc[[index]], merge_old)
        # mask_sami = xr_rasterize(gdf.iloc[[index]], sami)
        #mask = round_coords(mask)
        
        # Mask dataset to set pixels outside the polygon to `NaN`
        merge_region = merge.where(mask_region)
        merge_lgbm_region = merge_lgbm.where(mask_lgbm)
        merge_old_region = merge_old.where(mask_old)
        #sami_region = merge_old.where(mask_sami)
        
        fig, ax = plt.subplots(1,1, figsize=(11,4))
        merge_region[var+'_cdr'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR original')
        # sami_region[var+'_cdr'].sel(time=slice('2001', '2013')).mean(['x','y']).plot(ax=ax, label='AVHRR original')
        
        merge_region[var+'_mcd_pred'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR adjusted - nontrees GAM')
        
        merge_lgbm_region.mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR adjusted - nontrees LGBM')
        
        merge_old_region[var+'_mcd_pred'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR adjusted - Aus GAM')
        
        # sami_region[var+'_mcd_pred'].mean(['x','y']).plot(ax=ax, label='AVHRR adjusted - Sami GAM')
        
        merge_old_region[var+'_mcd'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='MODIS', c='black')
        # sami_region[var+'_mcd'].mean(['x','y']).plot(ax=ax, label='MODIS', c='black')
        
        # sami_region[var+'_cdr'].sel(time=slice('2001', '2013')).mean(['x','y']).plot(ax=ax, label='AVHRR original')
        ax.legend()
        ax.set_title(row['region_name']);

In [None]:
fig, ax = plt.subplots(1,1, figsize=(13,5))
merge[var+'_cdr'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR original')
merge[var+'_mcd_pred'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR adjusted - nontrees GAM')
merge_lgbm.mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='AVHRR adjusted - nontrees LGBM')
merge_old[var+'_mcd_pred'].rolling(time=3).mean().mean(['x','y']).plot(ax=ax, label='AVHRR adjusted - Aus model')

merge[var+'_mcd'].mean(['x','y']).rolling(time=3).mean().plot(ax=ax, label='MODIS', c='black')
# ax.set_ylim(0.15, 0.40)
ax.legend()
ax.set_title('Australia (nontrees only)');

In [None]:
adjusted_corr = xr.corr(merge_old[var+'_mcd'],
                        merge_old[var+'_mcd_pred'], 
                        dim='time'
                       )

orig_corr = xr.corr(merge_old[var+'_mcd'],
                        merge_old[var+'_cdr'], 
                        dim='time'
                       )

adjusted_mape = xs.mape(merge_old[var+'_mcd'],
                        merge_old[var+'_mcd_pred'], 
                        dim='time', skipna=True) *100

orig_mape = xs.mape(merge_old[var+'_mcd'],
                    merge_old[var+'_cdr'], 
                        dim='time', skipna=True) *100

adjusted_rmse = xs.rmse(merge_old[var+'_mcd'],
                        merge_old[var+'_mcd_pred'], 
                        dim='time', skipna=True)

orig_rmse = xs.rmse(merge_old[var+'_mcd'],
                    merge_old[var+'_cdr'], 
                        dim='time', skipna=True)

In [None]:
corr_data = [orig_corr, adjusted_corr, adjusted_corr-orig_corr]
products=['AVHRR-original', 'AVHRR-adjusted', 'Difference']

fig,axes = plt.subplots(1,3, figsize=(20,5), sharey=True)

for ax, ds, n in zip(axes.ravel(), corr_data, products):
    if n=='Difference':
        cmap='RdBu'
        vmin=-0.5
        vmax=0.5
    else:
        cmap='magma'
        vmin=0
        vmax=1
    im = ds.plot.imshow(vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, add_colorbar=True)
    ax.set_title(f'{n} R '"{:.2f}".format(ds.mean().values))
    #ax.set_title(n,  fontsize=15);
    ax.set_yticklabels([])
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_xticklabels([])

# fig.subplots_adjust(wspace=0.05)
# fig.colorbar(im, ax=axes.ravel().tolist(), pad=0.01, label='Correlation');
# plt.suptitle('Correlation', fontsize=15)
plt.tight_layout();

## Adding features to GAM inpust

In [None]:
# mod=assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/MODIS_NDVI_5km_monthly_2001_2022.nc')['NDVI_median'], crs='epsg:3577')
# av=assign_crs(xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/AVHRR_5km_monthly_1982_2013_climate.nc'), crs='epsg:3577')

In [None]:
# rain_cml3 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_cml3_5km_monthly_1982_2022.nc')['rain_cml3']
# rain_cml3 = assign_crs(rain_cml3, crs ='epsg:4326')
# rain_cml3=rain_cml3.sel(time=slice('1982', '2013'))
# rain_cml3=rain_cml3.odc.reproject(how=av.odc.geobox)
# av['rain_cml3'] = rain_cml3

# rain_cml6 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_cml6_5km_monthly_1982_2022.nc')['rain_cml6']
# rain_cml6 = assign_crs(rain_cml6, crs ='epsg:4326')
# rain_cml6=rain_cml6.sel(time=slice('1982', '2013'))
# rain_cml6=rain_cml6.odc.reproject(how=av.odc.geobox)
# av['rain_cml6'] = rain_cml6

# rain_cml3_anom = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_cml3_anom_5km_monthly_1982_2022.nc')['rain_cml3_anom']
# rain_cml3_anom = assign_crs(rain_cml3_anom, crs ='epsg:4326')
# rain_cml3_anom=rain_cml3_anom.sel(time=slice('1982', '2013'))
# rain_cml3_anom=rain_cml3_anom.odc.reproject(how=av.odc.geobox)
# av['rain_cml3_anom'] = rain_cml3_anom

# vpd = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/vpd_5km_monthly_1982_2022.nc')['vpd']
# vpd = assign_crs(vpd, crs ='epsg:4326')
# vpd=vpd.sel(time=slice('1982', '2013'))
# vpd=vpd.odc.reproject(how=av.odc.geobox)
# av['vpd'] = vpd

# srad = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/srad_5km_monthly_1982_2022.nc')['srad']
# srad = assign_crs(srad, crs ='epsg:4326')
# srad=srad.sel(time=slice('1982', '2013'))
# srad=srad.odc.reproject(how=av.odc.geobox)
# av['srad'] = srad

# CO2 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/CO2_5km_monthly_1982_2022.nc')['CO2']
# CO2 = assign_crs(CO2, crs ='epsg:4326')
# CO2=CO2.sel(time=slice('1982', '2013'))
# CO2=CO2.odc.reproject(how=av.odc.geobox)
# av['CO2'] = CO2.transpose('time','x', 'y')

# for i in av.data_vars:
#     try:
#         del av[i].attrs['grid_mapping']
#     except:
#         continue

# av.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/AVHRR_5km_monthly_1982_2013_climate.nc')

In [None]:
# trees = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/trees_5km_monthly_1982_2022.nc')['trees']
# trees = assign_crs(trees, crs ='epsg:4326')
# trees=trees.sel(time=slice('2001', '2019'))
# trees=trees.odc.reproject(how=mod.odc.geobox)
# trees = trees.mean('time')
# trees = xr.where(trees>0.5, 1, 0)
# not_trees = xr.where(trees<=0.5, 1, 0)

# vars = [trees,not_trees]

# for mask, n in zip(vars, ['trees', 'non_trees']):
#     mod_region = mod.where(mask)
#     av_region = av.where(mask)
#     mod_region.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+n+'_MODIS_NDVI_5km_monthly_2001_2022.nc')
#     av_region.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+n+'_AVHRR_5km_monthly_1982_2013_climate.nc')

In [None]:
# gdf = gpd.read_file('/g/data/os22/chad_tmp/NEE_modelling/data/bioclimatic_regions.geojson')

# for index, row in gdf.iterrows():
#     print(row['region_name'])
#     mask = xr_rasterize(gdf.iloc[[index]], mod.isel(time=1))
#     mod_region = mod.where(mask)
#     av_region = av.where(mask)
        
#     mod_region.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+row['region_name']+'_MODIS_NDVI_5km_monthly_2001_2022.nc')
#     av_region.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/regions/'+row['region_name']+'_AVHRR_5km_monthly_1982_2013_extras.nc')


In [None]:
# mod = mod.isel(x=range(625,715), y=range(650,755))
# av = av

# mod.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/TAS_MODIS_NDVI_5km_monthly_2001_2022.nc')
# av.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/TAS_AVHRR_5km_monthly_1982_2013.nc')