# Map annual and seasonal trends in climate variables

* P:PET
* VPD
* rainfall

In [None]:
%matplotlib inline

import sys
import dask
import warnings
import odc.geo.xr
import numpy as np
import os
from odc.geo.xr import assign_crs
import xarray as xr
import pandas as pd
# import seaborn as sb
import contextily as ctx
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import round_coords

# sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
# from _prediction import allNaN_arg

from xarrayMannKendall import Mann_Kendall_test

In [None]:
import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask

client = start_local_dask(mem_safety_margin='2Gb')
client

## Open data


In [None]:
vpd = xr.open_dataset('/g/data/os22/chad_tmp/AusENDVI/data/5km/vpd_5km_monthly_1982_2022.nc')['vpd']

# pet = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/PET_1981_2022.nc')['PET']
# mi = xr.open_dataarray('/g/data/os22/chad_tmp/Aus_phenology/data/MI_1982_2022.nc')
pet = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/PET_GLEAM_1982_2022.nc')['PET']

rain = xr.open_dataset('/g/data/os22/chad_tmp/AusENDVI/data/5km/rain_5km_monthly_1981_2022.nc')['rain']
# rain = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/rainfall_CHIRPS_daily_5km_1981_2023.nc',
#                       chunks=dict(latitude=250, longitude=250))['rain']

rain = rain.sel(time=slice('1982','2022'))
pet = pet.sel(time=slice('1982','2022'))

# pet['time'] = [pd.offsets.MonthBegin().rollback(t)+pd.Timedelta(14, 'd') for t in pet['time'].values]

In [None]:
# rain = rain.resample(time='MS').sum().compute()
# rain['time'] = rain['time'] + pd.Timedelta(14, 'd') #Make time the middle of the month

mi = rain/pet

In [None]:
mi = mi.rename('P:PET')
mi = xr.where(mi>5, 5, mi) #remove extreme values
mi = xr.where(mi<0, 0, mi) #remove negative values

## Mask to study region

In [None]:
lin_or_circ = 'circular'

p_average = assign_crs(xr.open_dataset(f'/g/data/os22/chad_tmp/Aus_phenology/results/mean_phenology_perpixel_{lin_or_circ}_final.nc'), crs='EPSG:4326')
p_trends = assign_crs(xr.open_dataset(f'/g/data/os22/chad_tmp/Aus_phenology/results/trends_phenology_perpixel_{lin_or_circ}_final.nc'), crs='EPSG:4326')

nan_mask = ~np.isnan(p_average['POS'])

season_per_year = p_average['n_seasons']/p_average['n_years']
non_seasonal = xr.where((season_per_year <= 0.90),1,0)
extra_seasonal = xr.where((season_per_year >= 1.1),1,0)
seasonality_mask = (non_seasonal | extra_seasonal)

In [None]:
vpd = vpd.where(seasonality_mask!=1)
rain = rain.where(seasonality_mask!=1)
mi = mi.where(seasonality_mask!=1)

vpd = vpd.where(nan_mask)
rain = rain.where(nan_mask)
mi = mi.where(nan_mask)

## Annual trends

In [None]:
a_rain = rain.resample(time='YE', label='left').sum()
mask = xr.where(a_rain.max('time')==0, 0, 1)
a_rain = a_rain.where(mask)

a_vpd = vpd.resample(time='YE', label='left').mean()
a_mi = mi.resample(time='YE', label='left').mean()

In [None]:
annual_res = {}
prod =[a_rain, a_vpd, a_mi]
names = ['rainfall', 'VPD', 'Aridity']
for ds, name in zip(prod,names):
    print(name)
    _trends = Mann_Kendall_test(ds,
                               alpha=0.05,
                               method='theilslopes',
                               coords_name={'time':'time','x':'longitude','y':'latitude'})
                              
    annual_res[name] = _trends.compute()

In [None]:
fig,ax = plt.subplots(1,3, figsize=(12,4),sharey=True,sharex=True, layout='constrained')
cmaps = ['BrBG', 'RdBu_r', 'PuOr']
labels = 'Rainfall mm yr\u207B\u00B9', 'VPD HPa yr\u207B\u00B9', 'P:PET yr\u207B\u00B9'

for s, a, cmap, l in zip(names, ax.reshape(-1), cmaps, labels):
    if s=='VPD':
        im = annual_res[s].trend.where(annual_res[s].trend!=0).plot(ax=a, cmap=cmap,vmin=-0.1, vmax=0.1, add_colorbar=False)
    else:
        im = annual_res[s].trend.where(annual_res[s].trend!=0).plot(ax=a, cmap=cmap,robust=True, add_colorbar=False)
    ctx.add_basemap(a, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
    xr.plot.contourf(annual_res[s].trend.where(annual_res[s].signif), ax=a, alpha=0, hatches=['....'], add_colorbar=False)

    axins = inset_axes(a, width="55%",height="5%",loc="lower left", borderpad=2)
    cbar3 = fig.colorbar(im, cax=axins, orientation='horizontal')
    cbar3.ax.set_title(l, fontsize=10)
    
    a.set_yticklabels([])
    a.set_ylabel('')
    a.set_xlabel('')
    a.set_xticklabels([]);

fig.savefig("/g/data/os22/chad_tmp/Aus_phenology/results/figs/climate_annual_trends.png", bbox_inches='tight', dpi=300)

## Seasonal trends

In [None]:
q_rain = rain.resample(time='QE-DEC', label='left').sum()
q_vpd = vpd.resample(time='QE-DEC', label='left').mean()
q_mi = mi.resample(time='QE-DEC', label='left').mean()
# q_wcf = wcf.resample(time='QE-DEC', label='left').mean()

#mask zeros from the sum() for rainfall
mask = xr.where(q_rain.mean('time')==0, 0, 1)
q_rain = q_rain.where(mask)

#because DJF is missing in 1981 Dec, start in March 1982
q_rain = q_rain.sel(time=slice('1982-03','2022'))
q_vpd = q_vpd.sel(time=slice('1983-03','2022'))
q_mi = q_mi.sel(time=slice('1983-03','2022'))
# q_wcf = q_wcf.sel(time=slice('1983-03','2022'))

### rain

In [None]:
rain_res = {}
for season in ("DJF", "MAM", "JJA", "SON"):
    print(season)
    xx = q_rain.sel(time=q_rain['time.season']==season)
    xx['time'] = xx.time.dt.year
    _trends = Mann_Kendall_test(xx,
                               alpha=0.05,
                               method='theilslopes',
                               coords_name={'time':'time','x':'longitude','y':'latitude'}
                              ).compute()
    rain_res[season] = _trends


In [None]:
# fig,ax = plt.subplots(1,1, figsize=(8,8))
# im = rain_res['JJA'].trend.where(rain_res['JJA'].trend!=0).plot( vmin=-0.5, vmax=0.5)
# ctx.add_basemap(ax, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)

In [None]:
fig,ax = plt.subplots(1,4, figsize=(15,4),sharey=True, layout='constrained')
for s, a in zip(["DJF", "MAM", "JJA", "SON"], ax.reshape(-1)):
    im = rain_res[s].trend.rolling(y=5, x=5, center=True, min_periods=1).mean().where(rain_res[s].trend!=0).plot(ax=a, cmap='BrBG', vmin=-3, vmax=3, add_colorbar=False)
    ctx.add_basemap(a, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
    xr.plot.contourf(rain_res[s].trend.where(rain_res[s].signif).where(rain_res[s].trend!=0), ax=a, alpha=0, hatches=['....'], add_colorbar=False)
    a.set_title(s)
    a.set_yticklabels([])
    a.set_ylabel('')
    a.set_xlabel('')
    a.set_xticklabels([]);

cb = fig.colorbar(im, ax=ax, shrink=0.65, orientation='vertical', label='Rainfall trend (mm/yr)')
fig.savefig("/g/data/os22/chad_tmp/Aus_phenology/results/figs/rainfall_seasonal_trends.png", bbox_inches='tight', dpi=300)

### VPD

In [None]:
vpd_res = {}
for season in ("DJF", "MAM", "JJA", "SON"):
    print(season)
    xx = q_vpd.sel(time=q_vpd['time.season']==season)
    xx['time'] = xx.time.dt.year
    _trends = Mann_Kendall_test(xx,
                               alpha=0.05,
                               method='theilslopes',
                               coords_name={'time':'time','x':'longitude','y':'latitude'}
                              ).compute()
    vpd_res[season] = _trends

In [None]:
fig,ax = plt.subplots(1,4, figsize=(15,4),sharey=True, layout='constrained')
for s, a in zip(["DJF", "MAM", "JJA", "SON"], ax.reshape(-1)):
    im = vpd_res[s].trend.where(vpd_res[s].trend!=0).plot(ax=a, cmap='RdBu_r', vmin=-0.2, vmax=0.2, add_colorbar=False)
    ctx.add_basemap(a, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
    xr.plot.contourf(vpd_res[s].trend.where(vpd_res[s].signif), ax=a, alpha=0, hatches=['....'], add_colorbar=False)
    a.set_title(s)
    a.set_yticklabels([])
    a.set_ylabel('')
    a.set_xlabel('')
    a.set_xticklabels([]);

cb = fig.colorbar(im, ax=ax, shrink=0.65, orientation='vertical', label='VPD HPa yr\u207B\u00B9')
fig.savefig("/g/data/os22/chad_tmp/Aus_phenology/results/figs/vpd_seasonal_trends.png", bbox_inches='tight', dpi=300)

### moisture index

In [None]:
mi_res = {}
for season in ("DJF", "MAM", "JJA", "SON"):
    print(season)
    xx = q_mi.sel(time=q_mi['time.season']==season)
    xx['time'] = xx.time.dt.year
    _trends = Mann_Kendall_test(xx,
                               alpha=0.05,
                               method='theilslopes',
                               coords_name={'time':'time','x':'longitude','y':'latitude'}
                              ).compute()
    mi_res[season] = _trends

In [None]:
fig,ax = plt.subplots(1,4, figsize=(15,4),sharey=True, layout='constrained')
for s, a in zip(["DJF", "MAM", "JJA", "SON"], ax.reshape(-1)):
    im = mi_res[s].trend.rolling(y=5, x=5, center=True, min_periods=1).mean().where(mi_res[s].trend!=0).plot(ax=a, cmap='PuOr', vmin=-0.007, vmax=0.007, add_colorbar=False)
    ctx.add_basemap(a, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
    xr.plot.contourf(mi_res[s].trend.where(mi_res[s].signif), ax=a, alpha=0, hatches=['....'], add_colorbar=False)
    a.set_title(s)
    a.set_yticklabels([])
    a.set_ylabel('')
    a.set_xlabel('')
    a.set_xticklabels([]);

cb = fig.colorbar(im, ax=ax, shrink=0.65, orientation='vertical', label='P:PET yr\u207B\u00B9')
fig.savefig("/g/data/os22/chad_tmp/Aus_phenology/results/figs/mi_seasonal_trends.png", bbox_inches='tight', dpi=300)

## Preprocess GLEAM evap

In [None]:
# import os
# import xarray as xr
# import pandas as pd
# import numpy as np
# from odc.geo.xr import assign_crs

# import sys
# sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
# from _utils import start_local_dask, round_coords

# client = start_local_dask(mem_safety_margin='2Gb')
# client

In [None]:
# base = '/g/data/os22/chad_tmp/Aus_phenology/data/GLEAM/'
# gbox= xr.open_dataset('/g/data/os22/chad_tmp/AusENDVI/data/5km/rain_5km_monthly_1981_2022.nc').odc.geobox

# files = [f'{base}{i}' for i in os.listdir(base) if i.endswith(".nc")]
# files.sort()

# # #combine annual files into one file
# ds = xr.open_mfdataset(files)
# ds = ds.sel(lat=slice(-10,-45), lon=slice(111,155))
# ds = ds.rename({'lat':'latitude', 'lon':'longitude'})
# ds = assign_crs(ds['Ep'], crs='EPSG:4236')
# ds = ds.rename('PET')
# ds.attrs['nodata'] = np.nan
# ds = ds.chunk(dict(time=1), longitude=1000, latitude=1000).odc.reproject(gbox, resampling='bilinear').compute()
# ds = round_coords(ds)
# ds = ds.rename('PET')
# ds['time'] = [pd.offsets.MonthBegin().rollback(t)+pd.Timedelta(14, 'd') for t in ds['time'].values]

In [None]:
# ds.isel(time=12).plot.imshow(robust=True)

In [None]:
# ds.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/data/PET_GLEAM_1980_2023.nc')