# Predicting fluxes on grid


In [1]:
import sys
import xarray as xr
import numpy as np
import pandas as pd
from joblib import load
from matplotlib import pyplot as plt
from odc.geo.geobox import zoom_out
from odc.algo import xr_reproject
from datacube.utils.dask import start_local_dask

# from dask.distributed import Client,Scheduler
# from dask_jobqueue import SLURMCluster

sys.path.append('/g/data/os22/chad_tmp/NEE_modelling/')
from collect_prediction_data import collect_prediction_data, round_coords

sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints

In [2]:
# cluster = SLURMCluster(processes=2, cores=2, memory="47GB", walltime='02:00:00')
# client = Client(cluster)
# cluster.scale(cores=18)

client = start_local_dask(mem_safety_margin='2Gb')
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 1
Total threads: 16,Total memory: 44.92 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:40325,Workers: 1
Dashboard: /proxy/8787/status,Total threads: 16
Started: Just now,Total memory: 44.92 GiB

0,1
Comm: tcp://127.0.0.1:38335,Total threads: 16
Dashboard: /proxy/41719/status,Memory: 44.92 GiB
Nanny: tcp://127.0.0.1:46875,
Local directory: /local/u46/cb3058/tmp/dask-worker-space/worker-49ppkutq,Local directory: /local/u46/cb3058/tmp/dask-worker-space/worker-49ppkutq


## Analysis Parameters

In [3]:
var = 'NEE'
suffix='20230306'

t1, t2='2003','2022'
rescale=False

### Set up paths

In [4]:
results_name=var+'_2003_2021_5km_LGBM_'+suffix+'.nc'
model_path = '/g/data/os22/chad_tmp/NEE_modelling/results/models/AUS_'+var+'_LGBM_model_'+suffix+'.joblib'
# mask_path = '/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2021.nc'
features_list = '/g/data/os22/chad_tmp/NEE_modelling/results/variables_'+suffix+'.txt'

## Open model

In [5]:
model = load(model_path).set_params(n_jobs=1)

## Open predictor data

In [6]:
## open data
data = collect_prediction_data(time_start=t1,
                             time_end=t2,
                             verbose=False,
                             export=False,
                             chunks=dict(latitude=900, longitude=900, time=-1)
                             )

In [7]:
# mask = data[['NDWI', 'LST', 'tree_cover', 'TWI']].to_array().isnull().any('variable')
# mask.compute().to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2022.nc')

mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/1km/mask_1km_monthly_2003_2022.nc')

## Optionally rescale datasets to 5km

In [8]:
# if rescale: 
#     gbox_5km = zoom_out(data.odc.geobox, 5)
#     data.attrs['nodata'] = np.nan
#     data = xr_reproject(data, geobox=gbox_5km.compat, resampling='average')
#     # data = data.odc.reproject(how=gbox_5km, resampling='average') # no support yet for dask

#     #make sure the coords aren't too precise
#     data = round_coords(data)
#     data = data.rename({'latitude':'y', 'longitude':'x'}) #this helps with predict_xr
    
# data.attrs['nodata'] = np.nan
# data

In [9]:
# data = data.compute()
# mask = data[['VegH','NDWI', 'LST', 'tree_cover', 'TWI']].to_array().isnull().any('variable').compute()
# data.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km__.nc')
# mask.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km__.nc')

In [10]:
# data = xr.open_dataset('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/data_5km.nc')
# mask = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/prediction_data/mask_5km.nc')

In [11]:
# mask.isel(time=1).plot.imshow(size=7)

### Check training and prediction variable order

In [12]:
train_vars = list(pd.read_csv(features_list))[0:-1]
train_vars=[i[:-3] for i in train_vars]

data = data[train_vars]

if train_vars == list(data.data_vars):
    print('Variables match, n: ', len(data.data_vars))
else:
    raise ValueError("Variables don't match")

Variables match, n:  20


### Predict each time-step seperately

- TO DO: fix timesteps that come back from `predict_xr`

In [13]:
# %time
# data = data.compute()

In [None]:
%%time
import warnings
warnings.filterwarnings("ignore")

results = []

i=0
#start from 3 as these time-steps doesn't have rainfall lag values
for i in range(0, len(data.time)): 
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(data.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            data.isel(time=i),#.chunk(dict(y=900, x=900)),
                            proba=False,
                            clean=True,
                            chunk_size=1e6,
                              ).compute()
    
    predicted = predicted.Predictions.where(~mask.isel(time=i))
    predicted['time'] = data.isel(time=i).time.values
    results.append(predicted.astype('float32'))
    i+=1 

 012/234

In [None]:
ds = xr.concat(results, dim='time').sortby('time').rename(var).astype('float32')
ds

## Mask urban areas using landcover dataset

In [None]:
mask1 = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/data/urban_mask_1km.nc')

# if rescale:
#     mask1 = xr_reproject(mask1, geobox=gbox_5km.compat, resampling='mode')
#     mask1=round_coords(mask1)

# else:
#     mask1 = xr_reproject(mask1, geobox=data.odc.geobox.compat, resampling='mode')
#     mask1=round_coords(mask1)
    

mask1 = mask1.rename({'latitude':'y', 'longitude':'x'})
ds = ds.where(~mask1).astype('float32')


### Save results

In [None]:
ds.to_netcdf('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/'+results_name)


## Animate results

In [None]:
import xarray as xr
from IPython.display import Image
import matplotlib.pyplot as plt

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools')
from dea_tools.plotting import xr_animation


In [None]:
# ds = xr.open_dataarray('/g/data/os22/chad_tmp/NEE_modelling/results/predictions/GPP_2003_2021_5km_LGBM.nc')

In [None]:
path = '/g/data/os22/chad_tmp/NEE_modelling/results/figs/'+var+'_mystudy_LGBM_5km_'+suffix+'.gif'

if var=='NEE':
    imshow={'vmin': -50, 'vmax': 50, 'cmap': 'RdBu_r'}
    
else:
    imshow={'vmin': 0, 'vmax': 150, 'cmap': 'viridis'}

xr_animation(ds.to_dataset(),
            bands=[var],
            show_date='%b %Y',
            width_pixels=600,
            output_path=path,
            show_colorbar=True,
            colorbar_kwargs={'colors': 'black'},
            # show_gdf=poly_gdf,
            interval=200, 
            show_text=var+' gC/m2/month',
            # gdf_kwargs={'edgecolor': 'grey', 'linewidth':0.5}, 
            imshow_kwargs=imshow
            )

# Plot animation
plt.close()
Image(path, embed=True)