# Compare anomalies between GIMMS-PKU consolidate, Landsat, and AVHRR-Adj.


In [None]:
import numpy as np
import xarray as xr
import seaborn as sb
import pandas as pd
import geopandas as gpd
import matplotlib as mpl
from matplotlib.cm import ScalarMappable
import contextily as ctx
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs
from matplotlib.ticker import FormatStrFormatter

import sys
sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from _collect_prediction_data import round_coords

# Using ggplot styles in this notebook
plt.style.use('ggplot')

%matplotlib inline

## Analysis Parameters

In [None]:
model_var='LST'
crs='epsg:4326'
base = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/'

## Open datasets

In [None]:
syn = xr.open_dataset(f'{base}/synthetic/LST/LST_CLIM_synthetic_5km_monthly_1982_2022.nc')['LST']
syn = assign_crs(syn, crs=crs)
syn.attrs['nodata'] = np.nan
syn = syn.rename('LST')
syn = syn-273.15

# merge = xr.open_dataset(f'{base}NDVI_harmonization/LGBM/NDVI_NOCLIM_LGBM_5km_monthly_1982_2022_wGaps.nc')['NDVI']
merge = xr.open_dataset(f'{base}/LST_harmonization/LGBM/LST_NOCLIM_LGBM_5km_monthly_1982_2022_wGaps.nc')['LST']
merge = assign_crs(merge, crs=crs)
merge.attrs['nodata'] = np.nan
merge = merge.rename('LST')

rain = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_5km_monthly_1981_2022.nc').rain
rain = assign_crs(rain, crs=crs)
rain.attrs['nodata'] = np.nan

## Match datasets

In [None]:
# merge = merge.odc.reproject(pku.odc.geobox, resampling='average')
# rain = rain.odc.reproject(pku.odc.geobox, resampling='average')
# syn = syn.odc.reproject(pku.odc.geobox, resampling='average')

# merge = round_coords(merge)
# rain = round_coords(rain)
# syn = round_coords(syn)

In [None]:
mask =  ~np.isnan(merge)
del mask.attrs['nodata']
mask = assign_crs(mask, crs=crs)

In [None]:
merge = merge.where(mask)
rain = rain.sel(time=merge.time).where(mask)
syn = syn.where(mask)

merge = merge.rename({'latitude':'y', 'longitude':'x'})
rain = rain.rename({'latitude':'y', 'longitude':'x'})
syn = syn.rename({'latitude':'y', 'longitude':'x'})

In [None]:
fraction_avail= (~np.isnan(merge)).sum('time')/len(merge.time)

fig,ax=plt.subplots(1,1, figsize=(5,4),sharey=True, layout='constrained')

im = fraction_avail.where(fraction_avail>0).rename('').plot(vmin=0.1, vmax=0.95, ax=ax, cmap='magma', add_labels=False, add_colorbar=False)
ctx.add_basemap(ax, source=ctx.providers.CartoDB.Voyager, crs='EPSG:4326', attribution='', attribution_size=1)

cb = fig.colorbar(im, ax=ax, shrink=0.75, orientation='vertical', label='Fraction Available');
ax.set_title('Mean Fraction: '+str(round(fraction_avail.where(fraction_avail>0).mean().values.item(), 3)));

## Calculate standardised anomalies

In [None]:
import warnings
warnings.simplefilter('ignore')

#standardized anom
def stand_anomalies(ds):
    return xr.apply_ufunc(
        lambda x, m, s: (x - m) / s,
            ds.groupby("time.month"),
            ds.groupby("time.month").mean(),
            ds.groupby("time.month").std()
    )
    
rain_std_anom = stand_anomalies(rain)
merge_std_anom = stand_anomalies(merge)
syn_std_anom = stand_anomalies(syn)

## Rolling mean anomalies

In [None]:
roll=12

In [None]:
rain_df = rain_std_anom.rename('rain').rolling(time=roll,
                min_periods=roll).mean().mean(['x','y']).to_dataframe().drop(['spatial_ref', 'month'], axis=1)

plt.style.use('default')
fig, ax = plt.subplots(1,1, figsize=(14,4))
ax2 = ax.twinx()

syn_std_anom.drop('month').rolling(time=roll, min_periods=roll).mean().mean(['x','y']).plot(ax=ax, label='Synthetic')
merge_std_anom.drop('month').rolling(time=roll, min_periods=roll).mean().mean(['x','y']).plot(ax=ax, label='AVHRR-Adj. (GBM)')

norm=plt.Normalize(-2.5,2.5)
cmap = mpl.colors.LinearSegmentedColormap.from_list("", ['saddlebrown','chocolate','white','darkturquoise','darkcyan'], N=256)

# Plot bars
bar = ax2.bar(rain_df.index, 1, color=cmap(norm(rain_df['rain'])), width=32)
sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, shrink=0.8, pad=0.01)
cbar.set_label('Rainfall Anomaly (z-score)',labelpad=.5)

ax2.set_zorder(ax.get_zorder()-1)
ax.set_frame_on(False)
ax.axhline(0, c='grey', linestyle='--')

# Reformat y-axis label and tick labels
ax.set_ylabel(model_var+' Anomaly (z-score)')
ax.set_xlabel('')
ax2.set_ylabel('')
ax2.set_yticks([])
ax2.set_ylim([0, 1]) 
ax.margins(x=0)
ax2.margins(x=0)

# Adjust the margins around the plot area
plt.subplots_adjust(left=0.1, right=None, top=None, bottom=0.2, wspace=None, hspace=None)

ax.legend()
ax.set_title('Australian standardised LST anomalies ('+str(roll)+'-month rolling mean)');
fig.savefig("/g/data/os22/chad_tmp/climate-carbon-interactions/results/figs/Australian_LST_"+str(roll)+"Mrollingmean.png",
            bbox_inches='tight', dpi=300)

### Compare rainfall anomalies with NDVI anomalies per pixel

In [None]:
year='1988'

fig,ax = plt.subplots(1,3, figsize=(14,4), sharex=True, layout='constrained')
rain_std_anom.sel(time=year).mean('time').plot.imshow(vmin=-2.5, vmax=2.5, cmap='RdBu', ax=ax[0], add_labels=False)
merge_std_anom.sel(time=year).mean('time').plot.imshow(vmin=-2.5, vmax=2.5, cmap='BrBG', ax=ax[1], add_labels=False)
syn_std_anom.sel(time=year).mean('time').plot.imshow(vmin=-2.5, vmax=2.5, cmap='BrBG', ax=ax[2], add_labels=False)

ax[0].set_title('Rainfall anomaly '+year+' '+str(rain_std_anom.sel(time=year).mean().values))
ax[1].set_title('AVHRR Adj (GBM) NDVI anomaly '+year+' '+str(merge_std_anom.sel(time=year).mean().values))
ax[2].set_title('Synthetic) NDVI anomaly '+year+' '+str(syn_std_anom.sel(time=year).mean().values))

ax[0].set_yticklabels([])
ax[0].set_ylabel('')
ax[0].set_xlabel('')
ax[0].set_xticklabels([]);

ax[1].set_yticklabels([])
ax[1].set_ylabel('')
ax[1].set_xlabel('')
ax[1].set_xticklabels([]);

ax[2].set_yticklabels([])
ax[2].set_ylabel('')
ax[2].set_xlabel('')
ax[2].set_xticklabels([]);

In [None]:
rain_std_anom.sel(time=year).plot.imshow(vmin=-2.5, vmax=2.5, cmap='RdBu', col='time', col_wrap=6);

In [None]:
merge_std_anom.sel(time=year).plot.imshow(vmin=-2.5, vmax=2.5, cmap='BrBG', col='time', col_wrap=6);

In [None]:
syn_std_anom.sel(time=year).plot.imshow(vmin=-2.5, vmax=2.5, cmap='BrBG', col='time', col_wrap=6);

In [None]:
(merge.sel(time=year)-syn.sel(time=year)).plot.imshow(robust=True, cmap='BrBG', col='time', col_wrap=6);