# Predictions

In [None]:
import xarray as xr
import numpy as np
from joblib import load
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt
from datacube.utils.dask import start_local_dask

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _collect_prediction_data import round_coords

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

In [None]:
model_var='LST'

### Load model

In [None]:
model = load('/g/data/os22/chad_tmp/climate-carbon-interactions/results/models/gapfill/gapfill_'+model_var+'_LGBM.joblib').set_params(n_jobs=1)

### Load prediction data

and index to match training data order

In [None]:
base = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/'

datasets = [
     model_var+'_harmonization/'+model_var+'_5km_monthly_1982_2022_wGaps.nc',
    '5km/rain_5km_monthly_1981_2022.nc',
    '5km/rain_cml3_5km_monthly_1982_2022.nc',
    '5km/rain_cml3_anom_5km_monthly_1982_2022.nc',
    '5km/rain_cml6_5km_monthly_1982_2022.nc',
    '5km/rain_cml12_5km_monthly_1982_2022.nc',
    '5km/srad_5km_monthly_1982_2022.nc',
    '5km/srad_anom_5km_monthly_1982_2022.nc',
    '5km/tavg_5km_monthly_1982_2022.nc',
    '5km/tavg_anom_5km_monthly_1982_2022.nc',
    '5km/vpd_5km_monthly_1982_2022.nc',
    '5km/MOY_5km_monthly_1982_2022.nc',
    '5km/Elevation_5km_monthly_1982_2022.nc',
    '5km/CO2_5km_monthly_1982_2022.nc',
    '5km/WCF_5km_monthly_1982_2022.nc',
    #'5km/VegH_5km_monthly_1982_2022.nc',
    '5km/Aspect_5km_monthly_1982_2022.nc',
    # '5km/Landcover_5km_monthly_1982_2022.nc'
           ]

In [None]:
dss = []
for d in datasets:
    xx = xr.open_dataset(base+d).sel(time=slice('1982','2022'))
    xx = assign_crs(xx, crs ='epsg:4326')
    xx = round_coords(xx)
    xx = xx.drop('spatial_ref')
    dss.append(xx)

ds = xr.merge(dss)
ds = assign_crs(ds, crs ='epsg:4326')

### Add lat as a variable

Plus ensure order of the variables is correct for predictions

In [None]:
lat = ds.latitude
lat = lat.expand_dims(time=ds.time, longitude=ds.longitude)
lat = lat.transpose('time', 'latitude', 'longitude')
ds['latitude_gridded'] = lat

# lon = ds.longitude
# lon = lon.expand_dims(time=ds.time, latitude=ds.latitude)
# lon = lon.transpose('time', 'latitude', 'longitude')
# ds['longitude_gridded'] = lon

In [None]:
columns = list(ds.data_vars)[1:-1]
# columns.insert(0, 'longitude_gridded')
columns.insert(0, 'latitude_gridded')
ds = ds[columns]
ds = ds.rename({'latitude':'y', 'longitude':'x'})

### Create a mask

In [None]:
mask = ~np.isnan(ds.WCF.sel(time='2015').mean('time'))

### Predict

In [None]:
results = []
i=0
for i in range(0, len(ds.time)):
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(ds.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            ds.isel(time=i).drop('time'),
                            proba=False,
                            clean=True,
                            chunk_size=100000,
                              ).compute()
    
    # predicted = predicted.Predictions.where(~mask.isel(time=i))
    predicted = predicted.assign_coords(time=ds.isel(time=i).time).expand_dims(time=1)
    results.append(predicted.astype('float32'))
    i+=1

In [None]:
yy = xr.concat(results, dim='time').sortby('time').rename({'Predictions':model_var})#.astype('float32')
yy = yy.where(mask)

In [None]:
yy.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/results/ml_predictions/'+model_var+'_predicted_5km_monthly_1982_2022.nc')

In [None]:
yy[model_var].mean(['x','y']).plot(figsize=(11,5))