# Predicting fluxes on grid


In [None]:
import sys
import xarray as xr
import numpy as np
import pandas as pd
from joblib import load
from matplotlib import pyplot as plt
from odc.geo.geobox import zoom_out
from datacube.utils.dask import start_local_dask

sys.path.append('/g/data/os22/chad_tmp/climate-carbon-interactions/src/')
from _collect_prediction_data import collect_prediction_data, round_coords

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

## Analysis Parameters

In [None]:
var = 'NEE'
suffix='20230910'

t1, t2='2003','2022'
rescale=False

### Set up paths

In [None]:
results_name=var+'_'+t1+'_'+t2+'_5km_LGBM_'+suffix+'.nc'
model_path = '/g/data/os22/chad_tmp/climate-carbon-interactions/results/models/fluxes/'+var+'_LGBM_model_'+suffix+'.joblib'
features_list = '/g/data/os22/chad_tmp/climate-carbon-interactions/results/variables_'+suffix+'.txt'

## Open model

In [None]:
model = load(model_path).set_params(n_jobs=1)

## Open predictor data

In [None]:
## open data
data = collect_prediction_data(time_start=t1,
                             time_end=t2,
                             verbose=False,
                             export=False,
                             chunks=dict(latitude=1000, longitude=1000, time=1)
                             )

## Create a no-data mask

If we haven't already

In [None]:
# mask = data[['NDVI', 'WCF', 'tavg']].to_array().isnull().any('variable')
# mask.compute().to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/mask_5km_monthly_1982_2022.nc')
mask = xr.open_dataarray('/g/data/os22/chad_tmp/climate-carbon-interactions/data/mask_5km_monthly_1982_2022.nc')
mask = mask.sel(time=slice(t1, t2))

In [None]:
# #create an urban mask once, then next time load it.
# mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/urban_mask_1km.nc')
# mask1 = mask1.odc.reproject(mask.odc.geobox, resampling='mode')
# mask1=round_coords(mask1)
# mask1.name='urban_mask'
# mask1.compute().to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/data/urban_mask_5km_.nc')

### Index by variables and check variable order

In [None]:
train_vars = list(pd.read_csv(features_list))[0:-1]
train_vars=[i[:-3] for i in train_vars]

data = data[train_vars]

if train_vars == list(data.data_vars):
    print('Variables match, n:', len(data.data_vars))
else:
    raise ValueError("Variables don't match")

### Predict each time-step seperately

- TO DO: fix timesteps that come back from `predict_xr`

In [None]:
%%time
import warnings
warnings.filterwarnings("ignore")

results = []
i=0

for i in range(0, len(data.time)): 
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(data.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            data.isel(time=i),
                            proba=False,
                            clean=True,
                              ).compute()
    
    predicted = predicted.Predictions.where(~mask.isel(time=i))
    predicted['time'] = data.isel(time=i).time.values
    results.append(predicted.astype('float32'))
    i+=1 

In [None]:
ds = xr.concat(results, dim='time').sortby('time').rename(var).astype('float32')

In [None]:
# ds.sel(time=slice('2003', '2022')).mean(['x','y']).plot(figsize=(13,5))

## Mask urban areas using landcover dataset

In [None]:
mask1 = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/urban_mask_5km.nc')['urban_mask']
mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})
ds = ds.where(mask1!=1).astype('float32')

### Save results

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/results/predictions/'+results_name)

## Animate results

In [None]:
import xarray as xr
from IPython.display import Image
import matplotlib.pyplot as plt

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools')
from dea_tools.plotting import xr_animation


In [None]:
# ds = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/GPP_2003_2021_5km_LGBM.nc')

In [None]:
path = '/g/data/os22/chad_tmp/climate-carbon-interactions/results/gifs/'+var+'_LGBM_5km_'+suffix+'.gif'

if var=='NEE':
    imshow={'vmin': -45, 'vmax': 45, 'cmap': 'Spectral_r'}
    
else:
    imshow={'vmin': 0, 'vmax': 150, 'cmap': 'gist_earth_r'}

xr_animation(ds.to_dataset().rolling(time=3, min_periods=1).mean(),
            bands=[var],
            show_date='%b %Y',
            width_pixels=600,
            output_path=path,
            show_colorbar=True,
            colorbar_kwargs={'colors': 'black'},
            # show_gdf=poly_gdf,
            interval=150, 
            show_text=var+' gC/m2/month',
            # gdf_kwargs={'edgecolor': 'grey', 'linewidth':0.5}, 
            imshow_kwargs=imshow
            )

# Plot animation
plt.close()
Image(path, embed=True)

In [None]:
# wcf_1982 = data.WCF.sel(time='2020').compute()

In [None]:
# trees = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/data/1km/trees_1km_monthly_2002_2022.nc')['trees'].isel(time=-1)

In [None]:
# fig,ax=plt.subplots(1,2, figsize=(13,5))
# (wcf_1982/100).isel(time=-1).plot.imshow(robust=True, ax=ax[0], vmax=0.5)
# trees.plot.imshow(robust=True, ax=ax[1], vmax=0.5)
# plt.tight_layout()
# ax[0].set_title('WCF 2020')
# ax[1].set_title('Donohue fraction trees 2020')

In [None]:
# ds.mean('time').plot.imshow(robust=True)

In [None]:
# clim_early = ds.sel(time=slice('1982', '2002')).groupby('time.month').mean()
clim_late = ds.sel(time=slice('2003', '2022')).groupby('time.month').mean()

fig,ax=plt.subplots(1,1)
# clim_early.mean(['x', 'y']).plot(ax=ax, label='early')
clim_late.mean(['x', 'y']).plot(ax=ax,  label='late')
ax.legend()