# Per pixel phenology across Australia

This is very compute heavy and we can only return summaries (long term average, and/or trends), but it works and gives robust phenometrics for Aus.

Use a large local dask cluster, recommend `normalsr` queue and `104 cpus 496 GiB`, will take about 10 hours to loop through the 8 tiles.

In [None]:
%matplotlib inline

import os
import sys
import warnings

import scipy
import numpy as np
import xarray as xr
import pandas as pd
from scipy import stats

import dask
import dask.array
from dask import delayed

import seaborn as sb
import contextily as ctx
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from odc.geo.xr import assign_crs

import sys
sys.path.append('/g/data/os22/chad_tmp/Aus_phenology/src')
from phenology_pixel import xr_phenometrics, phenology_trends, _mean

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import round_coords

%load_ext autoreload
%autoreload 2

## Dask cluster

Local or Dynamic?

Dyanamic can be fickle so stick with local for now

https://github.com/NCI900-Training-Organisation/Distributed-Dask-Cluster-on-Gadi

In [None]:
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask
start_local_dask(n_workers=102, threads_per_worker=1, memory_limit='440GiB')

In [None]:
# from dask.distributed import Client
# from dask_jobqueue import PBSCluster

# cpus=52
# mem='240GB'
# extra = ['-q normalsr',
#          '-P w97', 
#          '-l ncpus='+str(cpus), 
#          '-l mem='+mem,
#         '-l storage=gdata/os22+gdata/w97'
#         ]
# setup_commands = ["module load python3/3.10.0", "source /g/data/os22/chad_tmp/AusENDVI/env/py310/bin/activate"]

# cluster = PBSCluster(walltime="01:00:00", 
#                      cores=cpus,
#                      processes=cpus,
#                      memory=mem,
#                      shebang='#!/usr/bin/env bash',
#                      job_extra_directives=extra, 
#                      local_directory='/g/data/os22/chad_tmp/Aus_phenology/data', 
#                      job_directives_skip=["select"], 
#                      interface="ib0",
#                      job_script_prologue=setup_commands,
#                     )

# # print(cluster.job_script())
# cluster.scale(jobs=1)
# client = Client(cluster)
# client

# client.shutdown()
# cluster.close()

## Open data

In [None]:
ds = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/NDVI_smooth_AusENDVI-clim_MCD43A4.nc')['NDVI']

#testing slices
# ds = ds.isel(latitude=slice(200,352), longitude=slice(50,302))
# ds = ds.isel(latitude=slice(400,425), longitude=slice(100,125))

### Split data into tiles

Running all of Aus just takes too long, >500,000 pixels * > 14,000 time steps - dask graph is huge

In [None]:
# Function to split into spatial tiles
def split_spatial_tiles(data_array, lat_dim='latitude', lon_dim='longitude', n_lat=2, n_lon=4):
    lat_size = data_array.sizes[lat_dim] // n_lat
    lon_size = data_array.sizes[lon_dim] // n_lon
    
    tiles = []
    for i in range(n_lat):
        for j in range(n_lon):
            tile = data_array.isel({
                lat_dim: slice(i * lat_size, (i + 1) * lat_size),
                lon_dim: slice(j * lon_size, (j + 1) * lon_size)
            })
            tiles.append(tile)
    
    return tiles


# Split data into spatial tiles (2 latitude x 4 longitude)
tiles = split_spatial_tiles(ds, n_lat=2, n_lon=4)

#verify no overlaps or missing pixels.
assert np.sum(xr.combine_by_coords(tiles).longitude == ds.longitude) == len(ds.longitude)
assert np.sum(xr.combine_by_coords(tiles).latitude == ds.latitude) == len(ds.latitude)

# create named dictonary
tile_names=['NW', 'NNW', 'NNE', 'NE',
            'SW', 'SSW', 'SSE', 'SE']
tiles_dict = dict(zip(tile_names, tiles))

#create a plot to visualise tiles
fig,axes = plt.subplots(2, 4, figsize=(10,8))
for t,ax in zip(tiles, axes.ravel()):
    t.isel(time=range(0,20)).mean('time').plot(ax=ax, add_colorbar=False, add_labels=False)
    ax.set_title(None);

## Per pixel phenometrics with dask.delayed

Loop through the eight tiles and compute the phenometerics, the average phenometrics, and the trends in phenometrics.

The tiles can be combined therafter to have our continental per pixel phenology.

https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask

In [None]:
for n,d in tiles_dict.items():
    
    if os.path.exists(f'/g/data/os22/chad_tmp/Aus_phenology/results/pheno_tiles/aus_trends_phenology_perpixel_{n}.nc'):
        continue
    
    else:
    
        print('Working on tile: '+ n)
        
        ### --------Handle NaNs---------
        # Due to issues with xarray quadratic interpolation, we need to remove
        # every NaN or else the daily interpolation function will fail
        
        ##remove last ~6 timesteps that are all-NaN (from S-G smoothing).
        times_to_keep = d.mean(['latitude','longitude']).dropna(dim='time',how='any').time
        d = d.sel(time=times_to_keep)
        
        #Find where NaNs are >10 % of data, will use this mask to remove pixels later.
        nan_mask = np.isnan(d).sum('time') >= len(d.time) / 10
        # nan_mask.to_netcdf(f'/g/data/os22/chad_tmp/Aus_phenology/data/ndvi_tiles/nan_mask_{n}.nc')
        
        #fill the mostly all NaN slices with a fill value
        d = xr.where(nan_mask, -99, d)
        
        #interpolate away any remaining NaNs
        d = d.interpolate_na(dim='time', method='cubic', fill_value="extrapolate")
        
        #now we can finally interpolate to daily
        d = d.resample(time='1D').interpolate(kind='quadratic').astype('float32')
        
        #export so in the next step we can import the array with dask.delayed
        # d.to_netcdf(f'/g/data/os22/chad_tmp/Aus_phenology/data/ndvi_tiles/ndvi_{n}.nc')
    
        ### --------Calculate the phenometrics on each pixel--------
        #                Paralleized with dask.delayed.
        
        # Lazily open the NDVI data
        # path=f'/g/data/os22/chad_tmp/Aus_phenology/data/ndvi_tiles/ndvi_{n}.nc'
        # da=dask.delayed(xr.open_dataarray)(path)
    
        # We also need the shape of the stacked array
        shape = d.stack(spatial=('latitude', 'longitude')).values.shape
        
        #stack spatial indexes, this makes it easy to loop through data
        y_stack = d.stack(spatial=('latitude', 'longitude'))
        Y = y_stack.transpose('time', 'spatial')
        
        #find spatial indexes where values are mostly NaN (mostly land-sea mask)
        # This is where the nan_mask we created earlier = True
        idx_all_nan = np.where(nan_mask.stack(spatial=('latitude', 'longitude'))==True)[0]
        
        # open template array which we'll use 
        # whenever we encounter an all-NaN index.
        # Created the template using one of the output results
        # bb = xr.full_like(results[0], fill_value=-99, dtype='float32')
        template_path='/g/data/os22/chad_tmp/Aus_phenology/data/template.nc'
        ds_template = xr.open_dataset(template_path)
    
        #now we start the real proceessing
        results=[]
        for i in range(shape[1]): #loop through all spatial indexes.
        
            #select pixel
            data = Y.isel(spatial=i)
            
            # First, check if spatial index has data. If its one of 
            # the all-NaN indexes then return xarray filled with -99 values
            if i in idx_all_nan:
                xx = ds_template.copy() #use our template    
                xx['latitude'] = [data.latitude.values.item()] #update coords
                xx['longitude'] = [data.longitude.values.item()]
            
            else:
                xx = xr_phenometrics(data,
                                  rolling=90,
                                  distance=90,
                                  prominence='auto',
                                  plateau_size=10,
                                  amplitude=0.20
                                 )
        
            #append results, either data or all-zeros
            results.append(xx)
        
        # #bring into memory this will take a long time
        # with warnings.catch_warnings(): #(can't suppress pandas warnings!)
        warnings.filterwarnings("ignore", category=FutureWarning,  module="pandas")
        results = dask.persist(results)[0]
        
        # ### ----Summarise phenology with a median------------
        #now we need to compute the average phenology
        p_average = [_mean(x) for x in results]
        p_average = dask.compute(p_average)[0]
        p_average = xr.combine_by_coords(p_average)
        
        #remove NaN areas that have a fill value
        p_average = p_average.where(p_average>-99).astype('float32')
        p_average = p_average.where(~np.isnan(p_average.vPOS)) #and again for the n_seasons layer
        p_average = assign_crs(p_average, crs='EPSG:4326') # add geobox
    
        #export results
        p_average.to_netcdf(f'/g/data/os22/chad_tmp/Aus_phenology/results/pheno_tiles/aus_mean_phenology_perpixel_{n}.nc')
    
        ### ----Find the trends in phenology--------------
        #now find trends in phenometrics
        trend_vars = ['POS','vPOS','TOS','vTOS','AOS','SOS',
                      'vSOS','EOS','vEOS','LOS','IOS','ROG','ROS']
        p_trends = [phenology_trends(x, trend_vars) for x in results]
        p_trends = dask.compute(p_trends)[0]
        p_trends = xr.combine_by_coords(p_trends)
        
        #remove NaNs
        p_trends = p_trends.where(~np.isnan(p_average.vPOS)).astype('float32')
        
        #assign crs and export
        p_trends = assign_crs(p_trends, crs='EPSG:4326')
        p_trends.to_netcdf(f'/g/data/os22/chad_tmp/Aus_phenology/results/pheno_tiles/aus_trends_phenology_perpixel_{n}.nc')

## Join tiles together

In [None]:
tiles_path = '/g/data/os22/chad_tmp/Aus_phenology/results/pheno_tiles/'
trend_tiles = [tiles_path+i for i in os.listdir(tiles_path) if 'trends' in i]
trend_tiles = [xr.open_dataset(t) for t in trend_tiles]

p_trends = xr.combine_by_coords(trend_tiles)
p_trends = assign_crs(p_trends,crs='EPSG:4326')

for var in p_trends.data_vars:
    del p_trends[var].attrs['grid_mapping']

In [None]:
mean_tiles = [tiles_path+i for i in os.listdir(tiles_path) if 'mean' in i]
mean_tiles = [xr.open_dataset(t) for t in mean_tiles]

p_average = xr.combine_by_coords(mean_tiles)
p_average = assign_crs(p_average,crs='EPSG:4326')

for var in p_average.data_vars:
    del p_average[var].attrs['grid_mapping']

In [None]:
# p_average.n_seasons.plot(cmap='RdYlBu', vmin=34, vmax=45)

## Mask urban, water, irrigated regions

In [None]:
crops = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/croplands_5km.nc')['croplands']
crops = xr.where(crops==2, 0, 1) #irrigated crops
crops = round_coords(crops)

urban = xr.open_dataarray('/g/data/os22/chad_tmp/AusEFlux/data/urban_mask_5km.nc').rename({'y':'latitude','x':'longitude'})
urban = ~urban

water = xr.open_dataarray('/g/data/os22/chad_tmp/Aus_phenology/data/NVISv6_5km.nc') ##24=inland water
water = xr.where(water==24,0,1)

In [None]:
mask_trends = (urban & crops & water)
mask_average = (urban & water) #long-term average is okay for irrigated

In [None]:
p_trends = p_trends.where(mask_trends)
p_average = p_average.where(mask_average)

In [None]:
# p_trends['vPOS_slope'].odc.explore(
#             tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#            attr = 'Esri',
#            name = 'Esri Satellite'
# )

## Export

In [None]:
p_average.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/results/aus_mean_phenology_perpixel.nc')
p_trends.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/results/aus_trends_phenology_perpixel.nc')