# Generate an ensemble of gridded predictions

Using the 30 models produced in `3_Generate_ensemble_of_models.ipynb`, we will generate an ensemble of 30 predictions. From this ensemble we will produce an uncertainty envelope, and a median prediction.


In [1]:
import sys
import os
import warnings
import xarray as xr
import numpy as np
import pandas as pd
from joblib import load
from odc.geo.geobox import zoom_out
from odc.algo import xr_reproject
from datacube.utils.dask import start_local_dask
from odc.geo.xr import assign_crs
import odc.geo.xr
# from dask.distributed import Client,Scheduler
# from dask_jobqueue import SLURMCluster

sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from _collect_prediction_data import round_coords, collect_prediction_data 

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

In [2]:
client = start_local_dask(mem_safety_margin='2Gb')
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 1
Total threads: 24,Total memory: 95.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:45345,Workers: 1
Dashboard: /proxy/8787/status,Total threads: 24
Started: Just now,Total memory: 95.00 GiB

0,1
Comm: tcp://127.0.0.1:38797,Total threads: 24
Dashboard: /proxy/37969/status,Memory: 95.00 GiB
Nanny: tcp://127.0.0.1:33623,
Local directory: /jobfs/84846300.gadi-pbs/dask-worker-space/worker-v160somn,Local directory: /jobfs/84846300.gadi-pbs/dask-worker-space/worker-v160somn


## Analysis Parameters

In [3]:
var = 'ET'
suffix='20230320'
results_path = '/g/data/os22/chad_tmp/NEE_modelling/results/predictions_uncertainty/'+var+'/'
models_folder = '/g/data/os22/chad_tmp/NEE_modelling/results/models_uncertainty/'+var+'/'
features_list = '/g/data/os22/chad_tmp/NEE_modelling/results/variables_'+suffix+'.txt'

t1, t2='2003','2022'
rescale=False

## Get paths to models

In [4]:
model_list = [file for file in os.listdir(models_folder) if file.endswith(".joblib")]

## Open predictor data

At 1 km resolution, we need to pull the gridded feature layers in as dask arrays and compute on each time-step individually as the total memory requirements are very large. At 5 km resolution, its better to load the entire feature layer data into memory as it speeds up predictions.

In [5]:
# data = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc')
# mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc')

## open data
data = collect_prediction_data(time_start=t1,
                             time_end=t2,
                             verbose=False,
                             export=False,
                             chunks=dict(latitude=680, longitude=1050, time=1) #chunks optimised
                             )

#precomputed the mask to save a little time
mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2022.nc')

### Check training and prediction variable order

In [6]:
train_vars = list(pd.read_csv(features_list))[0:-1]
train_vars=[i[:-3] for i in train_vars]

data = data[train_vars]

if train_vars == list(data.data_vars):
    print('Variables match, n: ', len(data.data_vars))
else:
    raise ValueError('Variables dont match')

Variables match, n:  20


## Predict

Loop through each model, and each time-step.  Mask the output with the urban mask.


In [7]:
#mask urban (5km res)
# mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/urban_mask_1km.nc')
# mask1 = xr_reproject(mask1, geobox=data.odc.geobox.compat, resampling='mode')
# mask1=round_coords(mask1)
# mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})

#mask urban (1km res)
mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/urban_mask_1km.nc')
mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})

In [None]:
%%time

# Loop through the 30 models
for m in model_list:
    name = m.split('.')[0]
    
    if os.path.exists(results_path+name+'.nc'):
        print('skipping model '+name)
        continue
    
    print('Model: ', name)
    
    warnings.filterwarnings("ignore")
    model = load(models_folder+m).set_params(n_jobs=1)
    
    results = []
    i=0
    #loop through the time-steps
    for i in range(0, len(data.time)): 
        print("  {:03}/{:03}\r".format(i + 1, len(range(0, len(data.time)))), end="")

        with HiddenPrints():
            warnings.filterwarnings("ignore")
            predicted = predict_xr(model,
                                data.isel(time=i),
                                proba=False,
                                clean=True,
                                chunk_size=875000, #this number is optimized to maximise pred speed.
                                  ).compute()

        predicted = predicted.Predictions.where(~mask.isel(time=i).compute())
        predicted['time'] = data.isel(time=i).time.values
        results.append(predicted.astype('float32'))
        i+=1 
    
    ds = xr.concat(results, dim='time').sortby('time').rename(var).astype('float32')
    
    #mask urban
    ds = ds.where(mask!=1).astype('float32')

    #save results
    ds.to_netcdf(results_path+name+'.nc')
    

skipping model _lgbm_7
skipping model _lgbm_8
skipping model _rf_13
skipping model _lgbm_11
skipping model _rf_10
skipping model _lgbm_15
skipping model _lgbm_1
skipping model _rf_15
skipping model _rf_1
skipping model _rf_3
skipping model _lgbm_12
skipping model _rf_7
skipping model _rf_8
skipping model _lgbm_4
skipping model _lgbm_14
skipping model _rf_2
skipping model _lgbm_3
skipping model _lgbm_9
skipping model _rf_5
skipping model _rf_12
skipping model _rf_4
skipping model _lgbm_13
Model:  _lgbm_2
Model:  _rf_6
  003/234



  005/234



  007/234



  009/234



  012/234



  014/234



  016/234



  018/234



  021/234



  023/234



  025/234



  027/234



  030/234



  032/234



  034/234



  036/234



  038/234



  041/234



  043/234



  045/234



  047/234



  050/234



  052/234



  054/234



  056/234



  059/234



  061/234



  063/234



  065/234



  068/234



  070/234



  072/234



  074/234



  076/234



  079/234



  081/234



  083/234



  085/234



  088/234



  090/234



  092/234



  094/234



  097/234



  099/234



  101/234



  103/234



  106/234



  108/234



  110/234



  112/234



  114/234



  117/234



  119/234



  121/234



  123/234



  126/234



  128/234



  130/234



  132/234



  135/234



  137/234



  139/234



  141/234



  144/234



  146/234



  148/234



  150/234



  153/234



  155/234



  157/234



  159/234



  161/234



  164/234



  166/234



  168/234



  170/234



  173/234



  175/234



  177/234



  179/234



  182/234



  184/234



  186/234



  188/234



  191/234



  193/234



  195/234



  197/234



  199/234



  202/234



  204/234



  206/234



  208/234



  211/234



  213/234



  215/234



  217/234



  220/234



  222/234



  224/234



  226/234



  229/234



  231/234



  233/234



Model:  _rf_9
  001/234



  003/234



  006/234



  008/234



  010/234



  012/234



  015/234



  017/234

