# Generate an ensemble of gridded predictions

Using the 30 models produced in `3_Generate_ensemble_of_models.ipynb`, we will generate an ensemble of 30 predictions. From this ensemble we will produce an uncertainty envelope, and a median prediction.


In [None]:
import sys
import os
import warnings
import xarray as xr
import numpy as np
import pandas as pd
from joblib import load
from odc.geo.geobox import zoom_out
from odc.algo import xr_reproject
from datacube.utils.dask import start_local_dask
from odc.geo.xr import assign_crs
import odc.geo.xr
# from dask.distributed import Client,Scheduler
# from dask_jobqueue import SLURMCluster

sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from _collect_prediction_data import round_coords, collect_prediction_data 

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

## Analysis Parameters

In [None]:
var = 'ER'
suffix='20230320'
results_path = '/g/data/os22/chad_tmp/NEE_modelling/results/predictions_uncertainty/'+var+'/'
models_folder = '/g/data/os22/chad_tmp/NEE_modelling/results/models_uncertainty/'+var+'/'
features_list = '/g/data/os22/chad_tmp/NEE_modelling/results/variables_'+suffix+'.txt'

t1, t2='2003','2022'
rescale=False

## Get paths to models

In [None]:
model_list = [file for file in os.listdir(models_folder) if file.endswith(".joblib")]

## Open predictor data

At 1 km resolution, we need to pull the gridded feature layers in as dask arrays and compute on each time-step individually as the total memory requirements are very large. At 5 km resolution, its better to load the entire feature layer data into memory as it speeds up predictions.

In [None]:
# data = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc')
# mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc')

## open data
data = collect_prediction_data(time_start=t1,
                             time_end=t2,
                             verbose=False,
                             export=False,
                             chunks=dict(latitude=680, longitude=1050, time=1) #chunks optimised
                             )

#precomputed the mask to save a little time
mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2022.nc')

### Check training and prediction variable order

In [None]:
train_vars = list(pd.read_csv(features_list))[0:-1]
train_vars=[i[:-3] for i in train_vars]

data = data[train_vars]

if train_vars == list(data.data_vars):
    print('Variables match, n: ', len(data.data_vars))
else:
    raise ValueError('Variables dont match')

## Predict

Loop through each model, and each time-step.  Mask the output with the urban mask.


In [None]:
#mask urban (5km res)
# mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/urban_mask_1km.nc')
# mask1 = xr_reproject(mask1, geobox=data.odc.geobox.compat, resampling='mode')
# mask1=round_coords(mask1)
# mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})

#mask urban (1km res)
mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/urban_mask_1km.nc')
mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})

In [None]:
%%time

# Loop through the 30 models
for m in model_list:
    name = m.split('.')[0]
    
    if os.path.exists(results_path+name+'.nc'):
        print('skipping model '+name)
        continue
    
    print('Model: ', name)
    
    warnings.filterwarnings("ignore")
    model = load(models_folder+m).set_params(n_jobs=1)
    
    results = []
    i=0
    #loop through the time-steps
    for i in range(0, len(data.time)): 
        print("  {:03}/{:03}\r".format(i + 1, len(range(0, len(data.time)))), end="")

        with HiddenPrints():
            warnings.filterwarnings("ignore")
            predicted = predict_xr(model,
                                data.isel(time=i),
                                proba=False,
                                clean=True,
                                chunk_size=875000, #this number is optimized to maximise pred speed.
                                  ).compute()

        predicted = predicted.Predictions.where(~mask.isel(time=i).compute())
        predicted['time'] = data.isel(time=i).time.values
        results.append(predicted.astype('float32'))
        i+=1 
    
    ds = xr.concat(results, dim='time').sortby('time').rename(var).astype('float32')
    
    #mask urban
    ds = ds.where(mask!=1).astype('float32')

    #save results
    ds.to_netcdf(results_path+name+'.nc')
    