# Per pixel phenology modelling for Australia

This is very compute heavy and we can only return summaries (long term average, and/or trends etc.) as every pixel has a different length of seasons across the 40+ year archive, but it works and gives robust phenometrics for Aus.  Australia is split into eight tiles and each tile is run sequentially.

In this notebook, all aspects of the analysis are run (phenology, trends, partial least squares regressions), but which analysis is run is configurable.

Use a large local dask cluster, recommend `normalsr` queue and `104 cpus 496 GiB`, will take about 10 hours to loop through the 8 tiles.

Some references for optimising processing:
* https://github.com/NCI900-Training-Organisation/Distributed-Dask-Cluster-on-Gadi
* https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask

In [None]:
%matplotlib inline

import os
import sys
import dask
import scipy
import warnings
import dask.array
import numpy as np
import xarray as xr
import pandas as pd
from scipy import stats
from dask import delayed

import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs

import sys
sys.path.append('/g/data/os22/chad_tmp/Aus_phenology/src')
from phenology_pixel import _preprocess, xr_phenometrics, phenology_trends, _mean, regression_attribution, IOS_analysis

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import round_coords

%load_ext autoreload
%autoreload 2

## Dask cluster

Local or Dynamic?

Dyanamic can be fickle so stick with local for now

In [None]:
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask
start_local_dask(n_workers=100, threads_per_worker=1, memory_limit='400GiB')

In [None]:
# from dask.distributed import Client
# from dask_jobqueue import PBSCluster

# cpus=13
# mem='60GB'
# extra = ['-q normalsr',
#          '-P w97', 
#          '-l ncpus='+str(cpus), 
#          '-l mem='+mem,
#         '-l storage=gdata/os22+gdata/w97'
#         ]
# setup_commands = ["module load python3/3.10.0", "source /g/data/os22/chad_tmp/AusENDVI/env/py310/bin/activate"]

# cluster = PBSCluster(walltime="01:00:00", 
#                      cores=cpus,
#                      processes=cpus,
#                      memory=mem,
#                      shebang='#!/usr/bin/env bash',
#                      job_extra_directives=extra, 
#                      local_directory='/g/data/os22/chad_tmp/Aus_phenology/data', 
#                      job_directives_skip=["select"], 
#                      interface="ib0",
#                      job_script_prologue=setup_commands,
#                     )

# print(cluster.job_script())
# cluster.scale(jobs=2)
# client = Client(cluster)
# client

# client.shutdown()
# cluster.close()

## Analysis Parameters

Which aspects of the analysis should be run?

In [None]:
average = True
trends = True
ios = True
regression = True
regress_var = 'vPOS'
model_type='delta_slope'
modelling_vars=['co2', 'srad', 'rain', 'tavg', 'vpd']

results_path = '/g/data/os22/chad_tmp/Aus_phenology/results/combined_tiles/'

## Open data

In [None]:
ds = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/NDVI_smooth_AusENDVI-clim_MCD43A4.nc')['NDVI']
covariables =  xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/covars.nc')
covariables = covariables.drop_vars('wcf')

#testing slices
ds = ds.isel(latitude=slice(200,352), longitude=slice(50,302)) 
covariables = covariables.isel(latitude=slice(200,352), longitude=slice(50,302))

## Split data into tiles

Running all of Aus just takes too long, >500,000 pixels * > 14,000 time steps - dask graph is huge

In [None]:
# Function to split into spatial tiles
def split_spatial_tiles(data_array, lat_dim='latitude', lon_dim='longitude', n_lat=2, n_lon=4):
    lat_size = data_array.sizes[lat_dim] // n_lat
    lon_size = data_array.sizes[lon_dim] // n_lon
    
    tiles = []
    for i in range(n_lat):
        for j in range(n_lon):
            tile = data_array.isel({
                lat_dim: slice(i * lat_size, (i + 1) * lat_size),
                lon_dim: slice(j * lon_size, (j + 1) * lon_size)
            })
            tiles.append(tile)
    
    return tiles

# Split data into spatial tiles (2 latitude x 4 longitude)
tiles = split_spatial_tiles(ds, n_lat=2, n_lon=4)
covars_tiles = split_spatial_tiles(covariables, n_lat=2, n_lon=4)

#verify no overlaps or missing pixels.
assert np.sum(xr.combine_by_coords(tiles).longitude == ds.longitude) == len(ds.longitude)
assert np.sum(xr.combine_by_coords(tiles).latitude == ds.latitude) == len(ds.latitude)

# create named dictonary
tile_names=['NW', 'NNW', 'NNE', 'NE',
            'SW', 'SSW', 'SSE', 'SE']
tiles_dict = dict(zip(tile_names, tiles))
covars_tiles_dict = dict(zip(tile_names, covars_tiles))

#create a plot to visualise tiles
# fig,axes = plt.subplots(2, 4, figsize=(10,8))
# for t,ax in zip(tiles, axes.ravel()):
#     t.isel(time=range(0,20)).mean('time').plot(ax=ax, add_colorbar=False, add_labels=False)
#     ax.set_title(None);

## Per pixel phenometrics with dask.delayed

Loop through the eight tiles and compute the time series of phenometerics, the average phenometrics, the trends in phenometrics, and PLS regression modelling.

The tiles can be combined thereafter to have our continental per pixel phenology analysis.

In [None]:
for (n,d),(nn,dd) in zip(tiles_dict.items(), covars_tiles_dict.items()):

    #first lets check if the analysis has already been done
    if os.path.exists(f'{results_path}attribution_{regress_var}_{model_type}_perpixel_{n}.nc'):
        continue
    else:
        
        print('Working on tile: '+ n)
    
        # transform the data and return all the objects we need. This code smooth and
        # interpolates the data, then stacks the pixels into a spatial index
        d, dd, Y, idx_all_nan, nan_mask, shape = _preprocess(d, dd)
        
        # Open templates array which we'll use whenever we encounter an all-NaN index
        # This speeds up the analysis by not running pixels that are empty.
        # Created the template using one of the output results
        # bb = xr.full_like(results[0], fill_value=np.nan, dtype='float32')
        template_path='/g/data/os22/chad_tmp/Aus_phenology/data/templates/'
        phen_template = xr.open_dataset(f'{template_path}template.nc')
        ios_template = xr.open_dataset(f'{template_path}template_IOS.nc')
        regress_template = xr.open_dataset(f'{template_path}template_{model_type}.nc').sel(feature=modelling_vars)
    
        #now we start the real proceessing
        results=[]
        for i in range(shape[1]): #loop through all spatial indexes.
        
            #select pixel
            data = Y.isel(spatial=i)
            
            # First, check if spatial index has data. If its one of 
            # the all-NaN indexes then return xarray filled with -99 values
            if i in idx_all_nan:
                xx = phen_template.copy() #use our template    
                xx['latitude'] = [data.latitude.values.item()] #update coords
                xx['longitude'] = [data.longitude.values.item()]
            
            else:
                #run the phenometrics
                xx = xr_phenometrics(data,
                                  rolling=90,
                                  distance=90,
                                  prominence='auto',
                                  plateau_size=10,
                                  amplitude=0.20
                                 )
        
            #append results, either data or all-zeros
            results.append(xx)
        
        # bring into memory.
        results = dask.compute(results)[0]
        
        ### ----Summarise phenology with a median------------
        ## This is one of the slowest steps.
        if average:
            if os.path.exists(f'{results_path}mean_phenology_perpixel_{n}.nc'):
                pass
            else:
                p_average = [_mean(x) for x in results]
                p_average = dask.compute(p_average)[0]
                p_average = xr.combine_by_coords(p_average)
                
                #remove NaN areas that have a fill value
                p_average = p_average.where(p_average>-99).astype('float32')
                p_average = p_average.where(~np.isnan(p_average.vPOS)) #and again for the n_seasons layer
                p_average = assign_crs(p_average, crs='EPSG:4326') # add geobox
            
                #export results
                p_average.to_netcdf(f'{results_path}mean_phenology_perpixel_{n}.nc')
    
        ## ----Find the trends in phenology--------------
        #now find trends in phenometrics
        if trends:
            if os.path.exists(f'{results_path}trends_phenology_perpixel_{n}.nc'):
                pass
            else:
                trend_vars = ['POS','vPOS','TOS','vTOS','AOS','SOS','vSOS','EOS',
                            'vEOS','LOS','IOS','ROG','ROS','LOS*vPOS','IOS:(LOS*vPOS)']
                p_trends = [phenology_trends(x, trend_vars) for x in results]
                p_trends = dask.compute(p_trends)[0]
                p_trends = xr.combine_by_coords(p_trends)
                
                #remove NaNs
                p_trends = p_trends.where(~np.isnan(p_average.vPOS)).astype('float32')
            
                # assign crs and export
                p_trends = assign_crs(p_trends, crs='EPSG:4326')
                p_trends.to_netcdf(f'{results_path}trends_phenology_perpixel_{n}.nc')
    
        # ----Partial correlation analysis etc. on IOS -----------------
        if ios:
            if os.path.exists(f'{results_path}IOS_analysis_perpixel_{n}.nc'):
                pass
            else:
                p_parcorr = []
                for pheno in results:            
                    corr = IOS_analysis(pheno, ios_template)
                    p_parcorr.append(corr)
                
                p_parcorr = dask.compute(p_parcorr)[0]
                p_parcorr = xr.combine_by_coords(p_parcorr).astype('float32')
                p_parcorr.to_netcdf(f'{results_path}IOS_analysis_perpixel_{n}.nc')
    
        # -----PLS regression-------------------------------
        if regression:
            if os.path.exists(f'{results_path}attribution_{regress_var}_{model_type}_perpixel_{n}.nc'):
                pass
            else:
                p_attribution = []
                for pheno in results:
                    lat = pheno.latitude.item()
                    lon = pheno.longitude.item()
                    
                    fi = regression_attribution(pheno,
                                       X=dd.sel(latitude=lat, longitude=lon),
                                       template=regress_template,
                                       model_type=model_type,
                                       pheno_var=regress_var,
                                       modelling_vars=modelling_vars,
                                      )
                    p_attribution.append(fi)
                
                p_attribution = dask.compute(p_attribution)[0]
                p_attribution = xr.combine_by_coords(p_attribution).astype('float32')
                p_attribution.to_netcdf(f'{results_path}attribution_{regress_var}_{model_type}_perpixel_{n}.nc')
            
            # #do an extra one----------------------------------------------------------------------------------
            # extra_template = xr.open_dataset(f'{template_path}template_PCMCI.nc').sel(feature=modelling_vars)
            # if os.path.exists(f'{results_path}attribution_{regress_var}_PCMCI_perpixel_{n}.nc'):
            #             pass
            # else:
            #     p_attribution = []
            #     for pheno in results:
            #         lat = pheno.latitude.item()
            #         lon = pheno.longitude.item()
                    
            #         fi = regression_attribution(pheno,
            #                            X=dd.sel(latitude=lat, longitude=lon),
            #                            template=extra_template,
            #                            model_type='PCMCI',
            #                            pheno_var=regress_var,
            #                            modelling_vars=modelling_vars,
            #                           )
            #         p_attribution.append(fi)
                
            #     p_attribution = dask.compute(p_attribution)[0]
            #     p_attribution = xr.combine_by_coords(p_attribution).astype('float32')
            #     p_attribution.to_netcdf(f'{results_path}attribution_{regress_var}_PCMCI_perpixel_{n}.nc')
    break

## testing plots

In [None]:
p_attribution = xr.open_dataset(f'{results_path}attribution_{regress_var}_delta_slope_perpixel_{n}.nc')

In [None]:
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _prediction import allNaN_arg

In [None]:
fig,axes=plt.subplots(1,5,figsize=(18,4))
for v,ax in zip(p_attribution.feature.values,axes.ravel()):
    p_attribution.sel(feature=v).delta_slope.plot(add_labels=False,ax=ax,cmap='viridis')
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])
    ax.set_title(v)

In [None]:
ss = allNaN_arg(np.abs(p_attribution.delta_slope), dim='feature',stat='max', idx=False)
im = ss.plot(add_colorbar=False, figsize=(5,8), add_labels=False)
cbar = fig.colorbar(im, ticks=[0,1,2,3,4], orientation='horizontal')
cbar.ax.set_xticklabels(list(p_attribution.feature.values));

In [None]:
p_corr = xr.open_dataset(f'{results_path}IOS_analysis_perpixel_{n}.nc')[['vPOS_parcorr', 'LOS_parcorr']]

In [None]:
fig,axes=plt.subplots(1,2,figsize=(7,3))
for v,ax in zip(p_corr.data_vars,axes.ravel()):
    p_corr[v].plot(add_labels=False,ax=ax,cmap='viridis', vmin=0, vmax=1)
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])
    ax.set_title(v)

In [None]:
fig,ax=plt.subplots(1,1,figsize=(4,6))
im = allNaN_arg(p_corr.to_array(), dim='variable',stat='max', idx=False).plot(ax=ax, add_colorbar=False)
cbar = fig.colorbar(im, ticks=[0,1], orientation='horizontal')
cbar.ax.set_xticklabels(list(p_corr.data_vars));

In [None]:
# p_trends.vPOS_slope.plot()

In [None]:
        # ### --------Handle NaNs---------
        # # Due to issues with xarray quadratic interpolation, we need to remove
        # # every NaN or else the daily interpolation function will fail
        
        # ##remove last ~6 timesteps that are all-NaN (from S-G smoothing).
        # times_to_keep = d.mean(['latitude','longitude']).dropna(dim='time',how='any').time
        # d = d.sel(time=times_to_keep)
        
        # #Find where NaNs are >10 % of data, will use this mask to remove pixels later.
        # #and include any NaNs in the climate data.
        # ndvi_nan_mask = np.isnan(d).sum('time') >= len(d.time) / 10
        # clim_nan_mask = dd[['rain','vpd','tavg','srad']].to_array().isnull().any('variable')
        # clim_nan_mask = (clim_nan_mask.sum('time')>0)
        # nan_mask = (clim_nan_mask | ndvi_nan_mask)

        # d = d.where(~nan_mask)
        # dd = dd.where(~nan_mask)
        # # nan_mask.to_netcdf(f'/g/data/os22/chad_tmp/Aus_phenology/data/ndvi_tiles/nan_mask_{n}.nc')
        
        # #fill the mostly all NaN slices with a fill value
        # d = xr.where(nan_mask, -99, d)
        
        # #interpolate away any remaining NaNs
        # d = d.interpolate_na(dim='time', method='cubic', fill_value="extrapolate")
        
        # #now we can finally interpolate to daily
        # d = d.resample(time='1D').interpolate(kind='quadratic').astype('float32')
        
        # # We also need the shape of the stacked array
        # shape = d.stack(spatial=('latitude', 'longitude')).values.shape
        
        # #stack spatial indexes, this makes it easy to loop through data
        # y_stack = d.stack(spatial=('latitude', 'longitude'))
        # Y = y_stack.transpose('time', 'spatial')
        
        # # x_stack = dd.stack(spatial=('latitude', 'longitude'))
        # # X = x_stack.transpose('time', 'spatial')
        
        # # find spatial indexes where values are mostly NaN (mostly land-sea mask)
        # # This is where the nan_mask we created earlier = True
        # idx_all_nan = np.where(nan_mask.stack(spatial=('latitude', 'longitude'))==True)[0]