# Predictions

In [None]:
import xarray as xr
import numpy as np
from joblib import load
import warnings
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt
from datacube.utils.dask import start_local_dask

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _collect_prediction_data import round_coords

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

In [None]:
model_var='NDVI'
name='nontrees'
t1 = '1982'
t2 = '2013'

### Load model

In [None]:
model = load('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/Harmonization_LGBM_'+name+'.joblib').set_params(n_jobs=1)

### Load prediction data

and index to match training data order

In [None]:
ds = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/AVHRR_5km_monthly_1982_2013.nc')
ds = assign_crs(ds, crs ='epsg:3577')
ds = ds.sel(time=slice(t1, t2))
ds = ds.rename({'NDVI_median': 'NDVI_median_avhrr'})

moy = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/MOY_5km_monthly_1982_2022.nc')['month']
moy = assign_crs(moy, crs ='epsg:4326')
moy=moy.sel(time=slice(t1, t2))
moy=moy.odc.reproject(how=ds.odc.geobox)
moy = moy.transpose('time','y','x')
ds['MOY'] = moy

rain_cml3 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_cml3_5km_monthly_1982_2022.nc')['rain_cml3']
rain_cml3 = assign_crs(rain_cml3, crs ='epsg:4326')
rain_cml3=rain_cml3.sel(time=slice(t1, t2))
rain_cml3=rain_cml3.odc.reproject(how=ds.odc.geobox)
ds['rain_cml3'] = rain_cml3

rain_cml6 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_cml6_5km_monthly_1982_2022.nc')['rain_cml6']
rain_cml6 = assign_crs(rain_cml6, crs ='epsg:4326')
rain_cml6=rain_cml6.sel(time=slice(t1, t2))
rain_cml6=rain_cml6.odc.reproject(how=ds.odc.geobox)
ds['rain_cml6'] = rain_cml6

rain_cml3_anom = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/rain_cml3_anom_5km_monthly_1982_2022.nc')['rain_cml3_anom']
rain_cml3_anom = assign_crs(rain_cml3_anom, crs ='epsg:4326')
rain_cml3_anom=rain_cml3_anom.sel(time=slice(t1, t2))
rain_cml3_anom=rain_cml3_anom.odc.reproject(how=ds.odc.geobox)
ds['rain_cml3_anom'] = rain_cml3_anom

vpd = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/vpd_5km_monthly_1982_2022.nc')['vpd']
vpd = assign_crs(vpd, crs ='epsg:4326')
vpd=vpd.sel(time=slice(t1, t2))
vpd=vpd.odc.reproject(how=ds.odc.geobox)
ds['vpd'] = vpd

srad = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/srad_5km_monthly_1982_2022.nc')['srad']
srad = assign_crs(srad, crs ='epsg:4326')
srad=srad.sel(time=slice(t1, t2))
srad=srad.odc.reproject(how=ds.odc.geobox)
ds['srad'] = srad

tavg = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/tavg_5km_monthly_1982_2022.nc')['tavg']
tavg = assign_crs(tavg, crs ='epsg:4326')
tavg=tavg.sel(time=slice('2001', '2013'))
tavg=tavg.odc.reproject(how=ds.odc.geobox)
ds['tavg'] = tavg

CO2 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/CO2_5km_monthly_1982_2022.nc')['CO2']
CO2 = assign_crs(CO2, crs ='epsg:4326')
CO2=CO2.sel(time=slice(t1, t2))
CO2=CO2.odc.reproject(how=ds.odc.geobox)
ds['CO2'] = CO2.transpose('time','x', 'y')

mod = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/MODIS_NDVI_5km_monthly_2001_2022.nc')['NDVI_median']
mod = assign_crs(mod, crs ='epsg:3577')
mod = mod.sel(time=slice('2001', '2013'))

mean = mod.mean('time')
mean = mean.expand_dims(time=ds.time)
ds['NDVI_modis_mean'] = mean

ds = ds.drop(['NDVI_stddev', 'n_obs'])

### Add lat as a variable

Plus ensure order of the variables is correct for predictions

In [None]:
y = ds.y
y = y.expand_dims(time=ds.time, x=ds.x)
y = y.transpose('time', 'y', 'x')
ds['y_gridded'] = y

# lon = ds.longitude
# lon = lon.expand_dims(time=ds.time, latitude=ds.latitude)
# lon = lon.transpose('time', 'latitude', 'longitude')
# ds['longitude_gridded'] = lon

In [None]:
columns = list(ds.data_vars)[:-1]
columns.insert(0, 'y_gridded')
ds = ds[columns]
# ds = ds.rename({'latitude':'y', 'longitude':'x'})

### Create a mask

In [None]:
trees = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/trees_5km_monthly_1982_2022.nc')['trees']
trees = assign_crs(trees, crs ='epsg:4326')
trees=trees.sel(time=slice('2001', '2018'))
trees=trees.odc.reproject(how=ds.odc.geobox)
trees = trees.mean('time')

if name=='trees':
    mask = xr.where(trees>0.5, 1, 0)
if name=='nontrees':
    mask = xr.where(trees<=0.5, 1, 0)
# mask = ~np.isnan(ds['NDVI_modis_mean'])

avhrr_mask = ~np.isnan(ds['NDVI_median_avhrr'])

### Predict

In [None]:
warnings.filterwarnings("ignore")
results = []
i=0
for i in range(0, len(ds.time)):
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(ds.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            ds.isel(time=i).drop('time'),
                            proba=False,
                            clean=True,
                            chunk_size=100000,
                              ).compute()
    
    # predicted = predicted.Predictions.where(~mask.isel(time=i))
    predicted = predicted.assign_coords(time=ds.isel(time=i).time).expand_dims(time=1)
    results.append(predicted.astype('float32'))
    i+=1

In [None]:
yy = xr.concat(results, dim='time').sortby('time').rename({'Predictions':model_var})#.astype('float32')
yy = yy.where(mask)

In [None]:
yy = yy.where(avhrr_mask)

In [None]:
yy.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/NDVI_harmonization/'+model_var+'_'+name+'_LGBM_harmonize_test_5km_monthly_1982_2013.nc')

In [None]:
yy[model_var].mean(['x', 'y']).plot(figsize=(11,4))