# Predicting fluxes on grid

To do:
 
* Implement confidence intervals using `forestci` https://github.com/scikit-learn-contrib/forest-confidence-interval


In [None]:
import sys
import xarray as xr
import numpy as np
import pandas as pd
from joblib import load
from matplotlib import pyplot as plt
from odc.geo.geobox import zoom_out
from odc.algo import xr_reproject
from datacube.utils.dask import start_local_dask

# from dask.distributed import Client,Scheduler
# from dask_jobqueue import SLURMCluster

sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from collect_prediction_data import collect_prediction_data, round_coords

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

In [None]:
# cluster = SLURMCluster(processes=2, cores=2, memory="47GB", walltime='02:00:00')
# client = Client(cluster)
# cluster.scale(cores=18)

client = start_local_dask(mem_safety_margin='2Gb')

In [None]:
client

## Analysis Parameters

In [None]:
var = 'NEE'
results_name='NEE_2003_2021_5km_LGBM_lower.nc'
model_path = '/g/data/os22/chad_tmp/NEE_modelling/results/models/AUS_NEE_LGBM_model_lower.joblib'
mask_path = '/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2021.nc'
t1, t2='2003','2021'
rescale=False

## Open model

In [None]:
model = load(model_path).set_params(n_jobs=1)

In [None]:
model

## Open predictor data

In [None]:
#open data
data = collect_prediction_data(time_start=t1,
                             time_end=t2,
                             verbose=False,
                             export=False
                             )

#open mask
mask = xr.open_dataarray(mask_path, chunks=dict(x=750, y=750, time=1))

data

In [None]:
# mask = data[['vpd', 'SOC', 'NDWI', 'LST', 'tree_cover']].to_array().isnull().any('variable')
# mask.compute().to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2021.nc')

## Optionally rescale datasets to 5km

In [None]:
if rescale: 
    gbox_5km = zoom_out(data.odc.geobox, 5)
    data.attrs['nodata'] = np.nan
    data = xr_reproject(data, geobox=gbox_5km.compat, resampling='average')
    mask = xr_reproject(mask, geobox=gbox_5km.compat, resampling='mode')

    #make sure the coords aren't too precise
    data = round_coords(data)
    mask = round_coords(mask)
    data = data.rename({'latitude':'y', 'longitude':'x'}) #this helps with predict_xr
    mask = mask.rename({'latitude':'y', 'longitude':'x'}) 

    #rechunk arrays
    # data = data.chunk(chunks=dict(x=1000, y=1000, time=1))
    mask = mask.chunk(chunks=dict(x=1000, y=1000, time=1))

In [None]:
# data = data.compute()
# mask = mask.compute()
# data.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc')
# mask.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc')

In [None]:
data = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc')
mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc')

### Check training and prediction variable order

In [None]:
train_vars = list(pd.read_csv('/g/data/os22/chad_tmp/NEE_modelling/results/variables.txt'))[0:-1]
train_vars=[i[:-3] for i in train_vars]

data = data[train_vars]

if train_vars == list(data.data_vars):
    print('All good')
else:
    raise ValueError('Variables dont match')

### Predict each time-step seperately

- TO DO: fix timesteps that come back from `predict_xr`

In [None]:
import warnings
warnings.filterwarnings("ignore")

results = []

i=0
#start from 3 as these time-steps doesn't have rainfall lag values
for i in range(0, len(data.time)): 
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(data.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            data.isel(time=i),
                            proba=False,
                            clean=True,
                              ).compute()
    
    predicted = predicted.Predictions.where(~mask.isel(time=i).compute())
    predicted['time'] = data.isel(time=i).time.values
    results.append(predicted.astype('float32'))
    i+=1 

In [None]:
ds = xr.concat(results, dim='time').sortby('time').rename(var).astype('float32')
ds

## Mask urban areas using landcover dataset

In [None]:
mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/1km/Landcover_1km_monthly_2002_2021.nc').isel(time=0)
mask1 = (~np.isnan(mask1)).astype(np.int8)

if rescale:
    mask1 = xr_reproject(mask1, geobox=gbox_5km.compat, resampling='mode')
    mask1=round_coords(mask1)

else:
    mask1 = xr_reproject(mask1, geobox=data.odc.geobox.compat, resampling='mode')
    mask1=round_coords(mask1)
    

mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})
ds = ds.where(~mask).astype('float32')


### Save results

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/'+results_name)


## Animate result for fun

In [None]:
import xarray as xr
from IPython.display import Image
import matplotlib.pyplot as plt

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools')
from dea_tools.plotting import xr_animation


In [None]:
# var='NEE'
# ds = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/NEE_2003_2021_5km_LGBM.nc')

In [None]:
path = '/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_mystudy_LGBM_5km.gif'

xr_animation(ds.to_dataset(),
            bands=[var],
            show_date='%b %Y',
            width_pixels=600,
            output_path=path,
            show_colorbar=True,
            colorbar_kwargs={'colors': 'black'},
            # show_gdf=poly_gdf,
            interval=200, 
            show_text=var+' gC/m2/month',
            # gdf_kwargs={'edgecolor': 'grey', 'linewidth':0.5}, 
            imshow_kwargs={'cmap': 'viridis','vmin': 0, 'vmax': 150}#'vmin': 0, 'vmax': 50 'cmap': 'RdBu_r'
            )

# Plot animation
plt.close()
Image(path, embed=True)