# Gapfilling NDVI/LST with machine learning


### Load packages
Import Python packages that are used for the analysis.

In [None]:
%matplotlib inline

import xarray as xr
import numpy as np
import pandas as pd
import seaborn as sb
from joblib import dump
from scipy import stats
import geopandas as gpd
from pprint import pprint
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

import lightgbm as lgbm
from lightgbm import LGBMRegressor

import shap
from sklearn.model_selection import RandomizedSearchCV, KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

import sys
sys.path.append('/g/data/os22/chad_tmp/dea-notebooks/Tools/')
from dea_tools.classification import predict_xr, HiddenPrints
from dea_tools.spatial import xr_rasterize

sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _collect_prediction_data import round_coords

import warnings
warnings.filterwarnings("ignore")

### Analysis parameters
* `path`: The path to the input shapefile. A default shapefile is provided.

In [None]:
model_var='LST'#'NDVI'
n_samples = 7000

### Assemble datasets for training and predicting

In [None]:
base = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/'

datasets = [
    model_var+'_5km_monthly_1982_2022_wGaps.nc',
     # 'NDVI_5km_monthly_1982_2022.nc',
    'rain_5km_monthly_1981_2022.nc',
    'rain_cml3_5km_monthly_1982_2022.nc',
    'rain_cml6_5km_monthly_1982_2022.nc',
    'rain_cml12_5km_monthly_1982_2022.nc',
    'srad_5km_monthly_1982_2022.nc',
    'tavg_5km_monthly_1982_2022.nc',
    'vpd_5km_monthly_1982_2022.nc',
    'MOY_5km_monthly_1982_2022.nc',
    'Elevation_5km_monthly_1982_2022.nc',
    'CO2_5km_monthly_1982_2022.nc',
    'WCF_5km_monthly_1990_2022.nc'
           ]

In [None]:
dss = []
for d in datasets:
    xx = xr.open_dataset(base+d).sel(time=slice('1990','2021'))
    xx = assign_crs(xx, crs ='epsg:4326')
    xx = round_coords(xx)
    xx = xx.drop('spatial_ref')
    dss.append(xx)

ds = xr.merge(dss)
ds = assign_crs(ds, crs ='epsg:4326')

## Training & testing data: equal random sampling of bioclimatic regions

In [None]:
gdf = gpd.read_file('/g/data/os22/chad_tmp/NEE_modelling/data/bioclimatic_regions.geojson')

In [None]:
# Dictionary to save results 
results = []
for index, row in gdf.iterrows():
    print(row['region_name'])

    # Generate a polygon mask to keep only data within the polygon
    mask = xr_rasterize(gdf.iloc[[index]], ds[model_var])
    mask = round_coords(mask)
    
    # Mask dataset to set pixels outside the polygon to `NaN`
    dss = ds.where(mask)

    #sample equivalent num of samples per region
    df = dss.to_dataframe().dropna().sample(n=int(n_samples/len(gdf)), random_state=0).reset_index()
    
    # Append results to a dictionary using the attribute
    # column as an key
    results.append(df)


In [None]:
df = pd.concat(results).reset_index(drop=True)
df['year'] = pd.DatetimeIndex(df['time']).year

In [None]:
df = df.drop(['time', 'spatial_ref'], axis=1)

### Independent validation samples

In [None]:
validation = df.sample(n=1000, random_state=0)

In [None]:
df = df.drop(validation.index)
print(len(df), 'training samples')

### Plot the location of the samples

In [None]:
gdf = gpd.GeoDataFrame(
    df, geometry=gpd.points_from_xy(df.longitude, df.latitude), crs="EPSG:4326"
)

gdf_val = gpd.GeoDataFrame(
    validation, geometry=gpd.points_from_xy(validation.longitude, validation.latitude), crs="EPSG:4326"
)

# gdf_val.explore(column='year', cmap='inferno')

In [None]:
gdf.to_file('/g/data/os22/chad_tmp/climate-carbon-interactions/data/training_data.geojson')
gdf_val.to_file('/g/data/os22/chad_tmp/climate-carbon-interactions/data/validation_data.geojson')

### Import training data if skipping above

In [None]:
gdf = gpd.read_file('/g/data/os22/chad_tmp/climate-carbon-interactions/data/training_data.geojson')
df = pd.DataFrame(gdf.drop(columns='geometry', axis=1))
df = df.drop(['year'], axis=1)

gdf_val = gpd.read_file('/g/data/os22/chad_tmp/climate-carbon-interactions/data/validation_data.geojson')
validation = pd.DataFrame(gdf_val.drop(columns='geometry', axis=1))
validation = validation.drop(['year'], axis=1)

In [None]:
y = df[model_var]
x = df.drop([model_var, 'longitude'], axis=1)

## Testing model using nested CV

In [None]:
# Create the parameter grid using distributions
param_grid = {
    'num_leaves': stats.randint(5,50),
    'min_child_samples':stats.randint(10,30),
    'boosting_type': ['gbdt', 'dart'],
    'max_depth': stats.randint(5,25),
    'n_estimators': [200, 300, 400, 500],
}

In [None]:
outer_cv = KFold(n_splits=5, shuffle=True,
                   random_state=0)

# lists to store results of CV testing
acc = []
rmse=[]
r2=[]

i = 1
for train_index, test_index in outer_cv.split(x, y):
    print(f"Working on {i}/5 outer CV split", end='\r')
    model = LGBMRegressor(random_state=1,
                          verbose=-1,
                          # n_jobs=-1
                          )

    # index training, testing
    X_tr, X_tt = x.iloc[train_index, :], x.iloc[test_index, :]
    y_tr, y_tt = y.iloc[train_index], y.iloc[test_index]
    
    #simple random split on inner fold
    inner_cv = KFold(n_splits=3,
                     shuffle=True,
                     random_state=0)
    
    clf = RandomizedSearchCV(
                   model,
                   param_grid,
                   verbose=0,
                   n_iter=100,
                   # n_jobs=-1,
                   cv=inner_cv.split(X_tr, y_tr)
                  )
    
    #prevents extensive print statements
    clf.fit(X_tr, y_tr, callbacks=None)
    
    # predict using the best model
    best_model = clf.best_estimator_
    pred = best_model.predict(X_tt)

    # evaluate model w/ multiple metrics
    # r2
    r2_ = r2_score(y_tt, pred)
    r2.append(r2_)
    # MAE
    ac = mean_absolute_error(y_tt, pred)
    acc.append(ac)
    # RMSE
    rmse_ = np.sqrt(mean_squared_error(y_tt, pred))
    rmse.append(rmse_)
    
    #1:1 plots for each fold (save to csv so we can make a plot later on)
    df = pd.DataFrame({'Test':y_tt, 'Pred':pred}).reset_index(drop=True)

    df.to_csv("/g/data/os22/chad_tmp/climate-carbon-interactions/results/cross_validation/"+str(i)+"_"+model_var+"_lgbm.csv")
    
    i += 1

### Create a single 1:1 plot out of the folds 

None of the test samples overlap between folds, and every sample has been tested


In [None]:
dffs=[]
for i in range(1,5+1):
    df = pd.read_csv('/g/data/os22/chad_tmp/climate-carbon-interactions/results/cross_validation/'i+'_'+model_var+'_lgbm.csv', usecols=['Test', 'Pred'])
    dffs.append(df)

cross_df = pd.concat(dffs)

In [None]:
fig,ax = plt.subplots(1,1, figsize=(6,6))

xy = np.vstack([cross_df['Test'],cross_df['Pred']])
z = gaussian_kde(xy)(xy)

sb.scatterplot(data=cross_df, x='Test',y='Pred',c=z, s=50, lw=1, alpha=0.5, ax=ax)
sb.regplot(data=cross_df, x='Test',y='Pred', scatter=False, color='darkblue', ax=ax)
sb.regplot(data=cross_df, x='Test',y='Test', color='black', scatter=False, line_kws={'linestyle':'dashed'}, ax=ax);

plt.xlabel('Observation '+model_var, fontsize=16)
plt.ylabel('Prediction '+model_var, fontsize=16)
ax.text(.05, .95, 'r\N{SUPERSCRIPT TWO}={:.2f}'.format(np.mean(r2)),
            transform=ax.transAxes, fontsize=16)
ax.text(.05, .9, 'MAE={:.2g}'.format(np.mean(acc)),
            transform=ax.transAxes, fontsize=16)
ax.set_ylim(0, 1)
ax.set_xlim(0, 1)

ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)

plt.tight_layout()
# fig.savefig("/g/data/os22/chad_tmp/NEE_modelling/results/cross_validation/cross_val_"+model_var+"_lgbm_"+suffix+".png")

## Optimize model using all training data

Using a randomized strategy so we can search through more variables, with 500 iterations


In [None]:
outer_cv = KFold(n_splits=5, shuffle=True,
                   random_state=0)

clf = RandomizedSearchCV(LGBMRegressor(verbose=-1),
                   param_grid,
                   verbose=1,
                   n_iter=500,
                   # n_jobs=-1,
                   cv=outer_cv
                  )

clf.fit(x, y, callbacks=None)

In [None]:
print("The most accurate combination of tested parameters is: ")
pprint(clf.best_params_)
print('\n')
print("The best score using these parameters is: ")
print(round(clf.best_score_, 2))

## Fit on all data using best params

In [None]:
model = LGBMRegressor(**clf.best_params_)

model.fit(x,y)

## Compare with independent validation data

In [None]:
y_val = validation[model_var]
x_val = validation.drop([model_var,'longitude'], axis=1)

In [None]:
pred = model.predict(x_val)

r2 = r2_score(y_val, pred)
ac = mean_absolute_error(y_val, pred)
df_val = pd.DataFrame({'Test':y_val, 'Pred':pred}).reset_index(drop=True)

In [None]:
fig,ax = plt.subplots(1,1, figsize=(6,6))

xy = np.vstack([df_val['Test'],df_val['Pred']])
z = gaussian_kde(xy)(xy)

sb.scatterplot(data=df_val, x='Test',y='Pred',c=z, s=50, lw=1, alpha=0.5, ax=ax)
sb.regplot(data=df_val, x='Test',y='Pred', scatter=False, color='darkblue', ax=ax)
sb.regplot(data=df_val, x='Test',y='Test', color='black', scatter=False, line_kws={'linestyle':'dashed'}, ax=ax);

plt.xlabel('Observation '+model_var, fontsize=16)
plt.ylabel('Prediction '+model_var, fontsize=16)
ax.text(.05, .95, 'r\N{SUPERSCRIPT TWO}={:.2f}'.format(r2),
            transform=ax.transAxes, fontsize=16)
ax.text(.05, .9, 'MAE={:.2g}'.format(ac),
            transform=ax.transAxes, fontsize=16)
ax.set_ylim(0, 1)
ax.set_xlim(0, 1)

ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)

plt.tight_layout()
# fig.savefig("/g/data/os22/chad_tmp/NEE_modelling/results/cross_validation/cross_val_"+model_var+"_lgbm_"+suffix+".png")

### Save the model

In [None]:
dump(model, '/g/data/os22/chad_tmp/climate-carbon-interactions/results/models/gapfill/gapfill_'+model_var+'_LGBM.joblib')

## Examine feature importance using SHAP

SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model

In [None]:
# explain the model's predictions using SHAP
explainer = shap.Explainer(model)
shap_values = explainer(x)

In [None]:
vals= np.abs(shap_values.values).mean(0)
feature_importance = pd.DataFrame(list(zip(x.columns, vals)), columns=['col_name','feature_importance_vals'])
feature_importance.sort_values(by=['feature_importance_vals'],ascending=False,inplace=True)
feature_importance['col_name'] = feature_importance['col_name'].str.removesuffix("_RS")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(5,7))
shap.summary_plot(shap_values, max_display=11, show=False, feature_names=feature_importance['col_name'])
plt.gcf().axes[-1].set_aspect('auto')
plt.gcf().axes[-1].set_box_aspect(15) 
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)
ax.set_xlabel(model_var+' SHAP Value', fontsize=16)
plt.tight_layout()
# fig.savefig("/g/data/os22/chad_tmp/NEE_modelling/results/cross_validation/feature_importance_"+model_var+"_lgbm_"+suffix+".png")

## Predictions

In [None]:
from joblib import load
from datacube.utils.dask import start_local_dask

In [None]:
client = start_local_dask(mem_safety_margin='2Gb')
client

In [None]:
model_var='LST'

### Load model

In [None]:
model = load('/g/data/os22/chad_tmp/climate-carbon-interactions/results/models/gapfill/gapfill_'+model_var+'_LGBM.joblib').set_params(n_jobs=1)

### Load prediction data

and index to match training data order

In [None]:
base = '/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/'

datasets = [
    model_var+'_5km_monthly_1982_2022_wGaps.nc',
     # 'NDVI_5km_monthly_1982_2022.nc',
    'rain_5km_monthly_1981_2022.nc',
    'rain_cml3_5km_monthly_1982_2022.nc',
    'rain_cml6_5km_monthly_1982_2022.nc',
    'rain_cml12_5km_monthly_1982_2022.nc',
    'srad_5km_monthly_1982_2022.nc',
    'tavg_5km_monthly_1982_2022.nc',
    'vpd_5km_monthly_1982_2022.nc',
    'MOY_5km_monthly_1982_2022.nc',
    'Elevation_5km_monthly_1982_2022.nc',
    'CO2_5km_monthly_1982_2022.nc',
    'WCF_5km_monthly_1990_2022.nc'
           ]

In [None]:
dss = []
for d in datasets:
    xx = xr.open_dataset(base+d).sel(time=slice('1990','2021'))
    xx = assign_crs(xx, crs ='epsg:4326')
    xx = round_coords(xx)
    xx = xx.drop('spatial_ref')
    dss.append(xx)

ds = xr.merge(dss)
ds = assign_crs(ds, crs ='epsg:4326')

### Add latitude as a variable

In [None]:
lat = ds.latitude
lat = lat.expand_dims(time=ds.time, longitude=ds.longitude)
lat = lat.transpose('time', 'latitude', 'longitude')
ds['latitude_gridded'] = lat

In [None]:
columns = list(ds.data_vars)[1:-1]
columns.insert(0, 'latitude_gridded')
ds = ds[columns]
ds = ds.rename({'latitude':'y', 'longitude':'x'})

### Create a mask

In [None]:
mask = ~np.isnan(ds.WCF.sel(time='2015').mean('time'))

### Predict

In [None]:
results = []
i=0
for i in range(0, len(ds.time)):
    print(" {:03}/{:03}\r".format(i + 1, len(range(0, len(ds.time)))), end="")
    with HiddenPrints():
        predicted = predict_xr(model,
                            ds.isel(time=i),
                            proba=False,
                            clean=True,
                            chunk_size=10000,
                              ).compute()
    
    # predicted = predicted.Predictions.where(~mask.isel(time=i))
    predicted['time'] = ds.isel(time=i).time.values
    results.append(predicted.astype('float32'))
    i+=1 

In [None]:
yy = xr.concat(results, dim='time').sortby('time').rename({'Predictions':model_var})#.astype('float32')
yy['time'] = ds.isel(time=range(0, len(yy.time))).time
yy = yy.where(mask)

In [None]:
yy.to_netcdf('/g/data/os22/chad_tmp/climate-carbon-interactions/results/ml_predictions/NDVI_predicted_5km_monthly_1990_2022.nc')

### Compare with observations

In [None]:
ndvi = xr.open_dataset('/g/data/os22/chad_tmp/climate-carbon-interactions/data/5km/'+model_var+'_5km_monthly_1982_2022_wGaps.nc')

In [None]:
ndvi = ndvi.sel(time=yy.time)
ndvi = ndvi.NDVI.rename({'latitude':'y', 'longitude':'x'})
gaps_mask = ~np.isnan(ndvi)

In [None]:
diff = yy.NDVI - ndvi

In [None]:
fig, ax = plt.subplots(1,1, figsize=(13,5))
yy.NDVI.where(gaps_mask).mean(['x','y']).plot(ax=ax, label='predictions')
ndvi.mean(['x','y']).plot(ax=ax, label='observed')
ax.legend()
ax.set_title('Aus-Wide NDVI matching data gaps')

In [None]:
corr = xr.corr(ndvi, yy.NDVI, dim='time')

In [None]:
pred_mean = yy.NDVI.mean('time')

In [None]:
ndvi_mean = ndvi.mean('time')

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10), sharey=True)
ndvi_mean.plot.imshow(ax=ax[0,0], vmin=0.05, vmax=0.7, add_labels=False)
pred_mean.plot.imshow(ax=ax[0,1], vmin=0.05, vmax=0.7, add_labels=False)
(ndvi_mean - pred_mean).plot.imshow(ax=ax[1,0], cmap='RdBu', vmin=-0.1, vmax=0.1, add_labels=False)
corr.plot.imshow(ax=ax[1,1], cmap='magma', robust=True, add_labels=False)
ax[0,0].set_title('Observed Mean NDVI (1990-2021)')
ax[0,1].set_title('Predicted Mean NDVI (1990-2021)')
ax[1,0].set_title('Difference (Obs-Pred)')
ax[1,1].set_title('Correlation')
plt.tight_layout()

### Animations

## Gapfill with synthetic data