# Regression CNN Pipeline for VHR CHM

This notebook is used to illustrate the full development lifecycle of a regression Convolutional Neural Network (CNN) to produce Canopy Height models (CHM) from very high-resolution data. The data used in this work comes from WorldView imagery, while the labels are taken from highly preprocessed ICESat ATL08 points generated by Montesano et al. The initial study area is Senegal.

Science Question:
- Can we generate VHR CHM models using ICESat-2 points as training data and CNNs as the regression algorithm?

Possible Research Directions:
- Datacube format (e.g. optical, land cover, DEMs, veg. indices, resolution, photon-level data)
- Algorithm (e.g. random forest, neural network, 2d-CNN, 3d-CNN)
- Training data (e.g. tile size, matching training data, transfer learning for more-than-regional)

Challenges:
- Resolution (e.g. ICESat 30m vs WorldView 2m)
- CNNs were designed for classification, we are making in-house modifications for regression
- Forest patch vs. sparse forest patch

Let's discuss and have some fun!

## Data Science Development Phases from a Computer Science Perspective

- Data Gathering
- Exploratory Data Analysis
- Preprocessing
- Training
- Inference
- Validation

## 1. Data Gathering

Below we illustrated the ATL08 points available, together with the footprint of the World View imagery available for the selected study area (future, gather from Maggie). Some of the local filtering is perform with the following code:

```bash
pdsh -g ilab,forest do_extract_filter_atl08.sh \"2018 2019 2020 2021\" /att/nobackup/pmontesa/userfs02/data/icesat2/list_atl08.005 senegal

pdsh -g ilab,forest do_extract_filter_atl08.sh \"2018 2019 2020 2021\" /att/nobackup/pmontesa/userfs02/data/icesat2/list_atl08.005 senegal_no_filt
```

In [None]:
import os
import sys
import omegaconf
from glob import glob
from pathlib import Path

sys.path.append('/adapt/nobackup/people/jacaraba/development/tensorflow-caney')
sys.path.append('/adapt/nobackup/people/jacaraba/development/vhr-cnn-chm')
sys.path.append('/home/pmontesa/code/icesat2')

import rasterio
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from vhr_cnn_chm.model.geoscitools import maplib, atl08lib
from vhr_cnn_chm.model.cnn_regression_pipeline import CNNRegressionPipeline
from shapely.geometry import Point
from rasterio.plot import show

%matplotlib inline

### 1.1 Define General Variables

To match the CLI development script, we use omegaconf to structure the variables required for this work. Ideally, this would be done from a configuration file for the CLI script. The main idea of the CLI script is to avoid falling into the timeout issues encountered in JupyterHub, including the use of more than 1-GPU.

In [None]:
# General configurations
omega_conf_string = """
# General CRS to leverage across the data
general_crs: "EPSG:32628"

# ATL08 configurations
atl08_dir: '/adapt/nobackup/people/pmontesa/userfs02/data/icesat2/atl08.005/senegal'
chm_footprints_fn: '/adapt/nobackup/people/pmontesa/chm_work/hrsi_chm_senegal/merge.shp'

# ATL08 available data
atl08_start_year: 2018
atl08_end_year: 2022

# WorldView available data
wv_data_regex:
  - '/adapt/nobackup/people/mwooten3/Senegal_LCLUC/VHR/CAS/M1BS/*.tif'
  - '/adapt/nobackup/people/mwooten3/Senegal_LCLUC/VHR/ETZ/M1BS/*.tif'
  - '/adapt/nobackup/people/mwooten3/Senegal_LCLUC/VHR/SRV/M1BS/*.tif'

# Output directories to store data
intersection_output_dir: '/adapt/nobackup/projects/ilab/projects/Senegal/CNN_CHM/v1/intersection_metadata_evhrtoa'
tiles_output_dir: '/adapt/nobackup/projects/ilab/projects/Senegal/CNN_CHM/v1/intersection_tiles_evhrtoa'

# Data extraction metadata
tile_buffer: 520
"""
conf = omegaconf.OmegaConf.create(omega_conf_string)

In [None]:
omegaconf.OmegaConf.to_yaml(conf)

### 1.2 Build ATL08 geodataframe from extracted CSVs

These CSVs were heavily filtered with land-cover specific thresholds for h_can (canopy height).

In [None]:
atl08_gdf = []
for year in range(conf.atl08_start_year, conf.atl08_end_year):
    atl08_gdf.append(atl08lib.atl08_io(conf.atl08_dir, str(year), do_pickle=False))
atl08_gdf = pd.concat(atl08_gdf)
atl08_gdf.info()

### 1.3 Make an interactive map to view the heavily filtered set of ATL08 obs.

Make sure to set SAMP_FRAC so you dont map all the points. This is the footprints vector of the ~2m HRSI DSM-derived "CHM" data we are playing with.

In [None]:
# UNCOMMENT if you want to visualize the points
# %%time
# SAMP_FRAC=0.25
# maplib.MAP_ATL08_FOLIUM(atl08_gdf.sample(frac=SAMP_FRAC), MAP_COL='h_can', DO_NIGHT=False, LAYER_FN=conf.chm_footprints_fn, RADIUS=3)

In [None]:
%%time
SAMP_FRAC=0.25
maplib.MAP_ATL08_FOLIUM(atl08_gdf.sample(frac=SAMP_FRAC), MAP_COL='h_can', DO_NIGHT=False, LAYER_FN=conf.chm_footprints_fn, RADIUS=3)

## 2. Data Preprocessing

In this section we gather the training data and EVHR + ICESat intersections. We start from the already generated EVHR scenes, to then find the intersections of the above mentioned points.

100 x 12 m wide

Question here:
- How do we select the tile size to choose from?
- What are the benefits?
- What are the downsides or consequences from these selections?

We have created a pipeline object to ease the development regardless of the environment (e.g. CLI, JH), which we initialize below.

In [None]:
cnn_pipeline = CNNRegressionPipeline(conf)
dir(cnn_pipeline)

### 2.1 Read ATL08 points

Here we read ATL08 points.

In [None]:
%%time
cnn_pipeline.atl08_gdf = cnn_pipeline.get_atl08_gdf(
    conf.atl08_dir,
    conf.atl08_start_year,
    conf.atl08_end_year,
    conf.general_crs
)
print(f'Load ATL08 GDF files, {cnn_pipeline.atl08_gdf.shape[0]} rows.')

## 2.2 Read and Filter WorldView Imagery

We go straigth to the data, We tried reading from the footprints database, but the polygons were not good enough to find intersections between the EVHR output and the ATL08 data.

In [None]:
%%time
cnn_pipeline.wv_evhr_gdf = cnn_pipeline.get_wv_evhr_gdf(
    conf.wv_data_regex, crs=conf.general_crs)
print(f'Load WorldView GDF, {wv_evhr_gdf.shape[0]} rows.')

Filter the WorldView data based on the years available from ICESat-2.

In [None]:
%%time
cnn_pipeline.wv_evhr_gdf = cnn_pipeline.filter_gdf_by_list(
    cnn_pipeline.wv_evhr_gdf, 'acq_year', list(range(conf.atl08_start_year, conf.atl08_end_year)))
print(f'Filter WorldView GDF by year, {cnn_pipeline.wv_evhr_gdf.shape[0]} rows.')

In [None]:
cnn_pipeline.wv_evhr_gdf.plot(color='white', edgecolor='black')

## 2.3 Get the Intersection of both datasets

Get the intersection of the two GDBs, and output geopackage files.

In [None]:
# UNCOMMENT if you want to extract all world view files that have ATL08 points
# this step only needs to be done once
# cnn_pipeline.get_point_in_polygon_by_scene()

Lets visualize some of that.

In [None]:
# UNCOMMENT if you want to visualize some of the WorldView ICESAT-2 intersections
# wv_data_dir = '/adapt/nobackup/people/mwooten3/Senegal_LCLUC/VHR/CAS/M1BS'
# intersection_gpkg_filenames = glob(os.path.join(conf.intersection_output_dir, f'*/*.gpkg'))
# print(f'{len(intersection_gpkg_filenames)} intersected WorldView scenes')

### 2.4 WorldView and ICESat Intersection Analysis

We now know the locations of the ICESat points, we can proceed to visualize any intersection points within the already generated Senegal EVHR scenes.

In [None]:
"""
for intersected_atl08_filename in intersection_gpkg_filenames[:5]:
    
    wv_filename = os.path.join(wv_data_dir, f'{Path(intersected_atl08_filename).stem}.tif')
    
    raster_data = rasterio.open(wv_filename)
    atl08_data = gpd.read_file(intersected_atl08_filename, layer='ATL08_WorldView')

    fig, ax = plt.subplots(figsize=(20,15))
    
    # transform rasterio plot to real world coords
    extent=[raster_data.bounds[0], raster_data.bounds[2], raster_data.bounds[1], raster_data.bounds[3]]
    ax = rasterio.plot.show(raster_data, extent=extent, ax=ax, cmap='pink')
    atl08_data.plot(ax=ax)

    plt.title(f'{wv_filename}, {atl08_data.shape[0]}')
    plt.show()
"""
# UNCOMMENT if you want to visualize some of the WorldView ICESAT-2 intersections

### 2.5 Extract Tiles Matching ICESat-2 extent

In this section we extract training and validation tiles matching the desired extent. Let's look at the script.

## 3. Training the CNN Model

In this section we proceed to train the CNN model.

## 3.1 CNN Model Explained (what needs to change for regression?)

Common 2D CNN models rely on continuous data. Generally they are used for classification problems, but we can modify the architectures to apply to regression. Let's walk through how to define a network, or how to reuse someone else network.

In [None]:
import os
import glob
import logging
import fiona
import rasterio
import numpy as np
import pandas as pd
import rioxarray as rxr
import geopandas as gpd
import shapely.speedups
from omegaconf.listconfig import ListConfig
from multiprocessing import Pool, Lock, cpu_count

import cuspatial
import tensorflow as tf
from tqdm import tqdm
from pathlib import Path
from shapely.geometry import box
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import mean_squared_error

shapely.speedups.enable()

from vhr_cnn_chm.model.geoscitools.atl08lib import atl08_io
from vhr_cnn_chm.model.cnn_model import get_2d_cnn_tf

Let's take a look at some of these tiles.

In [None]:
data_dir = '/adapt/nobackup/projects/ilab/projects/Senegal/CNN_CHM/tiles_cas/*.tif'
data_filenames = glob.glob(data_dir)
print(f'{len(data_filenames)} extracted data tiles')

In [None]:
for tile in data_filenames[:3]:
    print(tile)

In [None]:
cnn_pipeline = CNNRegressionPipeline(conf)

In [None]:
# TODO:
# - Consider adding the spatial location as a band

# get data and labels
data_array, labels_df = cnn_pipeline.get_data_labels(data_filenames[:100])
print(data_array.shape, labels_df.shape)

### 3.2 Normalization of the Data

CNNs are distance based algorithms. If we do not normalize or standardize the data, the algorithm will be prone to overfitting in the presence of huge variations in the distribution of the data. There are several techniques we can use for this. The simplest one for initial testing is the normalization. When fine-tuning the model, we proceed to standardize based on the data mean and standard deviation.

Question for the group:
- Do we normalize the heigth as well? (heigth meaning the labels)

In [None]:
data_array = data_array / 10000.0

Now we split our dataset, we need a training dataset and a validation/testing dataset. The validation dataset is of great importance during training.

In [None]:
# split data in training and validation
split = train_test_split(labels_df, data_array, test_size=0.25, random_state=42)
(trainAttrY, testAttrY, trainImagesX, testImagesX) = split
print(trainAttrY.shape, testAttrY.shape, trainImagesX.shape, testImagesX.shape)

In [None]:
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import BatchNormalization, Conv2D
from tensorflow.keras.layers import MaxPooling2D, Activation, Dropout
from tensorflow.keras.layers import Dense, Flatten, Input
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

Lets learn how to define a neural network, and add convolutional components to it.

In [None]:
def get_2d_cnn_tf(input_size=(256, 256, 3), filters=(16, 32, 64), regression=False, chanDim=-1):

    # define the model input
    inputs = Input(shape=input_size)

    # loop over the number of filters
    for (i, f) in enumerate(filters):
        
        # if this is the first CONV layer, initialize with input
        if i == 0:
            x = inputs

        # CONV => RELU => BN => POOL
        x = Conv2D(f, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)

    # flatten the volume, then FC => RELU => BN => DROPOUT
    x = Flatten()(x)
    x = Dense(32)(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = Dropout(0.5)(x)

    x = Dense(16)(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = Dropout(0.5)(x)

    # apply another FC layer, this one to match the number of nodes
    # coming out of the MLP
    x = Dense(4)(x)
    x = Activation("relu")(x)

    # regression execution
    if regression:
        x = Dense(1, activation="linear")(x)

    # construct the CNN
    model = Model(inputs=inputs, outputs=x, name="SimpleRegression_2dCNN")

    # return the CNN
    return model

In [None]:
model = get_2d_cnn_tf(
    input_size=(128, 128, 8), filters=(16, 32, 64, 128, 256, 512, 1024), regression=True)
model.summary()

In [None]:
opt = Adam(learning_rate=1e-3, decay=1e-3 / 200)

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        save_best_only=True, mode='min', monitor='val_loss',
        filepath='/adapt/nobackup/projects/ilab/projects/Senegal/CNN_CHM/model/test-{epoch:02d}-{val_loss:.2f}.hdf5'),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
]

metrics = [
    tf.keras.metrics.MeanSquaredError(),
    tf.keras.metrics.RootMeanSquaredError(),
    tf.keras.metrics.MeanAbsoluteError(),
    tf.keras.metrics.MeanAbsolutePercentageError(),
    tf.keras.metrics.CosineSimilarity(axis=1)
]

In [None]:
model.compile(
    loss="mean_absolute_percentage_error", optimizer=opt,
    metrics=metrics
)

# train the model
model.fit(
    x=trainImagesX, y=trainAttrY, 
    validation_data=(testImagesX, testAttrY),
    epochs=10,#6000,
    batch_size=32,
    callbacks=callbacks
)

ypred = model.predict(testImagesX)
print(model.evaluate(testImagesX, testAttrY))

print("MSE: %.4f" % mean_squared_error(testAttrY, ypred))

## Lets visualize some of these feature maps

In [None]:
layer_outputs = [layer.output for layer in model.layers]
layer_outputs

In [None]:
tile = rxr.open_rasterio(data_filenames[0]).values
tile = np.moveaxis(tile, 0, -1)
tile = resize(tile, (128, 128))
tile /= 10000.0
tile = np.expand_dims(tile, 0)
tile.shape

In [None]:
layer = model.layers
filters, biases = model.layers[1].get_weights()
print(layer[1].name, filters.shape)

In [None]:
print(layer)

In [None]:
fig1 = plt.figure(figsize=(8, 12))
columns = 4
rows = 4
n_filters = columns * rows
for i in range(1, n_filters + 1):
    f = filters[:, :, :, i-1]
    fig1 =plt.subplot(rows, columns, i)
    fig1.set_xticks([])  #Turn off axis
    fig1.set_yticks([])
    plt.imshow(f[:, :, 0], cmap='gray') #Show only the filters from 0th channel (R)
    #ix += 1
plt.show()   

In [None]:
conv_layer_index = [1, 5, 9]  #TO define a shorter model
outputs = [model.layers[i].output for i in conv_layer_index]
model_short = Model(inputs=model.inputs, outputs=outputs)
print(model_short.summary())

In [None]:
feature_output = model_short.predict(tile)
feature_output

In [None]:
layer_names = [layer.name for layer in model.layers]
layer_outputs = [layer.output for layer in model.layers]
layer_outputs

In [None]:
feature_map_model = tf.keras.models.Model(inputs=model.inputs, outputs=layer_outputs)

In [None]:
feature_maps = feature_map_model.predict(tile)

In [None]:
for layer_name, feature_map in zip(layer_names, feature_maps):
    if len(feature_map.shape) == 4:
        k = feature_map.shape[-1]  
        size=feature_map.shape[1]
        for i in range(k):
            feature_image = feature_map[0, :, :, i]
            feature_image -= feature_image.mean()
            feature_image /= feature_image.std ()
            feature_image *=  64
            feature_image += 128
            feature_image = np.clip(feature_image, 0, 255).astype('uint8')
            print(feature_image.shape, 
            #image_belt[:, i * size : (i + 1) * size] = feature_image   

In [None]:
scale = 20. / k
plt.figure( figsize=(scale * k, scale) )
plt.title ( layer_name )
plt.grid  ( False )
plt.imshow( image_belt, aspect='auto')