# Predicting fluxes on grid

To do:
 
* Implement confidence intervals using `forestci` https://github.com/scikit-learn-contrib/forest-confidence-interval


In [None]:
import sys
import xarray as xr
import numpy as np
import pandas as pd
from joblib import load
from matplotlib import pyplot as plt
from datacube.utils.dask import start_local_dask
# from dask.distributed import Client,Scheduler
# from dask_jobqueue import SLURMCluster

sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from collect_prediction_data import collect_prediction_data

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

In [None]:
# cluster = SLURMCluster(processes=2, cores=2, memory="47GB", walltime='01:00:00')
# client = Client(cluster)
# cluster.scale(cores=18)

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

## Analysis Parameters

In [None]:
var = 'NEE'
results_name='NEE_2003_2021.nc'
model_path = '/g/data/os22/chad_tmp/NEE_modelling/results/models/AUS_NEE_RF_model.joblib'
#data_path = '/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/prediction_data_2003_2021.nc'
t1, t2='2003','2021'

## Open model

In [None]:
model = load(model_path).set_params(n_jobs=1)

## Open predictor data

In [None]:
# data = xr.open_dataset(data_path).set_coords('spatial_ref')
data = collect_prediction_data(time_start=t1,
                             time_end=t2,
                             verbose=False,
                             export=False
                             )
data

### Predict each time-step seperately

- TO DO: fix timesteps that come back from `predict_xr`

In [None]:
# mask = data.to_array().isnull().any('variable')
# mask.isel(time=20).plot.imshow()

# data = data.drop('PFT')

### Check training and prediction variable order

In [None]:
print(list(data.data_vars))
print('\n')      
print(list(pd.read_csv('/g/data/os22/chad_tmp/NEE_modelling/results/variables.txt')))

In [None]:
results = []

i=0
#start from 3 as these time-steps doesn't have rainfall lag values
for i in range(0, len(data.time)): 
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(data.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            data.isel(time=i),
                            proba=False,
                            clean=True,
                              ).compute()
    
    predicted = predicted.Predictions#.where(~mask.isel(time=i))
    predicted['time'] = data.isel(time=i).time.values
    results.append(predicted.astype('float32'))
    i+=1 

In [None]:
ds = xr.concat(results, dim='time').sortby('time').rename(var).astype('float32')

### Mask water and urban areas using landcover dataset

In [None]:
mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/1km/Landcover_1km_monthly_2002_2021.nc').isel(time=0)
mask = mask.rename({'latitude':'y', 'longitude':'x'})
mask['x'] = mask.x.astype('float32')
mask['y'] = mask.y.astype('float32')
mask['y'] = np.array([round(i,4) for i in mask.y.values])
mask['x'] = np.array([round(i,4) for i in mask.x.values])
mask = ~np.isnan(mask)

In [None]:
ds = ds.where(mask).astype('float32')

In [None]:
ds.astype('float32').to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/'+results_name)

## Animate result for fun

In [3]:
import xarray as xr
from IPython.display import Image
import matplotlib.pyplot as plt

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools')
from dea_tools.plotting import xr_animation


In [8]:
var='NEE'
ds = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/NEE_2003_2021.nc')

In [None]:
path = '/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_mystudy.gif'

xr_animation(ds.to_dataset(),
            bands=[var],
            show_date='%b %Y',
            width_pixels=600,
            output_path=path,
            show_colorbar=True,
            colorbar_kwargs={'colors': 'black'},
            # show_gdf=poly_gdf,
            interval=200, 
            show_text=var+' gC/m2/month',
            # gdf_kwargs={'edgecolor': 'grey', 'linewidth':0.5}, 
            imshow_kwargs={'cmap': 'RdBu_r','vmin': -50, 'vmax': 50}#
            )

# Plot animation
plt.close()
Image(path, embed=True)