# Machine Learning - Training and Mapping

This notebook trains machine learning models (currently random forest classifiers) on several kinds of data to create flood hazard maps.

The algorithm used in this notebooks is as follows:
1. Load bands from various satellites and derive relevant information from them (e.g. water indicies like MNDWI). This includes elevation data from [SRTM](https://www2.jpl.nasa.gov/srtm/) and precipitation data from CHIRPS ([info here](https://www.chc.ucsb.edu/data/chirps), [data here](https://data.chc.ucsb.edu/products/CHIRPS-2.0/)).
2. Calculate the "summary statistics" for each data variable across time (e.g. water index for a particular satellite). These statistics can be min, mean, max, and standard deviation - or just mean for binary variables. This results in a collection of 2D datasets (spatial dimensions - "composites" in a sense).
3. Run a [linear discriminant analysis](https://en.wikipedia.org/wiki/Linear_discriminant_analysis) to identify the summary stats that best help predict flood hazard areas. This can provide insight into which summary statistics are the most helpful in identifying flood hazard areas, including the degree of hazard. **However, the LDA analysis only indicates how well the input data (features) allows the output data (given hazard map) to be classified. It does not indicate how well a model trained on this data may generalize to unseen inputs.**
4. Train the machine learning models.
5. Create flood hazard maps with the trained models.

The machine learning models are created in the **models** directory (top level of the repository, not within this notebook's directory - same applies for other output directories). The models are named **hazard_{cross validation score}_{model parameter set}.joblib**. You can find code to load them in this notebook (**joblib.load()**).

The hazard maps for Dakar are created in the **outputs** directory. The images are named **hazard_{cross validation score}_{model parameter set}.png**. Models with higher scores should generalize better than ones with lower scores. Comparing these images to the given flood hazard map (EO4SD) should provide an indication of the best the model will perform if the given flood hazard map provides sufficient data for the machine learning model to generalize to the coastline of Senegal (on whatever data the model will be run).

The hazard maps for the coastline of Senegal are created in the **hazard_maps** directory in subdirectories named for the area (see the `areas` variable in the **Create the maps** section to see and change the areas). The images are named **{cross validation score}_{model parameter set}.png**. Models with higher scores should generalize better than ones with lower scores (so areas other than the Dakar training area should have better outputs if the models input features and training data are suitable for generalization). **Comparing these images to the given flood hazard map (EO4SD) should provide an indication of how well the models will perform if the given flood hazard map provides sufficient data for the machine learning model to generalize to the coastline of Senegal (on whatever data the model will be run).**

# Index

* Import dependencies, setup Dask client, and connect to the data cube
* Load flood hazard data from World Bank
* Show area to load data for
* Load geospatial data
    * Sentinel-2
* Load elevation data (SRTM)
* Load precipitation data from CHIRPS
* Combine datasets
* Train a flood risk classifier
* Create the maps

## Import dependencies, setup Dask client, and connect to the data cube

In [None]:
from collections import ChainMap

import matplotlib.pyplot as plt
import geopandas as gpd
import xarray as xr
import pandas as pd
import numpy as np
import joblib
import os

import sys
sys.path.append('..')
from utils.ceos_utils.dc_display_map import display_map
from utils.deafrica_utils.deafrica_bandindices import \
    calculate_indices
from utils.deafrica_utils.deafrica_datahandling import load_ard

import datacube
dc = datacube.Datacube()

In [None]:
from utils.ceos_utils.dask import create_local_dask_cluster

client = create_local_dask_cluster()

## Load flood hazard data from World Bank

In [None]:
dakar_flood_hazard = gpd.read_file('../floodareas/eo4sd_dakar_fhazard_2018/EO4SD_DAKAR_FHAZARD_2018.shp')

**Remove records with no geometry data**

In [None]:
dakar_flood_hazard = dakar_flood_hazard[[dakar_flood_hazard.geometry[i] is not None for i in range(len(dakar_flood_hazard))]]

**Change the CRS to EPSG:4326**

In [None]:
dakar_flood_hazard = dakar_flood_hazard.to_crs("EPSG:4326")

**Get the bounding box of the data**

In [None]:
dakar_bounds = dakar_flood_hazard.bounds
min_lon = dakar_bounds.minx.min()
max_lon = dakar_bounds.maxx.max()
min_lat = dakar_bounds.miny.min()
max_lat = dakar_bounds.maxy.max()
lat = (min_lat, max_lat)
lon = (min_lon, max_lon)

## Show area to load data for

In [None]:
## Dakar, Senegal
# Small test
# lat = (14.8270, 14.8422)
# lon = (-17.2576, -17.2172)
# Tip
# lat = (14.6433, 14.7892)
# lon = (-17.5408, -17.4158)
# Full
lat = (14.6285, 14.8725)
lon = (-17.5348, -17.2068)

## Coast of Sengal
# North
# lat = (14.3559, 16.0974)
# lon = (-17.5683, -16.4543)
# Full
# lat = (12.3016, 16.1810)
# lon = (-17.8198, -16.3257)

In [None]:
display_map(lat, lon)

## Load geospatial data

**Specify time range and common load parameters**

In [None]:
years = range(2013, 2020) # (inclusive, exclusive)
time_ranges = [(f"{year}-01-01", f"{year}-12-31") for year in years]
common_load_params = \
    dict(output_crs="EPSG:4326",
         resolution=(-0.00027,0.00027),
         latitude=lat, longitude=lon,
         dask_chunks={'time':40, 
                      'latitude':1000, 
                      'longitude':1000})

>### WOfS

In [None]:
# from utils.ceos_utils.dc_load import is_dataset_empty

# ls_data = []
# for time_range in time_ranges:
#     data = dc.load(product='ga_ls8c_wofs_2', 
#                    measurements=['water'], 
#                    time=time_range,
#                    **common_load_params)
#     if not is_dataset_empty(data):
        
#         # Formatting water data #
#         # bit 7 indicates water, bit 2 indicates sea.
#         ls_water_cls = (data.water&0b10000010)!=0
#         # Set no_data (missing) values to NaN.
#         ls_water_cls = \
#             ls_water_cls.where(data.water!=1)
#         data['water'] = ls_water_cls
#         # End formatting water data #
        
#         ls_data.append(data.rename({'water':'wofs'}))
# ls_data = xr.concat(ls_data, dim='time')

In [None]:
# ls_data = dc.load(product='ga_ls8c_wofs_2', 
#                   measurements=['water'], 
#                   time=full_time_range,
#                   **common_load_params)
# ls_data = ls_data.sel(time=[list(time_range) for time_range in time_ranges])
# # ls_data = xr.merge((ls_data_red, ls_data_water))

**Rename data variables to distinguish them from those of other datasets when we merge**

In [None]:
# ls_data = ls_data.rename({data_var: f"ls_{data_var}" for data_var in ls_data.data_vars})

**Calculate summary statistics**

In [None]:
# ls_stats = [{
#              f'ls_wofs_mean':  ls_data.ls_wofs.mean('time') 
#            }]
# ls_stats = xr.Dataset(dict(ChainMap(*ls_stats)))

**Impute remaining NaNs with the means**

In [None]:
# for data_var in ls_stats.data_vars:
#     ls_stats[data_var] = ls_stats[data_var]\
#         .where(~np.isnan(ls_stats[data_var]), ls_stats[data_var].mean())

>### Sentinel-2

In [None]:
s2_data = []
for time_range in time_ranges:
    try:
        data = load_ard(dc, products=['s2_l2a'],
                        measurements=[
                            # Used by MNDWI, AWEI_ns, AWEI_sh
                            'green', 'swir_1', 
                            # Used by AWEI_ns, AWEI_sh
                            'nir', 'swir_2',
                            # Used by AWEI_sh
                            'blue',
                            'AOT', 'SCL'], 
                        time=time_range,
                        **common_load_params).persist() # This will likely require a lot of RAM or storage.
        data = calculate_indices(data, index='MNDWI', collection='s2')
        data = calculate_indices(data, index='AWEI_ns', collection='s2')
        data = calculate_indices(data, index='AWEI_sh', collection='s2')
        data = data[['MNDWI', 'AWEI_ns', 'AWEI_sh', 'SCL', 'AOT']]
        s2_data.append(data)
    except:
        continue
s2_data = xr.concat(s2_data, dim='time')    

**Rename data variables to distinguish them from those of other datasets when we merge**

In [None]:
s2_data = s2_data.rename({data_var: f"s2_{data_var}" for data_var in s2_data.data_vars})

**Calculate mean of bare soil, water, and bare soil to water transitions across time**

In [None]:
s2_bare_soil = s2_data.s2_SCL == 5
s2_water = s2_data.s2_SCL == 6

In [None]:
s2_bare_soil_mean = s2_bare_soil.mean('time')
s2_water_mean = s2_water.mean('time')

In [None]:
s2_soil_to_water = s2_water.isel(time=slice(1, len(s2_data.time))) & \
                   s2_bare_soil.isel(time=slice(0, len(s2_data.time)-1)).data   

In [None]:
s2_soil_to_water_mean = s2_soil_to_water.mean('time')

**Calculate summary statistics**

In [None]:
from scipy.stats import mode

s2_stats = [{
             f'{data_var}_min':   s2_data[data_var].min('time'), 
             f'{data_var}_mean':  s2_data[data_var].mean('time'), 
             f'{data_var}_std':   s2_data[data_var].std('time'), 
             f'{data_var}_max':   s2_data[data_var].max('time')
            }
            for data_var in s2_data.data_vars if data_var != 's2_SCL'] + \
           [{
             # The most common classification for each pixel.
#              's2_SCL_mode': xr.DataArray(mode(s2_data.s2_SCL, axis=s2_data.s2_SCL.get_axis_num('time'))[0].squeeze(), 
#              coords={'latitude':s2_data.latitude, 'longitude':s2_data.longitude}, 
#              dims=['latitude', 'longitude']),
             's2_soil_to_water_mean': s2_soil_to_water_mean, 
             's2_bare_soil_mean': s2_bare_soil_mean,
             's2_water_mean': s2_water_mean,
            }]

s2_stats = xr.Dataset(dict(ChainMap(*s2_stats)))

**Impute remaining NaNs with the means**

In [None]:
for data_var in s2_stats.data_vars:
    s2_stats[data_var] = s2_stats[data_var]\
        .where(~np.isnan(s2_stats[data_var]), s2_stats[data_var].mean())

## Load elevation data (SRTM)

In [None]:
# Only 1 time, so we remove it with `.isel(time=0)`.
srtm_data = \
    dc.load(product='srtm', 
            **common_load_params).elevation.isel(time=0)
# Remove no_data values.
srtm_data = srtm_data.where(srtm_data!=-32768)

**Impute remaining NaNs with the means**

In [None]:
srtm_data = srtm_data.where(~np.isnan(srtm_data), srtm_data.mean())

## Load precipitation data from CHIRPS

In [None]:
import itertools
chirps_months = list(map(lambda month_str: month_str.zfill(2), map(str, range(1, 13))))
chirps_data = xr.concat([xr.open_rasterio(f'../precipitation/chirps/chirps-v2.0.{year}.{month_str}.tif').squeeze()
                         .sel(y=slice(*lat[::-1]), x=slice(*lon)) 
                         for year, month_str in itertools.product(years, chirps_months)], 
                        dim='time')

**Load and show CHIRPS data for the coast of Senegal**

In [None]:
# chirps_months = map(lambda month_str: month_str.zfill(2), map(str, range(1, 13)))
# coast_lat = (12.3016, 16.1810)
# coast_lon = (-17.8198, -16.3257)
# chirps_data_coast = xr.concat([xr.open_rasterio(f'../precipitation/chirps/chirps-v2.0.{year}.{month_str}.tif').squeeze() 
#                                .sel(y=slice(*coast_lat[::-1]), x=slice(*coast_lon)) 
#                          for year, month_str in itertools.product(years, chirps_months)], 
#                         dim='time')

In [None]:
# chirps_data_coast = chirps_data_coast.where(chirps_data_coast != -9999).rename({'x': 'longitude', 'y':'latitude'})

In [None]:
# plt.figure(figsize=(3,8))
# chirps_data_coast.rename('mm / month').mean('time').plot.imshow(vmin=0, vmax=150)
# plt.title('CHIRPS Precipitation')
# plt.show()

**Set missing data points to NaN**

In [None]:
chirps_data = chirps_data.where(chirps_data != -9999).rename({'x': 'longitude', 'y':'latitude'})

**Calculate summary statistics**

In [None]:
chirps_stats = [{
             f'chirps_min':   chirps_data.min('time'), 
             f'chirps_mean':  chirps_data.mean('time'), 
             f'chirps_std':   chirps_data.std('time'), 
             f'chirps_max':   chirps_data.max('time')
            }]
chirps_stats = xr.Dataset(dict(ChainMap(*chirps_stats)))

**Impute remaining NaNs with the means**

In [None]:
for data_var in chirps_stats.data_vars:
    chirps_stats[data_var] = chirps_stats[data_var]\
        .where(~np.isnan(chirps_stats[data_var]), chirps_stats[data_var].mean())

**Fit the CHIRPS data to the Landsat and Sentinel-2 data**

In [None]:
chirps_stats = chirps_stats.reindex(latitude=s2_stats.latitude, longitude=s2_stats.longitude, method='nearest')

## Combine datasets

In [None]:
merged_stats = xr.merge((s2_stats, srtm_data, chirps_stats), join='left').persist()
output_shape = merged_stats.s2_MNDWI_mean.shape
output_coords = merged_stats.coords
output_dims = merged_stats.dims

## Train a flood risk classifier

**Get an encoding for the flood risk classes and a raster mask of the flood risk areas**

In [None]:
# Mask out areas that are commonly water from the risk maps
# `dakar_flood_hazard['RISKCODE_H']==1` is very wrong without this.
s2_land_mask = merged_stats.s2_MNDWI_mean < 0.001

In [None]:
from utils.deafrica_utils.deafrica_spatialtools import xr_rasterize

flood_hazard_enc = {'0':0, 'Low Risk':1, 'Medium Risk':2, 'High Risk':3}
flood_hazard_masks = \
{'0': xr_rasterize(dakar_flood_hazard[dakar_flood_hazard['RISKCODE_H']==0], 
                   merged_stats).astype(np.bool).where(s2_land_mask, 0),
 'Low Risk': xr_rasterize(dakar_flood_hazard[dakar_flood_hazard['RISKCODE_H']==1], 
                          merged_stats).astype(np.bool).where(s2_land_mask, 0),
 'Medium Risk': xr_rasterize(dakar_flood_hazard[dakar_flood_hazard['RISKCODE_H']==2], 
                             merged_stats).astype(np.bool).where(s2_land_mask, 0),
 'High Risk': xr_rasterize(dakar_flood_hazard[dakar_flood_hazard['RISKCODE_H']==3], 
                           merged_stats).astype(np.bool).where(s2_land_mask, 0)}

**Convert values to labels**

In [None]:
flood_hazard_enc_rev = {v: k for k, v in flood_hazard_enc.items()}
dakar_flood_hazard_enc = dakar_flood_hazard.copy()
dakar_flood_hazard_enc['RISKCODE_H'] = dakar_flood_hazard['RISKCODE_H'].map(flood_hazard_enc_rev)

**Note that the border of low risk flooding around the coast is removed in processing later**

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
dakar_flood_hazard_enc[dakar_flood_hazard_enc['RISKCODE_H']!='0']\
    .plot(column='RISKCODE_H', legend=True, ax=ax)
plt.title('Flood Hazard')
plt.show()

**Format the feature matrix**

In [None]:
X = merged_stats.to_array().transpose('latitude', 'longitude', 'variable')
X = X.stack(row=('latitude', 'longitude')).transpose('row', 'variable').persist()
X_local = X.compute()

**Format the truth matrix (risk classifications)**

In [None]:
y = X.isel(variable=0)
y = y.where(False, 0)
for key, mask in flood_hazard_masks.items():
    mask = mask.stack(row=('latitude', 'longitude'))
    y = y.where(~mask.astype(bool), flood_hazard_enc[key])
y = y.persist()
y_local = y.compute()

**Remove unneeded data**

In [None]:
# Clear the persisted data in `merged_stats` to save memory.
if 'merged_stats' in globals():
    del merged_stats

**Determine the relative importance of the data variables**

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis()
lda = lda.fit(X_local, y_local)

In [None]:
# Get the relative frequency of the classes to use as weights for the LDA coefficients.
_, cls_frq_wgts = np.unique(y, return_counts=True)
cls_frq_wgts = cls_frq_wgts / cls_frq_wgts.sum()

In [None]:
# Weight the LDA coefficients by class frequency.
lda_coef = (cls_frq_wgts*lda.coef_.T).T.mean(axis=0)

In [None]:
# The LDA coefficients for the features in descending order.
desc_lda_coef_inds = np.argsort(abs(lda_coef))[::-1]

In [None]:
lda_table = pd.DataFrame({'name': X['variable'].values, 'coef': lda_coef})

In [None]:
# Sort the data variable names by the absolute value of the sum of their coefficients.
lda_table['abs_coef'] = abs(lda_table.coef)
lda_table = lda_table.sort_values('abs_coef', ascending=False)
lda_table = lda_table.set_index('name')
lda_table

In [None]:
lda_table.abs_coef.plot()
plt.xticks(ticks=range(len(lda_table)), 
           labels=lda_table.index.values, 
           rotation=70, ha='right')
plt.title('LDA importance of inputs')
plt.show()

**Train and save a classifier for each parameter set in a parameter grid and output the predictions as an image**

In [None]:
y_vis = xr.DataArray(data=y.values.reshape(output_shape), 
                     coords=output_coords, dims=output_dims)

In [None]:
y_vis.plot.imshow(vmin=0, vmax=3, figsize=(12,7))
plt.title('Input Classifications (Given Flood Hazard Map)')
plt.show()

<p style="color:red"><b>Set the parameter grid here</b></p>

In [None]:
import itertools
from pathlib import Path
from joblib import dump, load

param_grid = {
    'n_estimators': [16],
    'max_depth': [24],
    'min_samples_split': [2, 5, 15, 25, 50],
    'min_samples_leaf': [2, 5, 15, 25, 50],
}

In [None]:
%%time

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedShuffleSplit, GridSearchCV
from sklearn.metrics import make_scorer, f1_score
from functools import partial, update_wrapper

def wrapped_partial(func, *args, **kwargs):
    partial_func = partial(func, *args, **kwargs)
    update_wrapper(partial_func, func)
    return partial_func

clf = RandomForestClassifier(class_weight='balanced', n_jobs=8)
cv = StratifiedShuffleSplit(n_splits=5)
scorer = make_scorer(wrapped_partial(f1_score, average='micro'))

param_set_keys = list(param_grid.keys())
param_sets_vals = list(itertools.product(*param_grid.values()))
for i, param_set_vals in enumerate(param_sets_vals):
    # Train
    param_set = {k:v for k, v in zip(param_set_keys, param_set_vals)}
    print(f"Training with param set {param_set}")
    clf.set_params(**param_set)
    # Using GridSearchCV on empty param grid just to use `cv` conveniently.
    grid_search = \
        GridSearchCV(clf, {}, cv=cv, 
                     scoring=scorer, n_jobs=-1, verbose=1)
    grid_search.fit(X_local,y_local)
    from dask_ml.wrappers import ParallelPostFit
    grid_search_parallel_predictor = \
        ParallelPostFit(grid_search)
    y_pred = grid_search_parallel_predictor.predict(X)
    clf = grid_search.best_estimator_
    score = grid_search.best_score_
    print(f"Score: {score}")
    
    # Save model
    if (0.85 < score) & (score < 0.998):
        param_set_suffix = ''.join([f'__{param}_{val}' for param, val 
                                    in param_set.items()])
        model_dir = '../models'
        Path(model_dir).mkdir(parents=True, exist_ok=True)
        dump(clf, f'{model_dir}/hazard_classifier_{score:0.4f}{param_set_suffix}.joblib')
    
        # Visualize predictions
        y_pred_vis = xr.DataArray(data=y_pred.reshape(output_shape), 
                                  coords=output_coords, dims=output_dims)
        fig = plt.figure(figsize=(12,7))
        y_pred_vis.plot.imshow(vmin=0, vmax=3)
        plt.title('Output Classifications (WOfS + S2 MNDWI, AWEI + SRTM DEM + CHIRPS)\n' + \
                  ''.join([f'{param}: {val}, ' for param, val in param_set.items()]) + \
                  f' CV Score: {grid_search.best_score_:.2%}')
        vis_dir = '../outputs'
        Path(vis_dir).mkdir(parents=True, exist_ok=True)
        plt.savefig(f'{vis_dir}/hazard_{score:0.4f}{param_set_suffix}.png')
        plt.close(fig)
    
    print(f"{(i+1)/len(param_sets_vals):.2%} through parameter sets")

## Create the maps

**Specify the areas and time range**

In [None]:
from collections import OrderedDict
areas = OrderedDict(
    [
        ('Area1', ((15.8193, 16.4064), (-16.5475, -15.7261))),
        ('Area2', ((14.9893, 15.8185), (-17.1027, -16.4953))),
        # Dakar
        ('Area3', ((14.4290, 15.0000), (-17.5581, -17.0007))),
        # Delta du Saloum W
        ('Area4', ((13.5361, 14.3098), (-17.0092, -16.3564))),
        # Delta du Saloum E
        ('Area5', ((13.5361, 14.3098), (-16.3564, -15.9413))),
        ('Area6', ((13.0470, 13.5400), (-16.8582, -15.5425)))
    ])

In [None]:
for area_ind, (area_name, (lat, lon)) in enumerate(areas.items()):
    
    common_load_params = \
    dict(output_crs="EPSG:4326",
         resolution=(-0.00027,0.00027),
         latitude=lat, longitude=lon,
         dask_chunks={'time':40, 
                      'latitude':1000, 
                      'longitude':1000})
    
    ## Load geospatial data ##

    ### Load Sentinel-2 ###
    s2_data = []
    for time_range in time_ranges:
        try:
            data = load_ard(dc, products=['s2_l2a'],
                            measurements=[
                                # Used by MNDWI, AWEI_ns, AWEI_sh
                                'green', 'swir_1', 
                                # Used by AWEI_ns, AWEI_sh
                                'nir', 'swir_2',
                                # Used by AWEI_sh
                                'blue',
                                'AOT', 'SCL'], 
                            time=time_range,
                            **common_load_params)
            data = calculate_indices(data, index='MNDWI', collection='s2')
            data = calculate_indices(data, index='AWEI_ns', collection='s2')
            data = calculate_indices(data, index='AWEI_sh', collection='s2')
            data = data[['MNDWI', 'AWEI_ns', 'AWEI_sh', 'SCL', 'AOT']]
            s2_data.append(data)
        except:
            continue
    s2_data = xr.concat(s2_data, dim='time')
    
    # Rename data variables to distinguish them from those of other datasets when we merge
    s2_data = s2_data.rename({data_var: f"s2_{data_var}" for data_var in s2_data.data_vars})

    ### End Load Sentinel-2 ###

    ### Sentinel-2 Stats ###

    # Calculate mean of bare soil, water, and bare soil to water transitions across time
    s2_bare_soil = s2_data.s2_SCL == 5
    s2_water = s2_data.s2_SCL == 6
    s2_bare_soil_mean = s2_bare_soil.mean('time')
    s2_water_mean = s2_water.mean('time')
    s2_soil_to_water = s2_water.isel(time=slice(1, len(s2_data.time))) & \
                       s2_bare_soil.isel(time=slice(0, len(s2_data.time)-1)).data   
    s2_soil_to_water_mean = s2_soil_to_water.mean('time')
    
    # Calculate summary statistics
    from scipy.stats import mode
    s2_stats = [{
                 f'{data_var}_min':   s2_data[data_var].min('time'), 
                 f'{data_var}_mean':  s2_data[data_var].mean('time'), 
                 f'{data_var}_std':   s2_data[data_var].std('time'), 
                 f'{data_var}_max':   s2_data[data_var].max('time')
                }
                for data_var in s2_data.data_vars if data_var != 's2_SCL'] + \
               [{
                 # The most common classification for each pixel.
#                  's2_SCL_mode': xr.DataArray(mode(s2_data.s2_SCL, axis=s2_data.s2_SCL.get_axis_num('time'))[0].squeeze(), 
#                  coords={'latitude':s2_data.latitude, 'longitude':s2_data.longitude}, 
#                  dims=['latitude', 'longitude']),
                 's2_soil_to_water_mean': s2_soil_to_water_mean, 
                 's2_bare_soil_mean': s2_bare_soil_mean,
                 's2_water_mean': s2_water_mean,
                }]

    s2_stats = xr.Dataset(dict(ChainMap(*s2_stats)))
    
    # Impute remaining NaNs with the means
    for data_var in s2_stats.data_vars:
        s2_stats[data_var] = s2_stats[data_var]\
            .where(~np.isnan(s2_stats[data_var]), s2_stats[data_var].mean())

    ### End Sentinel-2 Stats ###

    ### Load Elevation Data (SRTM) ###

    # Only 1 time, so we remove it with `.isel(time=0)`.
    srtm_data = \
        dc.load(product='srtm', 
                **common_load_params).elevation.isel(time=0)
    # Remove no_data values.
    srtm_data = srtm_data.where(srtm_data!=-32768)
    # Impute remaining NaNs with the means
    srtm_data = srtm_data.where(~np.isnan(srtm_data), srtm_data.mean())
    
    ### End Load Elevation Data (SRTM) ###

    ### Load Precipitation Data from CHIRPS ###
    import itertools
    chirps_months = map(lambda month_str: month_str.zfill(2), map(str, range(1, 13)))
    chirps_data = xr.concat([xr.open_rasterio(f'../precipitation/chirps/chirps-v2.0.{year}.{month_str}.tif').squeeze()
                             .sel(y=slice(*lat[::-1]), x=slice(*lon)) 
                             for year, month_str in itertools.product(years, chirps_months)], 
                            dim='time')
    
    # Set missing data points to NaN
    chirps_data = chirps_data.where(chirps_data != -9999).rename({'x': 'longitude', 'y':'latitude'})
    
    # Calculate summary statistics
    chirps_stats = [{
             f'chirps_min':   chirps_data.min('time'), 
             f'chirps_mean':  chirps_data.mean('time'), 
             f'chirps_std':   chirps_data.std('time'), 
             f'chirps_max':   chirps_data.max('time')
            }]
    chirps_stats = xr.Dataset(dict(ChainMap(*chirps_stats)))
    
    # Impute remaining NaNs with the means
    for data_var in chirps_stats.data_vars:
        chirps_stats[data_var] = chirps_stats[data_var]\
            .where(~np.isnan(chirps_stats[data_var]), chirps_stats[data_var].mean())
    
    # Fit the CHIRPS data to the Landsat and Sentinel-2 data
    chirps_stats = chirps_stats.reindex(latitude=s2_stats.latitude, longitude=s2_stats.longitude, method='nearest')    
    
    ### End Load Precipitation Data from CHIRPS ###
    
    ## End Load geospatial data ##
    
    ## Combine datasets ##
    
    print("\nCalculating stats...")
    merged_stats = xr.merge((s2_stats, srtm_data, chirps_stats)).persist()
    output_shape = merged_stats.s2_MNDWI_mean.shape
    output_coords = merged_stats.coords
    output_dims = merged_stats.dims

    ## End Combine datasets (L8, S2, SRTM) ##

    ## Format the Feature Matrix ##
    print("\nFormatting data for model...\n")
    X = merged_stats.to_array().transpose('latitude', 'longitude', 'variable')
    X = X.stack(row=('latitude', 'longitude')).transpose('row', 'variable').persist()
    # Clear the persisted data in `merged_stats` to save memory.
    del merged_stats

    ### Remove unneeded data ###

    # Clear the persisted data in `merged_stats` to save memory.
    if 'merged_stats' in globals():
        del merged_stats
        
    ### End Remove unneeded data ###
    
    ## End Format the Feature Matrix ##
    
    ## Create and save a map for each classifier ##
    ## defined by the parameter grid ##
    
    param_set_keys = list(param_grid.keys())
    param_sets_vals = list(itertools.product(*param_grid.values()))
    for i, param_set_vals in enumerate(param_sets_vals):
        param_set = {k:v for k, v in zip(param_set_keys, param_set_vals)}
        param_set_suffix = ''.join([f'__{param}_{val}' for param, val 
                                    in param_set.items()])
        
        # Load the model if it exists.
        import glob
        clf_filepath = None
        for path in glob.glob(f'../models/hazard_classifier_*{param_set_suffix}.joblib'):
            clf_filepath = path
        if clf_filepath is None:
            continue
        if os.path.exists(clf_filepath):
            from dask_ml.wrappers import ParallelPostFit
            clf = ParallelPostFit(load(clf_filepath))
        else:
            continue

        print(f"Creating hazard map for param set {param_set}.")
            
        # Get predictions.
        y_pred = clf.predict(X)
        
        # Get the score from the model filename.
        import re
        score = float(re.compile(".*classifier_(.*?)__.*").search(clf_filepath).group(1))
        
        # Save output (flood hazard map)
        y_pred_vis = xr.DataArray(data=y_pred.reshape(output_shape), 
                                  coords=output_coords, dims=output_dims)
        fig = plt.figure(figsize=(12,7))
        y_pred_vis.plot.imshow(vmin=0, vmax=3)
        plt.title('Flood Hazard Map (WOfS + S2 MNDWI, AWEI + SRTM DEM + CHIRPS)\n' + \
                  ', '.join([f'{param}: {val}' for param, val in param_set.items()]))
        vis_dir = f'../hazard_maps/{area_name}'
        Path(vis_dir).mkdir(parents=True, exist_ok=True)
        plt.savefig(f'{vis_dir}/{score:0.4f}{param_set_suffix}.png')
        fig.clf()
        plt.close(fig)

        print(f"{(i+1)/len(param_sets_vals):.2%} through parameter sets.")

    ## End Create and save a map for each classifier ##
    ## defined by the parameter grid ##
    print(f"\n{(area_ind+1)/len(areas):.2%} through areas.")
    print()
    client.restart() # Clear Dask memory