# Predicting fluxes on grid

To do:
* Need to mask NaNs in any variables at any time-step as fill-value will be non-sensical

In [None]:
import sys
import xarray as xr
import numpy as np
from joblib import load
from matplotlib import pyplot as plt
from datacube.utils.dask import start_local_dask

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/dea_tools')
from classification import predict_xr, HiddenPrints

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

## Analysis Parameters

In [None]:
var = 'ER'
results_name='ER_2003_2021.nc'
model_path = '/g/data/os22/chad_tmp/NEE_modelling/results/models/AUS_ER_model.joblib'
data_path = '/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/prediction_data_2002-10_2021.nc'

## Open model

In [None]:
model = load(model_path).set_params(n_jobs=1)

## Open predictor data

In [None]:
data = xr.open_dataset(data_path)

### Predict each time-step seperately

- TO DO: fix timesteps that come back from `predict_xr`

In [None]:
mask = data.to_array().isnull().any('variable')

In [None]:
results = []

i=0
#start from 3 as these time-steps doesn't have rainfall lag values
for i in range(3, len(data.time)): 
    print(" {:03}/{:03}\r".format(i + 1, len(range(3, len(data.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            data.isel(time=i),
                            proba=False,
                            clean=True,
                              ).compute()
    
    predicted = predicted.Predictions.where(~mask.isel(time=i))
    predicted['time'] = data.isel(time=i).time.values
    results.append(predicted)
    i+=1

In [None]:
ds = xr.concat(results, dim='time').sortby('time').rename(var)

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/'+results_name)