# Compare CSIRO above ground biomass plots with AusEFlux

Or, compare with Fluxcom-RS which has a spatial resolution of ~8km.  Fluxcom data is only for NEE.

All data and outputs and code are on google drive

## Import packages

In [None]:
import os
import sys
import folium
import odc.geo.xr
import numpy as np
import xarray as xr
import pandas as pd
import seaborn as sb
import geopandas as gpd
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs

## Analysis Parameters

Variables to adjust depending on analysis required

In [None]:
var = 'NEE' # 'NEE', 'ER', or 'GPP'
suffix = 'v1.1' #version of AusEFlux
product = 'AusEFlux' # 'AusEFlux' or 'Fluxcom'

# date range to extract from AusEflux (or fluxcom)
start_date = '2003-01'
end_date = '2021-12' #fluxcom stops at 2015

#locations of files
auseflux_flux_loc = '/g/data/os22/chad_tmp/NEE_modelling/results/predictions/AusEFlux_'+var+'_2003_2022_1km_quantiles_'+suffix+'.nc'
fluxcom_flux_loc = '/g/data/os22/chad_tmp/NEE_modelling/data/FLUXCOM/NEE_rs.nc'
input_loc = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/'
output_loc = '/g/data/os22/chad_tmp/climate-carbon-interactions/results/hannah_csiro_plots/'

## Import data

### Import gridded flux data

In [None]:
if product == 'AusEFlux':
    ds = assign_crs(xr.open_dataset(auseflux_flux_loc), crs='epsg:4326')[var+'_median']
    print(ds)

if product =='Fluxcom':
    ds = assign_crs(xr.open_dataset(fluxcom_flux_loc), crs='epsg:4326')[var]
    print(ds)

## Import CSIRO plots

And convert to geopandas to allow for exploring interactively on a map

In [None]:
#open as pandas dataframe
sites = pd.read_csv(input_loc+'CSIRO_plot_locations.csv')

#convert to geodataframe
gdf_sites = gpd.GeoDataFrame(
    sites, geometry=gpd.points_from_xy(sites.Long, sites.Lat), crs="EPSG:4326"
)

## Exploratory Interactive plots

Uncomment if interested

In [None]:
#plot the site data over a basemap

# gdf_sites.explore(column='Plot number',
#                   attr='Esri',
#                  tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}')

In [None]:
#plot the first timestep of AusEFlux on a basemap

# # # Create folium Map (ipyleaflet is also supported)
# m = folium.Map(control_scale = True)

# # Plot each sample image with different colormap
# ds.isel(time=0).odc.add_to(m, cmap='RdBu_r', robust=True)


# # Zoom map to Australia
# m.fit_bounds(ds.isel(time=0).odc.map_bounds())

# tile = folium.TileLayer(
#         tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#         attr = 'Esri',
#         name = 'Esri Satellite',
#         overlay = True,
#         control = True
#        ).add_to(m)

# folium.LayerControl().add_to(m)
# display(m)

## Extract pixel values over plot sites

Two sites: `'Olive Creek'` and `'Downfall'` are adjacent permnanent waterbodies, and are thus masked out in AusEflux.

In [None]:
td = []
for i in range(0, len(sites)):
    # indexing spatiotemporal values at CSIRO sites
    idx=dict(latitude=sites.iloc[i]['Lat'],
             longitude=sites.iloc[i]['Long'])

    # grab nearest pixel
    da = ds.sel(idx, method='nearest').sel(time=slice(start_date, end_date)) 
    
    # convert to dataframe and rename
    da = da.rename(var).to_dataframe().drop(['longitude', 'latitude', 'spatial_ref'], axis=1)
    
    #convert fluxcom to /month instead of /day
    if product == 'Fluxcom':
        da[var] = da[var]*30
    
    # add site id
    da['plot_number'] = sites.iloc[i]['Plot number']
    da['plot_name'] = sites.iloc[i]['Plot name']
    td.append(da)

# 'dropna' will remove the two sites beside water
ts = pd.concat(td)
ts.head()

## Save full timeseries to csv

In [None]:
ts.to_csv(output_loc+var+'_'+product+'_CSIRO_plots.csv')

## Generate some plots

### Box plots of annual mean flux

In [None]:
df_annual_mean = ts[[var, 'plot_number']].groupby('plot_number').resample('Y').mean()

In [None]:
sb.set(font_scale=1.5)

fig, ax =  plt.subplots(1,1,figsize=(20,5), sharex=True)
sb.boxplot(x='plot_number',
           y=var,
           data=df_annual_mean.reset_index(),
           ax=ax,palette='Spectral')
ax.yaxis.grid(True) # Hide the horizontal gridlines
ax.xaxis.grid(True) # Show the vertical gridlines
ax.set_ylabel('')
ax.set_xlabel('')
if var == 'NEE':
    ax.axhline(0, c='grey', linestyle='--')
fig.supylabel(var+' (gC m\N{SUPERSCRIPT TWO} mon⁻¹)', fontsize=16);
plt.title(product+' Annual Mean '+var+' Fluxes at CSIRO Sites')
plt.tight_layout()

fig.savefig(output_loc+product+'_annual_mean_boxplot_'+var+'.png',
            bbox_inches='tight', dpi=300)

### Annual Cumulative NEE flux boxplots

If plotting NEE. 

To convert to total flux (rather than flux per m2), we need to multiply the fluxes by the area of the pixel size (in AusEflux this is ~1km or ~5km depending on the dataset, and for Fluxcom-RS this is ~8km) used. To do this, we need to reproject the geobox to albers equal area.

In [None]:
if var=='NEE':
    #find area of pixel in m2 - first convert to equal area grid
    grid = ds.odc.geobox.to_crs('EPSG:3577')
    area_per_pixel = grid.resolution.x**2

    #find annual sums
    df_annual_sum = ts[[var, 'plot_number']].groupby('plot_number').resample('Y').sum().drop('plot_number',axis=1)

    #Convert to total annual flux in Megagrams of Carbon
    df_annual_sum = df_annual_sum * area_per_pixel * 1e-6

    #plot
    fig, ax =  plt.subplots(1,1,figsize=(20,5), sharex=True)
    sb.boxplot(x='plot_number',
               y=var,
               data=df_annual_sum.reset_index(),
               ax=ax,palette='Spectral')
    ax.yaxis.grid(True)
    ax.xaxis.grid(True) 
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.axhline(0, c='grey', linestyle='--')
    fig.supylabel(var+' (MgC year⁻¹)', fontsize=16);
    plt.title(product+' Annual '+var+' Total Flux at CSIRO Sites')
    plt.tight_layout()
    fig.savefig(output_loc+product+'_annual_sum_boxplot_'+var+'.png',
                bbox_inches='tight', dpi=300)

## Tropical Eddy Covariance Flux Towers

Extract from TERN THREDDS and summarise

FletchersView is only a year old so its commented out

### Open EC data

In [None]:
robson_creek_loc = 'https://dap.tern.org.au/thredds/dodsC/ecosystem_process/ozflux/RobsonCreek/2022_v2/L6/default/RobsonCreek_L6_20130801_20220816_Monthly.nc'
cape_tribulation_loc = 'https://dap.tern.org.au/thredds/dodsC/ecosystem_process/ozflux/CapeTribulation/2022_v2/L6/default/CapeTribulation_L6_20100101_20181102_Monthly.nc'
cowbay_loc = 'https://dap.tern.org.au/thredds/dodsC/ecosystem_process/ozflux/CowBay/2022_v2/L6/default/CowBay_L6_20090101_20220816_Monthly.nc'
# fletcher = 'https://dap.tern.org.au/thredds/dodsC/ecosystem_process/ozflux/Fletcherview/2022_v2/L6/default/FletcherviewTropicalRangeland_L6_20220122_20220712_Monthly.nc'

#open sites
robson_creek = xr.open_dataset(robson_creek_loc)[var+'_SOLO']
cape_tribulation = xr.open_dataset(cape_tribulation_loc)[var+'_SOLO']
cowbay = xr.open_dataset(cowbay_loc)[var+'_SOLO']


### Process EC sites into annual means

Be sure to only calculate annual means in years with all data

In [None]:
arrs = []
for df, name in zip([robson_creek, cape_tribulation, cowbay], ['RobsCreek', 'CapeTrib', 'Cowbay']):
    #find years with 12 months of data
    years_with_all_months = (df.groupby('time.year').count()==12).values
    df = df.resample(time='Y').mean()
    df = df.where(years_with_all_months)
    df = df.to_dataframe()
    df['plot_name'] = name
    df = df.dropna()
    arrs.append(df)

df = pd.concat(arrs)

### Plot annual mean fluxes at EC towers 

In [None]:
sb.set(font_scale=1)

fig, ax =  plt.subplots(1,1,figsize=(5,4), sharex=True)
sb.boxplot(x='plot_name',
           y=var+'_SOLO',
           data=df,
           ax=ax,
           palette='Spectral')
ax.yaxis.grid(True)
ax.xaxis.grid(True)
ax.set_ylabel('')
ax.set_xlabel('')
if var == 'NEE':
    ax.axhline(0, c='grey', linestyle='--')
fig.supylabel(var+' (gC m\N{SUPERSCRIPT TWO} mon⁻¹)', fontsize=16);
plt.title('Eddy Covariance Annual Mean '+var+' Fluxes')
plt.tight_layout()
fig.savefig(output_loc+'EC_annual_mean_boxplot_'+var+'.png',
            bbox_inches='tight', dpi=300)

## Calculate annual means & trends over the entire tropical forest areas in northern Qld 

How to do this?
* First delineate the tropical forests areas in that region:
    1. Grab 'tropics' biome
    2. Limit to Northern Qld (~Townsville)
    3. Threshold kNDVI to find thick, green forests on coastal fringe?
    4. Visually validate with interactive plot
* Calculate zonal annual mean fluxes
* Calculate trends in NEE
* Calculate Trends in NDVI (use Sami's product) 

In [None]:
import contextily as cx

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.spatial import xr_rasterize

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _collect_prediction_data import round_coords
from dask_utils import start_local_dask


In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

In [None]:
var='NEE'
suffix = 'v1.1' #version of AusEFlux

In [None]:
# ndvi = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/MCD43_AVHRR_NDVI_hybrid_EasternOzWoody.nc')
# ndvi = ndvi['ndvi_mcd_pred'].rename({'x':'longitude', 'y':'latitude'})
# # ndvi = round_coords(ndvi)
# ndvi = assign_crs(ndvi, crs='4326')

# ndvi = xr.open_dataarray('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/NDVI_5km_monthly_1982_2022.nc')
ndvi = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/1km/kNDVI_1km_monthly_2002_2022.nc',
                         chunks=dict(latitude=500, longitude=500)).sel(time=slice('2003', '2021'))

#locations of files
auseflux_flux_loc='/g/data/os22/chad_tmp/NEE_modelling/results/predictions/AusEFlux_'+var+'_2003_2022_1km_quantiles_'+suffix+'.nc'
ds = assign_crs(xr.open_dataset(auseflux_flux_loc, chunks=dict(latitude=500, longitude=500)), crs='epsg:4326')[var+'_median']
ds = ds.sel(time=slice('2003', '2021'))

#limit regions to tropics ~north of Twownsville
ndvi = ndvi.sel(latitude=slice(0, -19.3))
ndvi = ndvi.sel(longitude=slice(142,147.5))
ds = ds.sel(latitude=slice(0, -19.3))
ds = ds.sel(longitude=slice(142,147.5))

#calculate minimum NDVI through the 20 years
ndvi_min = ndvi.min('time').compute()

#threshold minimum to get 
tf_min = ndvi_min.where(ndvi_min>0.375)
tf_mask = xr.where(tf_min>0, 1, 0)

## Linear trend (2002-2022)

In [None]:
import dask.array as da
from dask.delayed import delayed
from scipy import stats

def _calc_slope(y):
    """return linear regression statistical variables"""
    mask = np.isfinite(y)
    x = np.arange(len(y))
    return stats.linregress(x[mask], y[mask])

# regression function defition
def regression(y):
    """apply linear regression function along time axis"""
    axis_num = y.get_axis_num('time')
    return da.apply_along_axis(_calc_slope, axis_num, y)

def linregress(ds):

    # fill pixels that are all-NaNs
    allnans = ds.isnull().all('time')
    ds = ds.where(~allnans, other=0)

    # regression analysis
    delayed_objs = delayed(regression)(ds).persist()

    # transforms dask.delayed to dask.array
    results = da.from_delayed(delayed_objs, shape=(5, ds.shape[1:][0], ds.shape[1:][1]), dtype=np.float32)
    results = results.compute()
    results = results.compute() #need this twice haven't figured out why

    # statistical variables definition
    variables = ['slope', 'intercept', 'r_value', 'p_value', 'std_err']

    # coordination definition
    coords = {'latitude': ds.latitude, 'longitude': ds.longitude}

    # output xarray.Dataset definition
    ds_out = xr.Dataset(
        data_vars=dict(slope=(["latitude", "longitude"], results[0]),
                       intercept=(["latitude", "longitude"], results[1]),
                       r_value=(["latitude", "longitude"], results[2]),
                       p_value=(["latitude", "longitude"], results[3]),
                       std_err=(["latitude", "longitude"], results[4]),
                      ),
        coords = coords)

    #remask all-NaN pixel
    return ds_out.where(~allnans)

In [None]:
ndvi_trend = linregress(ndvi.resample(time='1Y').mean())
ndvi_trend = ndvi_trend.compute()

ds_trend = linregress(ds.resample(time='1Y').mean())
ds_trend = ds_trend.compute()

In [None]:
# ts = ds.where(tf_mask).mean(['latitude', 'longitude']).compute()
ts_1 = ndvi.where(tf_mask).mean(['latitude', 'longitude']).compute()

In [None]:
ndvi_trend_tf = ndvi_trend.where(tf_mask)
ds_trend_tf = ds_trend.where(tf_mask)

In [None]:
fig,ax=plt.subplots(1,2, figsize=(12,12))
im = ndvi_trend_tf.sel(longitude=slice(142,147.5)).slope.plot(
                                                            ax=ax[0],
                                                            cmap='BrBG',
                                                            vmin=-0.005, vmax=0.005,
                                                            add_colorbar=False,
                                                            add_labels=False
                                                           )
plt.colorbar(im, shrink=0.5, label=u'ΔkNDVI / yr')
cx.add_basemap(ax[0], crs='epsg:4326',
               # source=cx.providers.OpenStreetMap.Mapnik,
               # source=cx.providers.Stamen.Terrain,
               # source=cx.providers.Stamen.TonerLite,
               source=cx.providers.Stamen.TerrainBackground,
               # source=cx.providers.CartoDB.Voyager,
               # source=cx.providers.CartoDB.Positron,
               # source=cx.providers.OpenTopoMap,
               attribution_size=1
              )
ax[0].set_title('kNDVI Trend 2003-2021')

im1 = ds_trend_tf.slope.plot(
                            ax=ax[1],
                            cmap='BrBG_r',
                            vmin=-0.5, vmax=0.5,
                            add_colorbar=False,
                            add_labels=False
                           )
plt.colorbar(im1, shrink=0.5, label=u'ΔNEE gC m\N{SUPERSCRIPT TWO} mon⁻¹ / yr')
cx.add_basemap(ax[1], crs='epsg:4326',
               source=cx.providers.Stamen.Terrain,
               attribution_size=1
              )
ax[1].set_title('AusEFlux NEE Trend 2003-2021')
plt.tight_layout()
fig.savefig(output_loc+'trends_tropical_forests.png',
            bbox_inches='tight', dpi=300);

In [None]:
# # # Create folium Map (ipyleaflet is also supported)
# m = folium.Map(control_scale = True)

# # Plot each sample image with different colormap
# tf_min.odc.add_to(m)
# # tf_std.odc.add_to(m)


# # Zoom map to Australia
# m.fit_bounds(tf_min.odc.map_bounds())

# tile = folium.TileLayer(
#         tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#         attr = 'Esri',
#         name = 'Esri Satellite',
#         overlay = True,
#         control = True
#        ).add_to(m)

# folium.LayerControl().add_to(m)
# display(m)