# Batch run analysis

Qsub each tile after breaking input datasets into 8 tiles.  Also, interactive testing below

In [None]:
%matplotlib inline

import os
import sys
import warnings
import numpy as np
import xarray as xr
import pandas as pd

import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs


## Open data

In [None]:
# NDVI data
ds_path = f'/g/data/os22/chad_tmp/AusENDVI/results/publication/AusENDVI-clim_MCD43A4_gapfilled_1982_2022_0.2.0.nc'
ds = assign_crs(xr.open_dataset(ds_path), crs='EPSG:4326')
ds = ds.rename({'AusENDVI_clim_MCD43A4':'NDVI'})
ds = ds['NDVI']
del ds.attrs['grid_mapping']

# GPP datasets: DIFFUSE, MODIS, AusEFlux
# ds_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/DIFFUSE_GPP_5km_2003_2021.nc'
# ds_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/MODIS_GPP_5km_2002_2021.nc'
# ds_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/AusEFlux_GPP_5km_2003_2023.nc'
# ds_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/PML_GPP_5km_2001_2023.nc'
# ds_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/AusEFlux_GPP_5km_1982_2022_v0.2.nc'
# ds_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/AusEFlux_versions/AusEFlux_GPP_5km_1982_2022_v0.4.nc'
# ds = assign_crs(xr.open_dataset(ds_path)['GPP'], crs='EPSG:4326')
# del ds.attrs['grid_mapping']

#covars
covariables =  xr.open_dataset('/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/covars.nc')
covariables = covariables.drop_vars(['wcf','smrz'])

#soil signal if using NDVI
ss_path = f'/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/ndvi_of_baresoil_5km.nc'
ss = assign_crs(xr.open_dataset(ss_path)['NDVI'], crs='epsg:4326')
ss.name = 'NDVI'

#---------testing slices-----------------------------------------
# ds = ds.isel(latitude=slice(200,352), longitude=slice(50,302)) 
# covariables = covariables.isel(latitude=slice(200,352), longitude=slice(50,302))
# ss = ss.isel(latitude=slice(200,352), longitude=slice(50,302))

## Split data into tiles

<!-- Running all of Aus just takes too long, >500,000 pixels * > 14,000 time steps - dask graph is huge! -->

In [None]:
# Function to split into spatial tiles
def split_spatial_tiles(data_array, lat_dim='latitude', lon_dim='longitude', n_lat=2, n_lon=4):
    lat_size = data_array.sizes[lat_dim] // n_lat
    lon_size = data_array.sizes[lon_dim] // n_lon
    
    tiles = []
    for i in range(n_lat):
        for j in range(n_lon):
            tile = data_array.isel({
                lat_dim: slice(i * lat_size, (i + 1) * lat_size),
                lon_dim: slice(j * lon_size, (j + 1) * lon_size)
            })
            tiles.append(tile)
    
    return tiles

# Split data into spatial tiles (2 latitude x 4 longitude)
tiles = split_spatial_tiles(ds, n_lat=2, n_lon=4)
covars_tiles = split_spatial_tiles(covariables, n_lat=2, n_lon=4)
ss_tiles = split_spatial_tiles(ss, n_lat=2, n_lon=4)

#verify no overlaps or missing pixels.
assert np.sum(xr.combine_by_coords(tiles).longitude == ds.longitude) == len(ds.longitude)
assert np.sum(xr.combine_by_coords(tiles).latitude == ds.latitude) == len(ds.latitude)

# create named dictonary
tile_names=['NW', 'NNW', 'NNE', 'NE',
            'SW', 'SSW', 'SSE', 'SE']
tiles_dict = dict(zip(tile_names, tiles))
covars_tiles_dict = dict(zip(tile_names, covars_tiles))
ss_tiles_dict = dict(zip(tile_names, ss_tiles))

#create a plot to visualise tiles
fig,axes = plt.subplots(2, 4, figsize=(10,8))
for t,ax in zip(tiles, axes.ravel()):
    t.isel(time=range(0,20)).mean('time').plot(ax=ax, add_colorbar=False, add_labels=False)
    ax.set_title(None);

## Export

In [None]:
for k,v in tiles_dict.items():
    v.to_netcdf(f'/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/tiles/{ds.name}_{k}.nc')

In [None]:
for k,v in covars_tiles_dict.items():
    v.to_netcdf(f'/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/tiles/COVARS_{k}.nc')

In [None]:
for k,v in ss_tiles_dict.items():
    del v.attrs['grid_mapping']
    v.to_netcdf(f'/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/tiles/SS_{k}.nc')

## Submit tiles to PBS job queue

In [None]:
import os

In [None]:
tiles = ["'NW'", "'NNW'", "'NNE'", "'NE'", "'SW'", "'SSW'", "'SSE'", "'SE'"]  
# tiles = ["'NNW'", "'NNE'", "'SE'", "'SSW'"] 
# tiles = ["'SW'"] 
os.chdir('/g/data/os22/chad_tmp/Aus_CO2_fertilisation/')
for t in tiles:
    os.system("qsub -v TILENAME="+t+" src/run_single_tile.sh")

In [4]:
!qstat

Job id                 Name             User              Time Use S Queue
---------------------  ---------------- ----------------  -------- - -----
136430877.gadi-pbs     sys-dashboard-s* cb3058            00:22:39 R normalsr-exec   
136447363.gadi-pbs     sys-dashboard-s* cb3058            01:45:41 R normalsr-exec   
136447454.gadi-pbs     run_single_tile* cb3058            02:02:22 R normalsr-exec   
136447455.gadi-pbs     run_single_tile* cb3058            03:09:32 R normalsr-exec   
136447456.gadi-pbs     run_single_tile* cb3058            02:05:47 R normalsr-exec   
136447457.gadi-pbs     run_single_tile* cb3058            04:21:25 R normalsr-exec   
136447458.gadi-pbs     run_single_tile* cb3058            03:27:15 R normalsr-exec   
136447459.gadi-pbs     run_single_tile* cb3058            03:31:16 R normalsr-exec   
136447460.gadi-pbs     run_single_tile* cb3058            02:47:10 R normalsr-exec   
136447461.gadi-pbs     run_single_tile* cb3058            02:11:33 R normals

## Run interactively instead

Good for testing etc.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask

sys.path.append('/g/data/os22/chad_tmp/Aus_CO2_fertilisation/src/')
from batch_run_analysis import attribution_etal

### Variables for script

In [None]:
model_var='GPP'#'NDVI'
n_workers=13
memory_limit='60GiB'
modelling_vars=['co2', 'srad', 'rain', 'tavg', 'vpd', 'cwd']
results_path = '/g/data/os22/chad_tmp/Aus_CO2_fertilisation/results/tiles/testing/'
template_path='/g/data/os22/chad_tmp/Aus_CO2_fertilisation/data/templates/'
model_types = ['delta_slope'] #'PLS', 'delta_slope', 'ML'

In [None]:
start_local_dask(
        n_workers=n_workers,
        threads_per_worker=1,
        memory_limit=memory_limit
                    )

In [None]:
tiles = ['NW','NNW', 'NNE', 'NE', 'SW', 'SSW', 'SSE', 'SE'] # 

for t in tiles:
    print(t)
    attribution_etal(
        n=t,
        results_path=results_path,
        template_path=template_path,
        modelling_vars=modelling_vars,
        model_var=model_var,
        model_types=model_types
    )
    break

In [None]:
import xarray as xr
import numpy as np
trends = xr.open_dataset(f'{results_path}GPP_trends_perpixel_NW.nc')
trends.slope.plot(robust=True, cmap='BrBG');

In [None]:
attr = xr.open_dataset(f'{results_path}attribution_delta_slope_perpixel_NW.nc')
attr['delta_slope'].sel(feature='co2').plot(robust=True)

In [None]:
beta = xr.open_dataset(f'{results_path}beta_coefficient_perpixel_NW.nc')

In [None]:
beta['beta_relative'].plot(robust=True)