<a href="https://colab.research.google.com/github/MScEcologyAndDataScienceUCL/BIOS0032_AI4Environment/blob/main/06_DL_for_Remote_Sensing/06_DL_for_Remote_Sensing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Week 6: Deep Learning for Remote Sensing

This week, we will learn about remote sensing! W will explore the information content we can obtain
from optical satellite imagery, and how to train a deep learning model to make pixel-wise
predictions (semantic segmentation) using such data.

More specifically, we will train one of the most popular semantic segmentation models, U-net, to
predict forest coverage in the Brazilian rainforest using Sentinel-2 imagery. The output will not
only be a spatial forest map, but a prediction of **change** in time due to deforestation.


## Contents

1. [Setup](#1-setup)
2. [Optical Remote Sensing Data](#2-optical-remote-sensing-data)
3. [Forest Mapping](#3-forest-mapping)
4. [Change Detection & Deforestation Monitoring](#4-change-detection--deforestation-monitoring)
5. [Summary and Outlook](#5-summary-and-outlook)


## Notes

- If a line starts with the fountain pen symbol (üñåÔ∏è), it asks you to implement a code part or
answer a question.
- Lines starting with the light bulb symbol (üí°) provide important information or tips and tricks.

---

## 1. Setup

### 1.1 Enable GPU Runtime

Go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator.

### 1.2 Install and Import Dependencies

In [None]:
%pip install rasterio

In [None]:
import os
import matplotlib
try:
  matplotlib.use('widget')
except:
  %pip install -q ipympl
  os.kill(os.getpid(), 9)

**NOTE:**

If you run the above code block for the first time, you will get a message that Google Colab has
crashed. This is deliberate, because we need to restart the environment after installation of one of
the packages.

Just ignore this message and continue with the next code block.

In [None]:
import os
import matplotlib
matplotlib.use('widget')

import glob
import numpy as np
import torch
import rasterio
import ipywidgets as widgets
from google.colab import output
output.enable_custom_widget_manager()
import matplotlib.pyplot as plt

### 1.3 Mount Google Drive

In [None]:
from google.colab import drive

drive.mount("/content/drive")

Add a shortcut in you drive to this [shared folder](https://drive.google.com/drive/folders/1k2PyRm9AhYYyS_3vGJEqw3a9hAsn737X?usp=sharing).

This will allow you to access the data we will use in this practical.

In [None]:
%%capture

!unzip /content/drive/MyDrive/BIOS0032/2025/data/lab6_data.zip -d /content/week6_data

---

## 2. Optical Remote Sensing Data

As you have seen in the lecture, we have different types of remote sensing data at our disposal:
* Optical
* Synthetic Aperture Radar (SAR)
* Airborne Laser Scanning (ALS), respectively Light Detection and Ranging (LiDAR)

Each of these uses a different portion of the electromagnetic spectrum, ranging from thermal
infrared (longest wavelength) via short-wave and near-infrared to the visible range (red, green,
blue), rarely to Ultraviolet (shortest wavelength). Furthermore, SAR and LiDAR (and sonar by the
way, which is used for underwater sensing) are **active** sensors, that is, they emit their own
radiation and measure properties of what bounces back from the Earth. Optical sensors in turn
measure the reflected _irradiance_ from the sun.

As you can imagine, optical sensors can be categorised by their _spatial_ resolution (just like your
camera ‚Äì how many megapixels, _etc._). However, they can also differ in their **spectral**
resolution, which includes the number of spectral _bands_ and the bandwidth:
* RGB sensors just measure one band each in the red, green, and blue wavelengths.
* **Multispectral** sensors include at least one band beyond the visible range.
* **Hyperspectral** sensors have many narrow, evenly spaced bands across large ranges of the
  electromagnetic spectrum.


###¬†2.1 Sentinel-2

Below, we will take a look at data captured by the
[Sentinel-2](https://www.esa.int/Applications/Observing_the_Earth/Copernicus/Sentinel-2) satellite
missions. Sentinel-2 consists of two satellites, recording multispectral (12 bands) data at 10m
resolution.

Sentinel-2 (and other related missions, such as the SAR-based Sentinel-1) are part of the Copernicus
program and free of charge to use.

üí° You can download Sentinel-1 and -2 imagery free of charge from the [Copernicus
browser](https://browser.dataspace.copernicus.eu/) (registration required).

Let us now first take a look at a satellite dataset downloaded from Copernicus over the Brazilian
Rainforest:

1. Open the "Files" tab in Google Colab (click the folder icon to the left).
2. Navigate to folder
   `week6_data/deforestation/Sentinel-2_stripes/2017-06-28-00:00_2017-06-28-23:59_Sentinel-2_L2A`

As you can see, this folder contains twelve `*.tiff` files. TIFF (Tagged Image File Format) is an
image format that allows storing pixel data without any loss due to compression, as well as
information like geospatial position (**GeoTIFF**). Unlike JPEG, it also can store more than three
channels per image, which is important for multi- and hyperspectral remote sensing data.

üí° You will notice suffix "_L2A". This designates the processing **level** of the satellite dataset.
Loads of levels are available designating various procedures the data has undergone (error
correction, radiometric correction, geocoding, _etc._). Unless you are interested in very specific
parts of the pipeline, or else atmospheric data (_e.g._, on clouds), you usually want to use the
Level 2A (L2A) products. More information on Sentinel-2 processing levels can be found
[here](https://sentiwiki.copernicus.eu/web/s2-processing).

Let's take a look at one of those files:

In [None]:
BASE_FOLDER = '/content/week6_data/deforestation'

In [None]:
file_path = os.path.join(BASE_FOLDER, 'Sentinel-2_stripes', '2017-06-28-00:00_2017-06-28-23:59_Sentinel-2_L2A/2017-06-28-00:00_2017-06-28-23:59_Sentinel-2_L2A_B01_(Raw).tiff')


with rasterio.open(file_path, 'r') as f_band:
    band = f_band.read()

print(f'Data type: {type(band)}')
print(f'Data shape: {band.shape}')
print(f'Data number type: {band.dtype}')
print(f'Data min/max: {band.min()}/{band.max()}')

You will have noticed that we didn't use standard image libraries like PIL anymore. This is
because those libraries cannot deal with GeoTIFFs properly: they cannot read multiband images, nor
any geospatial metadata. Hence, we are using [Rasterio](https://rasterio.readthedocs.io/).

üí° In the background, Rasterio uses the [Geospatial Data Abstraction Library
(GDAL)](https://gdal.org/), the largest and most versatile open-source framework for geospatial
data.

üí° [QGIS](https://qgis.org/) is the best known open-source Geographic Information System (GIS) and
uses GDAL for anything related to spatial processing.

üí° The most comparable equivalent to Rasterio for R is [terra](https://rspatial.github.io/terra/),
also building on GDAL.


Rasterio returns a NumPy array of size `BxHxW` (bands x height x width). We can visualise it:

In [None]:
%matplotlib widget

plt.figure(figsize=(10,8))

@widgets.interact(brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0))
def vis_band(brightness=1.0):
    plt.gca().clear()
    plt.imshow(brightness * band.squeeze().astype(float) / 10000.0,
               cmap='gray',
               vmin=0,
               vmax=1)
    plt.show()

You will notice a couple of things from the above figure:
* The image is in greyscale. As you might have guessed already, this TIFF file just contains a
  single band.
* It is very dark, unless you turn up the brightness. Standard cameras like in your smartphone
  record images in 8-bit unsigned integer, which can store values from 0 to 255. That would be very
  problematic for Earth observation, since we may want to record many more values than just 256. The
  above Sentinel-2 band is encoded in 16-bit unsigned integer, which can store 65,536 different
  intensity values.
* The image contains some strange, diagonal stripes. These are artefacts from the data processing.
  The diagonality comes from the fact that the satellite orbits the Earth while it is rotating
  around its own axis.

It may be difficult to see what we are looking at in the image above. However, we have almost all
the other bands recorded by the satellite, so perhaps we can visualise a colour composite instead?

üñåÔ∏è Go to [this Web page](https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/bands/)
and take a look at all the bands Sentinel-2 measures. Let's write them down in a list below.

In [None]:
# Sentinel-2 band wavelengths and names in order; see here:
# https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/bands/
S2_LAMBDAS = [  # wavelengths [nm] (wavelength is often denoted as lambda)
    443,
    490,
    560,
    665,
    705,
    740,
    783,
    842,
    865,
    945,
    # 1375,
    1610,
    2190
]

S2_BANDS = [
    '443 nm (aerosol)',
    '490 nm (Blue)',
    '560 nm (Green)',
    '665 nm (Red)',
    '705 nm (Red Edge)',
    '740 nm',
    '783 nm',
    '842 nm (NIR)',
    '865 nm',
    '945 nm',
    # '1375 nm',            # band 10 is for cirrus cloud detection; we don't have it for L2A
    '1610 nm (SWIR 1)',
    '2190 nm (SWIR 2)'
]

Now, we can load all the bands in correct order for a given satellite scene:

In [None]:
def load_composite(folder):
    # find all files
    files = os.listdir(folder)

    # load all bands in correct order
    bands = ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12')

    # create layer stack
    layer_stack = []
    for band in bands:
        # find correct file name
        file_name = [file for file in files if band in file][0]
        with rasterio.open(os.path.join(folder, file_name), 'r') as f_layer:
            band = f_layer.read().astype(float) / 10000.0
            layer_stack.append(band)
    return np.concatenate(layer_stack, 0)


# load all bands in order and stack them together into a layer stack
FOLDER_2017 = os.path.join(BASE_FOLDER, 'Sentinel-2_stripes', '2017-06-28-00:00_2017-06-28-23:59_Sentinel-2_L2A/')

data_2017 = load_composite(FOLDER_2017)

print(f'Layer stack shape: {data_2017.shape}')

We now have a **layer stack** of twelve bands in order. From the above list of band names, you
should be able to see which bands we need to obtain a true colour image.


Let's visualise it!

In [None]:
%matplotlib widget

plt.close('all')

points, spectra_2017 = [], []                   # for click events

def redraw_plot():
    global points, spectra_2017
    ax1 = plt.subplot(1,2,2)
    ax1.clear()
    for spectrum in spectra_2017:
        ax1.plot(S2_LAMBDAS, spectrum, '-')
    ax1.set_ylabel('Normalised value')
    ax1.set_xticks(S2_LAMBDAS, S2_BANDS, rotation=90)
    plt.margins(0.05)
    plt.tight_layout(pad=2)


def click(event):
     global tile, points, spectra_2017
     if event.xdata is not None and event.ydata is not None:
        x = int(np.clip(event.xdata, 0, data_2017.shape[2]))
        y = int(np.clip(event.ydata, 0, data_2017.shape[1]))
        spectral_vals_2017 = data_2017[:,y,x]
        spectra_2017.append(spectral_vals_2017)
        points.append([x,y])
        plt.subplot(1,2,1)
        plt.scatter(x, y)
        redraw_plot()

figure = plt.figure(figsize=(16, 6))
figure.canvas.mpl_connect('button_press_event', click)


@widgets.interact(
        red=S2_BANDS,
        green=S2_BANDS,
        blue=S2_BANDS,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_composite(red=S2_BANDS[3],
                  green=S2_BANDS[2],
                  blue=S2_BANDS[1],
                  brightness=1.0):
    global tile_file_name, tile, ts, points, spectra_2017
    band_indices = [S2_BANDS.index(band) for band in [red, green, blue]]
    plt.subplot(1,2,1)
    arr = brightness * np.transpose(data_2017[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.axis('off')
    plt.title('2017-06-28')

Play around with the controls of the above widget as follows:
1. Drag the brightness slider up or down.
2. Select different band combinations for visualisation. Tip: besides red: red, green: green, blue:
   blue (true colour), try also red: NIR, green: red, blue: green (false colour near-infrared).
3. Click into the image at different locations. A second plot should appear to the right showing
   intensity values for each point and wavelength.


üñåÔ∏è  What does this scene depict?

_Answer:_

...

üñåÔ∏è  What are typical band reflectance characteristics of the different **land cover** classes you
can see?

_Answer:_

...

### 2.2 Spectral Indices

From the plots above, you may have noticed that some land cover categories show very strong
characteristics in how they reflect radiation for each wavelength. Moreover, it can be tricky to
disentangle land cover types based on absolute reflectance values ‚Äì it often is much more
straightforward to look at the _relative relation_ between spectral bands.

This is the principle of **spectral indices**, that is, quotients between spectral bands.

The perhaps most famous spectral index ever proposed is the **Normalised Difference Vegetation Index
(NDVI)**.

üñåÔ∏è Look up the definition of NDVI and complete the code below to calculate it for our Sentinel-2
scene.

In [None]:
def ndvi(data_array):
    return ...          # implement NDVI computation here


# let's test it
ndvi_2017 = ndvi(data_2017)

# check whether the output shape is correct
assert ndvi_2017.shape == data_2017.shape[1:], \
    f'Error: output shape should be {data_2017.shape[1:]}, got {ndvi_2017.shape}.'

# check whether the values are within the right range
assert np.nanmin(ndvi_2017) >= -1 and np.nanmax(ndvi_2017) <= 1, \
        'Error: NDVI output values should be within [0, 1].'

In [None]:
%matplotlib widget


plt.figure(figsize=(15,7))


@widgets.interact(
        red=S2_BANDS,
        green=S2_BANDS,
        blue=S2_BANDS,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_composite_ndvi(red=S2_BANDS[3],
                       green=S2_BANDS[2],
                       blue=S2_BANDS[1],
                       brightness=1.0):
    band_indices = [S2_BANDS.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(data_2017[band_indices,...], (1,2,0))
    plt.subplot(1,2,1)
    plt.imshow(np.clip(arr, 0, 1))
    plt.title('Sentinel-2 Image')
    plt.subplot(1,2,2)
    plt.imshow(ndvi_2017, cmap='vanimo', vmin=-1, vmax=1)
    plt.colorbar()
    plt.title('NDVI')

As you can see, NDVI shows very high values for vegetation and zero to negative ones for anything
else. This allows us not only to perform a simple classification of the pixels into vegetation/rest,
but also to estimate plant productivity, trace phenologic cycles, _etc._ NDVI is used for many
downstream analyses (one might argue a few too many).

**Bonus: extra indices**

As said, there are many more indices to compute. For example, [this
list](https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/indexdb/) gives a good
overview over indices available for Sentinel-2.

üñåÔ∏è Pick three indices, look them up, implement them below and observe the result. Note: keep the
function names as they are (`index_1`, `index_2`, `index_3`). Each function should return two
variables: the calculated index and a string denoting the name of the index chosen. You may also
need to modify the `vmin` and `vmax` parameters in the code cell afterwards, depending on whether
the index returns values outside the default `[-1,1]` range or not.

In [None]:
def index_1(data_array):
    val = ...
    name = 'Index 1'
    return val, name


def index_2(data_array):
    val = ...
    name = 'Index 2'
    return val, name


def index_3(data_array):
    val = ...
    name = 'Index 3'
    return val, name

In [None]:
%matplotlib widget


plt.figure(figsize=(15,7))


@widgets.interact(
        red=S2_BANDS,
        green=S2_BANDS,
        blue=S2_BANDS,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_composite_indices(red=S2_BANDS[3],
                       green=S2_BANDS[2],
                       blue=S2_BANDS[1],
                       brightness=1.0):
    band_indices = [S2_BANDS.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(data_2017[band_indices,...], (1,2,0))
    plt.subplot(2,2,1)
    plt.imshow(np.clip(arr, 0, 1))
    plt.title('Sentinel-2 Image')
    plt.subplot(2,2,2)
    index_1_vals, index_1_name = index_1(data_2017)
    plt.imshow(index_1_vals, cmap='vanimo', vmin=-1, vmax=1)
    plt.colorbar()
    plt.title(index_1_name)
    plt.subplot(2,2,3)
    index_2_vals, index_2_name = index_2(data_2017)
    plt.imshow(index_2_vals, cmap='vanimo', vmin=-1, vmax=1)
    plt.colorbar()
    plt.title(index_2_name)
    plt.subplot(2,2,4)
    index_3_vals, index_3_name = index_3(data_2017)
    plt.imshow(index_3_vals, cmap='vanimo', vmin=-1, vmax=1)
    plt.colorbar()
    plt.title(index_3_name)

From this list and your choices, you may be able to make the following observations:
* Some indices can be really useful for one or another type of land cover class (especially water
  and vegetation).
* Some are downright useless for our scene (but remember that they might work in other settings,
  such as with more bare soil).
* Not all indices are of the same format of "(band a - band b) / (band a + band b)"; they can become
  quite elaborate.
* Some indices have been developed for specific sensors in mind. For example, the Enhanced
  Vegetation Index (EVI) is a very popular choice and should perform much better than NDVI over
  rainforest (where the latter saturates, as you have seen above). However, EVI has been developed
  for the [MODIS](https://modis.gsfc.nasa.gov/about/) mission and contains correction factors that
  do not automatically work for other sensors. Thus, if you try to calculate EVI for Sentinel-2
  using the default formula, it likely won't give you good results.
* Finally, some indices include adjustment factors that we often do not know a priori. For example,
  the Soil Adjusted Vegetation Index (SAVI) requires a correction factor known as the "soil line".
  We could estimate this by taking samples of soil pixels and measuring their reflectance ratio
  across the right bands. Other indices are not even directly related to spectral bands but may be
  correlated with them. An example of those is the Leaf Area Index (LAI), which is the ratio of the
  area of leaves over non-vegetation: LAI can range from 0 (no leaves) over 0.5 (50% area coverage
  by leaves) to 1 (only leaves visible in a pixel/mapping unit). We do not know the LAI for our
  scene above since we cannot identify individual leaves; the best we could do is to empirically
  regress it.


Ultimately, spectral indices can be very powerful (and we will see this down below). It can
certainly make sense for you to use them, but it's always important to be aware of their
limitations, too.

---

## 3. Amazon Deforestation Dataset

The above scene showed a rather ominous mixture of natural forest and man-made structures, chief
among which clear cuts and plots. Indeed, we are looking at a scene of the Amazon rainforest that
has been subject to deforestation.

An important task in remote sensing is to map land cover, such as forests. We have seen above that
we can do so pretty well visually. Although we cannot use NDVI alone, since it does not allow us to
separate forest from pasture (for example), we can use a subset of bands to do so ‚Äì and a powerful
machine learning model.

If you think back to Session 4, you will remember that this task is known as **semantic
segmentation**, _i.e._, pixel-wise classification. As you can imagine, we would have to download
lots of Sentinel-2 data and label them all for forest/non-forest pixel-wise, which is very tedious
to do. Luckily, such datasets have been curated and are readily available.

In the following, we will be using the [Amazon and Atlantic
Forest](https://www.kaggle.com/datasets/catiowiec/amazon-and-atlantic-forest-sentinel-2-multiband/data)
dataset.

If you look into that folder, you will find that the images are already divided into training,
validation, and test sets. Also, each folder contains subfolders for images and masks (ground truth
annotations).

If you look further into the `images` folders, you will find a lot of TIFF files. This time, it's
not one for each band, but one for each scene, containing multiple bands in one file.

Moreover, the Sentinel-2 images in there contain only four bands instead of the original twelve:
blue, green, red, near-infrared (NIR).

In [None]:
S2_BANDS_SUBSET = [
    '490 nm (Blue)',
    '560 nm (Green)',
    '665 nm (Red)',
    '842 nm (NIR)'
]

We can now define a function to load such an image tile as follows:

In [None]:
def load_tile(file_path):
    with rasterio.open(file_path, 'r') as f_file:
        data = f_file.read()
    data = data / 10000.0       # we can use the same normalisation here because it's the same format (that isn't always the case; always double-check)
    return data

Next, we can visualise them just as we did with the big file above.

In [None]:
%matplotlib widget

# find all ".tif" files in the training subfolder
tile_file_names = glob.glob(os.path.join(BASE_FOLDER, 'AMAZON', 'Training', 'images', '*.tif'))

# sort them alphabetically
tile_file_names.sort()

# keep track of the tile we are currently visualising
tile_file_name = None
tile = None

points, spectra = [], []                # for click events

def redraw_plot():
    global points, spectra
    plt.subplot(1,2,2)
    plt.gca().clear()
    for spectrum in spectra:
        plt.plot(range(len(spectrum)), spectrum, '-')
    plt.ylim([0, 1])
    plt.ylabel('Normalised value')
    plt.xticks(range(len(S2_BANDS_SUBSET)), S2_BANDS_SUBSET, rotation=90)
    plt.margins(0.05)
    plt.tight_layout(pad=2)

def hover(event):
    global tile
    redraw_plot()
    if tile is not None and event.xdata is not None and event.ydata is not None:
        x = int(np.clip(event.xdata, 0, tile.shape[2]))
        y = int(np.clip(event.ydata, 0, tile.shape[1]))
        spectral_vals = tile[:,y,x]
        plt.plot(range(len(spectral_vals)), spectral_vals, 'k-')
        plt.ylim([0, 1])
        plt.ylabel('Normalised value')
        plt.xticks(range(len(spectral_vals)), S2_BANDS_SUBSET, rotation=90)

def click(event):
     global tile, points, spectra
     if event.xdata is not None and event.ydata is not None:
        x = int(np.clip(event.xdata, 0, tile.shape[2]))
        y = int(np.clip(event.ydata, 0, tile.shape[1]))
        spectral_vals = tile[:,y,x]
        spectra.append(spectral_vals)
        points.append([x,y])
        plt.subplot(1,2,1)
        plt.scatter(x, y)
        redraw_plot()

figure = plt.figure(figsize=(14,7))
# figure.canvas.mpl_connect('motion_notify_event', hover)
figure.canvas.mpl_connect('button_press_event', click)


@widgets.interact(
        file_name=tile_file_names,
        red=S2_BANDS_SUBSET,
        green=S2_BANDS_SUBSET,
        blue=S2_BANDS_SUBSET,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_tiles_all_median(file_name=tile_file_names[0],
                         red=S2_BANDS_SUBSET[2],
                         green=S2_BANDS_SUBSET[1],
                         blue=S2_BANDS_SUBSET[0],
                         brightness=1.0):
    global tile_file_name, tile, ts, points, spectra
    plt.subplot(1,2,1)
    if file_name != tile_file_name:
        plt.gca().clear()
        points, spectra = [], []
        redraw_plot()
        plt.subplot(1,2,1)
    plt.subplot(1,2,1)
    tile = load_tile(file_name)
    band_indices = [S2_BANDS_SUBSET.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(tile[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    if file_name != tile_file_name:
        for point in points:
            plt.scatter(point[0], point[1])
    tile_file_name = file_name

Just like above, you can adjust the brightness, band combination, and click into the tile to plot
spectra.

üñåÔ∏è Some of these files will contain clouds ‚Äì take a look at them. What can you say about the
spectral response of clouds?

_Answer:_
...

We further have ground truth annotations in the `masks` folder, containing information for each
pixel on whether it is forest or not. Let us first create a list of these label classes and a colour
map for visualisation:

In [None]:
LABEL_CLASSES = (
    "non-Forest",
    "Forest"
)

LABELCLASS_COLOURS = (
    (0, 0, 0),
    (0.5, 1, 0.5)
)

# create colour map for Matplotlib
cmap = matplotlib.colors.ListedColormap(LABELCLASS_COLOURS)


Then, we can visualise the ground truth side-by-side with the image tiles:

In [None]:
%matplotlib widget

plt.figure(figsize=(15,7))


@widgets.interact(
        tile_name=tile_file_names,
        red=S2_BANDS_SUBSET,
        green=S2_BANDS_SUBSET,
        blue=S2_BANDS_SUBSET,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_tiles_with_annotations(tile_name=tile_file_names[0],
                               red=S2_BANDS_SUBSET[2],
                               green=S2_BANDS_SUBSET[1],
                               blue=S2_BANDS_SUBSET[0],
                               brightness=1.0):
    plt.subplot(1,2,1)
    tile = load_tile(tile_name)
    band_indices = [S2_BANDS_SUBSET.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(tile[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))

    plt.subplot(1,2,2)
    tile_name_anno = tile_name.replace('images', 'masks')
    with rasterio.open(tile_name_anno, 'r') as f_anno:
        tile_anno = f_anno.read()

    plt.imshow(tile_anno.squeeze(), cmap=cmap)
    cbar = plt.colorbar(ticks=np.arange(len(LABEL_CLASSES)))
    cbar.ax.set_yticklabels(LABEL_CLASSES)

You might see where this is leading: we have many image-ground truth tuples, split into three sets.
Let's train a model to predict forest masks!

---

## 4. Semantic Segmentation with Deep Learning

### 4.1 U-net

The next major ingredient we need for predicting forest masks is a prediction model.

In classical remote sensing, this often consisted in a simple model, such as a random forest, that
would take pixel-wise spectral values as inputs and predict outputs individually. This usually works
if our spatial resolution is "low enough" (_e.g._, Landsat: 30m). However, as spatial resolution
increased and land cover (or use) classes became more fine-scale, per-pixel approaches didn't cut it
anymore: a grey pixel could be a road or top of a building; to really be sure, spatial **texture**
started to become important, too.

üí° This is why pixel-wise semantic segmentation was historically just referred to as
"classification" in remote sensing.


Many semantic segmentation ideas have been proposed over time, and as you would have guessed, deep
learning has spawned the most successful models. Among those, the arguably number one model is
called [U-net](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28):

üñåÔ∏è Read up about U-net online and provide a brief explanation of its working principles.

_Answer:_

...

U-net has originally been proposed for segmentation of biomedical images. However, it is now used in
many other fields, including remote sensing.

üí° Have you heard of Dall-E2 and (in particular) Stable Diffusion? Well, at their core lies... a
U-net.üôÇ

üñåÔ∏è Find an online implementation (_e.g._ on GitHub) of U-net in PyTorch and copy the relevant code
blocks into the code cell below.

In [None]:
# initiate U-net model instance with correct number of input channels and number of predicted
# classes.
# Of course, you may need to find an implementation of U-net first (class UNet ...)
model = ...

Test it out:

In [None]:
# load image
s2_img = load_tile(tile_file_names[0])

# prepare image for prediction with U-net:
# 1. Convert from numpy.array to torch.Tensor
# 2. Add leading dimension for batch index (.unsqueeze(0))
# 3. Convert to 32-bit float (single precision)
data = torch.from_numpy(s2_img).unsqueeze(0).float()

# obtain model prediction (forward pass)
with torch.no_grad():
    pred = model(data)

print(f'Input size:\t\t{data.size()}')
print(f'Prediction size:\t{pred.size()}')

If everything worked correctly, you should get a prediction of size `BxCxHxW` (batch size x no.
classes x height x width), with height and width being identical to the input.

### 4.2 Deep Learning Ingredients

Everything coming should be very familiar to you: we have seen it all in Sessions 2-4 where we
learnt how to train and test deep learning models.

The next ingredient we need is a Dataset class definition. Here, we first find all the images and
then load them together with the corresponding ground truth mask in the `__getitem__` function.

In [None]:
from torch.utils.data import Dataset



class S2Dataset(Dataset):
    def __init__(self,
                 data_dir,
                 transform,
                 split='Training'):
        self.data_dir = data_dir
        self.transform = transform

        # find all label files in data dir
        self.label_files = glob.glob(os.path.join(data_dir, split, 'masks', '*.tif'))
        self.image_files = [file.replace('masks', 'images') for file in self.label_files]

    def __len__(self):
        return len(self.label_files)

    def __getitem__(self, idx):
        # load and normalise image tile
        img = load_tile(self.image_files[idx])

        # convert to torch.Tensor
        img = torch.from_numpy(img).float()

        # transform
        img = self.transform(img)

        # load segmentation ground truth
        with rasterio.open(self.label_files[idx], 'r') as f_label:
            target = f_label.read()
        target = torch.from_numpy(target.squeeze()).long()

        return img, target

Next, we need a loss function. Since semantic segmentation is nothing else than (pixel-wise)
classification, we should be able to use a cross-entropy loss. Luckily, PyTorch's [Cross-Entropy
loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) can handle
multidimensional data easily:

In [None]:
from torch import nn

# load target mask for our sample image
target_path = tile_file_names[0].replace('images', 'masks')
with rasterio.open(target_path, 'r') as f_target:
    target = f_target.read()

# prepare for usage with criterion
target = torch.from_numpy(target).long()


# initialise cross-entropy loss instance
criterion = nn.CrossEntropyLoss()

# predict loss
loss = criterion(pred, target)

print(loss)

As you can see, `loss` still only contains one value. If you go to the
[documentation](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) of the
loss function, you will notice argument `reduction`. By default, this is set to `"mean"`, _i.e._,
per-pixel cross-entropy loss values are averaged. This is important for gradient computation (which
requires one single loss value to begin with). If you want you can try setting `reduction` to
something else and observing what happens with the output.

üí° Setting `reduction='none'` can be used to _e.g._ modify loss values for each pixel individually.
This can be useful when you have areas in an image you don't want the model to learn from.

üí° Since we are technically doing binary classification, we could also use a sigmoid + binary
cross-entropy instead of softmax + multi-class cross-entropy.


The following code block just contains some convenience functions for us to load and save model
states after each epoch.

In [None]:
import re       # package for regular expressions


# this is a default dict of what we store after completion of each epoch
DEFAULT_STATE_DICT = {
    'model': None,
    'loss_train': [],
    'loss_val': [],
    'oa_train': [],
    'oa_val': []
}


def load_model(model,
               save_path='model_states',
               epoch='latest'):
    os.makedirs(save_path, exist_ok=True)

    # find all model states in directory
    model_files = glob.glob(os.path.join(save_path, '*.pt'))
    # extract epoch from file names
    model_epochs = [int(re.sub(r'.*/([0-9]+)\.pt', '\\1', file)) for file in model_files]

    if len(model_epochs) == 0 or (isinstance(epoch, int) and epoch <= 0):
        # nothing saved yet or forcing creation of new model
        print(f'Initialising new model...')
        return model, 0, DEFAULT_STATE_DICT.copy()

    if epoch == 'latest':
        model_epoch = max(model_epochs)
    
    elif epoch not in model_epochs:
        raise Exception(f'Invalid model epoch specified (epoch {epoch} not found).')
    else:
        model_epoch = epoch

    # load model state
    print(f'Loading model state at epoch {model_epoch}...')
    with open(os.path.join(save_path, f'{model_epoch}.pt'), 'rb') as f_state:
        state_dict = torch.load(f_state, map_location='cpu', weights_only=False)
        model.load_state_dict(state_dict['model'])

    return model, model_epoch, state_dict


def save_model(state_dict,
               epoch,
               save_path='model_states'):
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, f'{epoch}.pt'), 'wb') as f_state:
        torch.save(state_dict, f_state)

Below follows the main training block. Again, pretty much all of this should be very familiar to
you.

In [None]:
%matplotlib widget

from torch import nn
from torch.utils.data import DataLoader
from torch.optim import SGD
from torchvision import transforms as T
from tqdm import tqdm
from IPython.display import clear_output


# hyperparameters
NUM_EPOCHS = 10
START_EPOCH = 'latest'      # set to "latest", a number (for specific epoch) or zero (to start training a new model)
BATCH_SIZE = 8
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0
MOMENTUM = 0.9              # momentum keeps parts of gradients from previous batches, which can help to "keep the ball rolling" over minor undulations of the gradient landscape
DEVICE = 'cuda'


# init transforms
transforms = T.Normalize(mean=0.5*torch.ones(len(S2_BANDS_SUBSET)),     # see below for an explanation of these values
                         std=torch.ones(len(S2_BANDS_SUBSET)))

# init dataset and data loader
data_folder = os.path.join(BASE_FOLDER, 'AMAZON')
dl_train = DataLoader(S2Dataset(data_folder, transforms, 'Training'),
                      batch_size=BATCH_SIZE,
                      shuffle=True)
dl_val = DataLoader(S2Dataset(data_folder, transforms, 'Validation'),
                    batch_size=BATCH_SIZE,
                    shuffle=False)

# init optimiser
optimiser = SGD(model.parameters(),
                lr=LEARNING_RATE,
                weight_decay=WEIGHT_DECAY,
                momentum=MOMENTUM)

# load model from pre-trained state if available
model, start_epoch, state_dict = load_model(model, 'model_states', epoch=START_EPOCH)

# move model to device
model = model.to(DEVICE)

# init criterion (with reduction="mean" as default)
criterion = nn.CrossEntropyLoss()



# helper function to plot training progress
plt.figure(figsize=(8, 5))
def plot_training_progress():
    clear_output(wait=True)
    plt.subplot(1,2,1)
    plt.plot(state_dict['loss_train'], 'b-')
    plt.plot(state_dict['loss_val'], 'r-')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.subplot(1,2,2)
    plt.plot(state_dict['oa_train'], 'b-', label='train')
    plt.plot(state_dict['oa_val'], 'r-', label='val')
    plt.xlabel('Epoch')
    plt.ylabel('Overall Accuracy')
    plt.tight_layout()
    plt.show()

plot_training_progress()

# iterate over epochs
for epoch in range(start_epoch+1, NUM_EPOCHS+1):
    # train
    loss_epoch_train, accuracy_epoch_train = 0.0, 0.0

    # put model in training mode (never forget!)
    model.train()
    with tqdm(dl_train) as pbar:
        for idx, (data, target) in enumerate(dl_train):
            data, target = data.to(DEVICE), target.to(DEVICE)

            # standard deep learning training here, just like you have seen in Session 3
            pred = model(data)
            loss = criterion(pred, target)
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            loss_epoch_train += loss.item()

            # overall accuracy:
            # 1. predicted label (y_hat) = position of the predicted max logits in class dimension (pred.argmax(1))
            # 2. pixel-wise comparison with ground truth (y_hat == target): returns tensor of ones where identical and zeros elsewhere
            # 3. convert from bool (True/False) to float (1.0/0.0)
            # 4. flatten into 1-D tensor (.view(-1))
            # 5. take average across all pixels (torch.mean)
            accuracy_epoch_train += torch.mean((pred.argmax(1) == target).float().view(-1)).item()

            pbar.set_description(f'[Ep. {epoch} train] Loss: {loss_epoch_train/(idx+1):.2f}, OA: {accuracy_epoch_train/(idx+1):.2%}')
            pbar.update(1)

    # average loss and overall accuracy values by number of batches (length of data loader),
    # append to correct list
    state_dict['loss_train'].append(loss_epoch_train / len(dl_train))
    state_dict['oa_train'].append(accuracy_epoch_train / len(dl_train))

    # validate
    loss_epoch_val, accuracy_epoch_val = 0.0, 0.0

    # put model in evaluation mode (never forget!)
    model.eval()
    with tqdm(dl_val) as pbar:
        for idx, (data, target) in enumerate(dl_val):
            with torch.no_grad():       #¬†skip calculating gradients; we don't need them for predictions
                data, target = data.to(DEVICE), target.to(DEVICE)

                pred = model(data)
                loss = criterion(pred, target)

                loss_epoch_val += loss.item()
                accuracy_epoch_val += torch.mean((pred.argmax(1) == target).float().view(-1)).item()

                pbar.set_description(f'[Ep. {epoch}   val] Loss: {loss_epoch_val/(idx+1):.2f}, OA: {accuracy_epoch_val/(idx+1):.2%}')
                pbar.update(1)
    state_dict['loss_val'].append(loss_epoch_val / len(dl_val))
    state_dict['oa_val'].append(accuracy_epoch_val / len(dl_val))

    # save model
    state_dict['model'] = model.state_dict()        # get model parameters & assign under correct key
    save_model(state_dict, epoch)

    # plot
    plot_training_progress()


üí° In the code above you may notice that we use mean values of 0.5 and standard deviation values of
1.0 for normalising the tiles instead of per-band statistics calculated a priori, as we usually
would. This is just down to empirical experimentation ‚Äì because remote sensing datasets are not
"standardised" like natural images (_i.e._, with values from 0 to 255), there is no proper guidance
on how to perform data normalisation. Values still should be zero-centred and at unit norm for deep
learning models, though. When loading Sentinel-2 tiles, we divided values by 10,000 because that's
how reflectance is encoded ([see
here](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/)); other products might need
different treatment.

üí° We didn't use any other data transforms this time. However, if you want you can absolutely do
data augmentation like we did in Session 4. Just be careful to also augment the ground truth the
same way: if you perform a horizontal flip of an image, you also need to flip the ground truth this
time. There are libraries to help you with that, such as
[Albumentations](https://github.com/albumentations-team/albumentations).


Running the above code you may notice that performance shoots up rather quickly, even after just a
few epochs. Let's see how well our model works on the test set:

In [None]:
%matplotlib widget

# load tiles from test set
tile_file_names = glob.glob(os.path.join(BASE_FOLDER, 'AMAZON', 'Test', 'images', '*.tif'))

# sort alphabetically
tile_file_names.sort()

# cache data for current tile
current_file_name = None
tile = None
confidence, y_hat = None, None
tile_anno = None

# put model on right device and into evaluation mode
model = model.to(DEVICE)
model.eval()

plt.figure(figsize=(9,6))

@widgets.interact(
        tile_name=tile_file_names,
        red=S2_BANDS_SUBSET,
        green=S2_BANDS_SUBSET,
        blue=S2_BANDS_SUBSET,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_tiles_with_predictions(tile_name=tile_file_names[0],
                               red=S2_BANDS_SUBSET[2],
                               green=S2_BANDS_SUBSET[1],
                               blue=S2_BANDS_SUBSET[0],
                               brightness=1.0):

    global current_file_name, tile, tile_median, confidence, y_hat, tile_anno

    if tile_name != current_file_name:
        # load tile
        tile = load_tile(tile_name)

        # get prediction
        with torch.no_grad():
            # prepare model input: convert to torch.Tensor, convert to 32-bit float,
            # add leading batch dimension, put tensor on right device
            data = torch.from_numpy(tile).float().unsqueeze(0).to(DEVICE)

            # remember that we have a normalisation transform to apply.
            # If you use data augmentation, make sure to not apply that during testing.
            data = transforms(data)

            # obtain prediction/model logits (forward pass)
            pred = model(data)

            # 1. get pseudo-probabilities via softmax
            # 2. get confidence (per-pixel max value) and predicted class y_hat (argument of max)
            # In PyTorch, specifying a dimension with max (.max(1)) returns both the values and
            # arguments together.
            confidence, y_hat = pred.softmax(dim=1).max(1)

            # convert back: remove any extra dimensions (e.g., batch), move tensors back to CPU,
            # convert to NumPy array
            confidence = confidence.squeeze().cpu().numpy()
            y_hat = y_hat.squeeze().cpu().numpy()
        
        # load ground truth: same file name, but different folder, so we can simply replace that
        tile_name_anno = tile_name.replace('images', 'masks')
        with rasterio.open(tile_name_anno, 'r') as f_label:
            tile_anno = f_label.read().squeeze()

        plt.clf()

    # show image tile
    plt.subplot(2,2,1)
    band_indices = [S2_BANDS_SUBSET.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(tile[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.title('Sentinel-2 image')

    # show ground truth
    plt.subplot(2,2,2)
    plt.imshow(tile_anno, cmap=cmap)
    cbar = plt.colorbar(ticks=np.arange(len(LABEL_CLASSES)))
    cbar.ax.set_yticklabels(LABEL_CLASSES)
    plt.title(r'Ground truth $y$')

    # show model confidence
    plt.subplot(2,2,3)
    plt.imshow(confidence, cmap='viridis')
    plt.colorbar()
    plt.title('Model confidence')

    # show model predictions
    plt.subplot(2,2,4)
    plt.imshow(y_hat, cmap=cmap)
    plt.title(r'Model prediction $\hat{y}$')

    plt.tight_layout()
    

üñåÔ∏è Can you explain the significance and model behaviour under the "Model confidence" panel?

_Answer:_
...

Also, we should of course calculate accuracy metrics. During training above, we calculated overall
accuracy on the training and validation set already. While this gives us an indicator as to whether
the model under- or overfits, it may not tell us the whole story.

Remember accuracy metrics from Session 2 and how biased they can be? For example, if 80% of our
pixels are "forest", the model can always trivially predict that class and still score an overall
accuracy of 80%.

Traditionally, you would often find the following metrics in remote sensing tasks:
* A confusion matrix
* User's accuracy (precision)
* Producer's accuracy (recall)
* Cohen's kappa / kappa coefficient: this measures the degree of chance agreement between a model
  prediction and ground truth and is very popular in remote sensing (but see [this recent
  paper](https://www.sciencedirect.com/science/article/abs/pii/S0034425719306509)).

"User's accuracy" and "producer's accuracy" are very common terms in remote sensing but really
denote nothing else than precision and recall.

See [here](https://gsp.humboldt.edu/olm/courses/GSP_216/lessons/accuracy/metrics.html) for more
information on accuracy metrics in remote sensing.


Our model outputs confidence scores via softmax (as shown above). Thus, we can do one of the most
complete analyses and calculate a precision-recall curve for all confidence thresholds. You have
seen this in Session 2 also. To do so, we can predict all test set images and "flatten" the
prediction and ground truth to 1-D tensors. If we combine them across all images, we can then
calculate precision and recall values for all test set pixels together. The code block below does
this and then plots the precision-recall curve.

In [None]:
# prepare empty lists of confidence scores and ground truth labels
predictions, targets = [], []

model = model.eval()

# iterate over test set
dl_test = DataLoader(S2Dataset(data_folder, transforms, 'Test'),
                     batch_size=BATCH_SIZE,
                     shuffle=False)
for data, target in tqdm(dl_test):
    with torch.no_grad():
        # get prediction (forward pass)
        pred = model(data.to(DEVICE))
        # get pseudo-probabilities with Softmax
        pred = pred.softmax(dim=1)
        # take second channel ("forest")
        pred = pred[:,1,:,:]
        # flatten into 1-D array (batch size * number of pixels per image)
        pred = pred.flatten()
        target = target.flatten()
        # append to lists
        predictions.append(pred.cpu())
        targets.append(target)

# concatenate prediction and target lists across images
predictions = torch.cat(predictions, 0)
targets = torch.cat(targets, 0)


# calculate precision-recall curve
from sklearn.metrics import precision_recall_curve

prec, rec, _ = precision_recall_curve(y_true=targets.numpy(),
                                      y_score=predictions.numpy())

plt.figure()
plt.plot(rec, prec, 'k-')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall curve')
plt.show()

If you have let the model train for long enough, that should give you a really good performance.

What's the catch?

Well, maybe there is none. In this case, the problem might simply be _too easy_. Remember from above
where we had taken a look at spectral values of forested _vs._ non-forested areas? We saw pretty
strong differences. We also took a look at NDVI, and although the contrast wasn't super clear
between forest and other vegetated areas, it was still there. Perhaps we could use NDVI to predict
forest cover, _i.e._, we assume any pixel with NDVI >= pre-defined threshold is forest?

üñåÔ∏è Run the widget below and try adjusting the `ndvi_threshold` slider to match the thresholded
image in the bottom left with the ground truth in the bottom right. Try doing this for multiple
images.

In [None]:
%matplotlib widget


# we need to redefine the NDVI function for our tiles, since the required bands now sit at different
# positions in the layer stack compared to the full scenes
def ndvi_tile(tile):
    return (tile[3,...] - tile[2,...]) / (tile[3,...] + tile[2,...])


plt.figure(figsize=(15,7))


@widgets.interact(
        tile_name=tile_file_names,
        red=S2_BANDS_SUBSET,
        green=S2_BANDS_SUBSET,
        blue=S2_BANDS_SUBSET,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0),
        ndvi_threshold=widgets.FloatSlider(value=0.5, min=-1.0, max=1.0, step=0.01)
)
def vis_tiles_with_ndvi_threshold(tile_name=tile_file_names[0],
                                  red=S2_BANDS_SUBSET[2],
                                  green=S2_BANDS_SUBSET[1],
                                  blue=S2_BANDS_SUBSET[0],
                                  brightness=1.0,
                                  ndvi_threshold=0.5):
    plt.subplot(2,2,1)
    tile = load_tile(tile_name)
    band_indices = [S2_BANDS_SUBSET.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(tile[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.title('Sentinel-2 image')

    plt.subplot(2,2,2)
    tile_ndvi = ndvi_tile(tile)
    plt.imshow(tile_ndvi, cmap='vanimo', vmin=-1, vmax=1)
    plt.colorbar()
    plt.title('NDVI')

    plt.subplot(2,2,3)
    tile_ndvi_threshold = (tile_ndvi >= ndvi_threshold).astype(int)
    plt.imshow(tile_ndvi_threshold, cmap=cmap)
    cbar = plt.colorbar(ticks=np.arange(len(LABEL_CLASSES)))
    cbar.ax.set_yticklabels(LABEL_CLASSES)
    plt.title('NDVI thresholded')

    plt.subplot(2,2,4)
    tile_name_anno = tile_name.replace('images', 'masks')
    with rasterio.open(tile_name_anno, 'r') as f_anno:
        tile_anno = f_anno.read()

    plt.imshow(tile_anno.squeeze(), cmap=cmap)
    cbar = plt.colorbar(ticks=np.arange(len(LABEL_CLASSES)))
    cbar.ax.set_yticklabels(LABEL_CLASSES)
    plt.title('Ground Truth')


While not perfect, the result can be really good ‚Äì and this for a **non-machine learning** model
that is orders of magnitude cheaper to compute than a U-net!

---

## 4. Change Detection & Deforestation Monitoring

For this last part, let us go back to our trained U-net, as well as the larger Sentinel-2 image we
investigated at the beginning of the exercise.

Mapping forest cover is more important than ever, but in wake of our alterations to the environment,
predicting **change** is at yet another level. Wouldn't it be great if we could map _deforestation_
across time?

Above, we looked at a single Sentinel-2 scene from 2017. The satellites have been in orbit since
2014 and still are ‚Äì and you are provided with another dataset over the exact same area, but from
2024. Perhaps we could see how much has changed in seven years?

Let us first re-load the twelve bands and create a composite (layer stack) for both timestamps.

In [None]:
FOLDER_2017 = os.path.join(BASE_FOLDER, 'Sentinel-2_stripes', '2017-06-28-00:00_2017-06-28-23:59_Sentinel-2_L2A')
FOLDER_2024 = os.path.join(BASE_FOLDER, 'Sentinel-2_stripes', '2024-06-06-00:00_2024-06-06-23:59_Sentinel-2_L2A')

# we're re-using the load_composite function we defined above
data_2017 = load_composite(FOLDER_2017)
data_2024 = load_composite(FOLDER_2024)

It's always a good idea to visualise our data first. The following widget shows both scenes
side-by-side.

Just as above, you can adjust brightness and band configuration, and click into either image to
display reflectance spectra for both (underneath each respective scene).

Click into areas of change and see how the spectra differ.

In [None]:
%matplotlib widget

plt.close('all')

points, spectra_2017, spectra_2024 = [], [], []                # for click events

def redraw_plot():
    global points, spectra_2017, spectra_2024
    ax1 = plt.subplot(2,2,3)
    ax1.clear()
    for spectrum in spectra_2017:
        ax1.plot(S2_LAMBDAS, spectrum, '-')
    ax1.set_ylabel('Normalised value')
    ax1.set_xticks(S2_LAMBDAS, S2_BANDS, rotation=90)
    ax2 = plt.subplot(2,2,4, sharey=ax1)
    ax2.clear()
    for spectrum in spectra_2024:
        ax2.plot(S2_LAMBDAS, spectrum, '-')
    ax2.set_ylabel('Normalised value')
    ax2.set_xticks(S2_LAMBDAS, S2_BANDS, rotation=90)
    plt.margins(0.05)
    plt.tight_layout(pad=2)


def click(event):
     global tile, points, spectra_2017, spectra_2024
     if event.xdata is not None and event.ydata is not None:
        x = int(np.clip(event.xdata, 0, data_2017.shape[2]))
        y = int(np.clip(event.ydata, 0, data_2017.shape[1]))
        spectral_vals_2017 = data_2017[:,y,x]
        spectra_2017.append(spectral_vals_2017)
        spectral_vals_2024 = data_2024[:,y,x]
        spectra_2024.append(spectral_vals_2024)
        points.append([x,y])
        plt.subplot(2,2,1)
        plt.scatter(x, y)
        plt.subplot(2,2,2)
        plt.scatter(x, y)
        redraw_plot()

figure = plt.figure(figsize=(15, 10))
figure.canvas.mpl_connect('button_press_event', click)


@widgets.interact(
        red=S2_BANDS,
        green=S2_BANDS,
        blue=S2_BANDS,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_composites(red=S2_BANDS[3],
                   green=S2_BANDS[2],
                   blue=S2_BANDS[1],
                   brightness=1.0):
    global tile_file_name, tile, ts, points, spectra_2017, spectra_2024
    band_indices = [S2_BANDS.index(band) for band in [red, green, blue]]
    plt.subplot(2,2,1)
    arr = brightness * np.transpose(data_2017[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.axis('off')
    plt.title('2017-06-28')
    plt.subplot(2,2,2)
    arr = brightness * np.transpose(data_2024[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.axis('off')
    plt.title('2024-06-06')    

Now, let us use our trained U-net model from above to predict these changes. To do so, we will:
1. Obtain pixel-wise predictions for each timestep.
2. Create a difference map by subtracting predicted classes (forest/non-forest) from each other.


üí° This is a very crude way of performing change detection. It may not always work due to different
dataset characteristics ("domain shift", see final section of previous practical). Many other ways
have been proposed instead, including:
* Unsupervised change detection: trying to compare scenes without any labels.
* Models that ingest two inputs at once and predict changes directly.
* _etc._


We now just have one obstacle: our two scenes are way larger than the $512\times512$ tiles our U-net
was trained on. The general solution thus is to split our full scenes up into **patches** (or tiles,
windows) of correct size, predict each tile individually (or in batches), and stitch the predictions
back together:

<img src="https://raw.githubusercontent.com/obss/sahi/main/resources/sliced_inference.gif" />

[image source](https://github.com/obss/sahi)

To do so, we have to perform the following steps:
1. Create target tensor/array of same size as full scene to store model predictions in
2. Create list of positions in East/North (x/y) direction that denote the top-left corner of each
   patch.
3. For each coordinate in x/y: extract patch, obtain prediction, "burn" prediction into output
   tensor


Finally, remember that our U-net has only been trained on blue, green, red, and NIR bands, while our
scenes contain twelve of them. However, since we know exactly which band is where, we can just take
a subset of the required bands in correct order.

Let's do all of that in one go below:

In [None]:
def predict_composite(composite,
                      model,
                      tile_size):
    model = model.eval().to(DEVICE)

    # normalise satellite data for model by z-scoring over entire composite
    data_mean = np.nanmean(composite.reshape((composite.shape[0], -1)), 1)[:,np.newaxis,np.newaxis]
    data_std = np.nanstd(composite.reshape((composite.shape[0], -1)), 1)[:,np.newaxis,np.newaxis]

    data = (np.copy(composite) - data_mean) / data_std

    # create positions in East/North (x/y) from 0 to shape (width/height) of data in tile_size steps
    pos_x = np.arange(0, data.shape[2], tile_size[0])
    pos_y = np.arange(0, data.shape[1], tile_size[1])

    # prepare arrays to store predictions in
    conf_comp = np.zeros(data.shape[1:])
    y_hat_comp = np.zeros(data.shape[1:], dtype=int)

    # iterate over positions
    for loc_x in pos_x:
        for loc_y in pos_y:
            # determine patch size: we may overshoot the data boundaries at the end, so we take
            # the minimum of either position + tile size or else width/height of data array
            end_x = min(loc_x+tile_size[0], data.shape[2])
            end_y = min(loc_y+tile_size[1], data.shape[1])

            # extract patch: all bands, y position to end in height, x position to end in width
            patch = data[:,loc_y:end_y,loc_x:end_x]

            # at the scene borders, some of our patches might be smaller than the target size;
            # to prevent this from happening, we will pad them with zeros.
            # This is important so that the input to our U-net always is the same size.
            patch_shape = patch.shape
            patch = np.pad(patch, [(0,0),                               # no padding in band dim
                                   (0,tile_size[1]-patch_shape[1]),     # remainder in height
                                   (0,tile_size[0]-patch_shape[2])],    # remainder in width
                                   mode='constant')

            # prepare patch as usual: convert to Tensor, 32-bit float, apply transforms,
            # add leading batch dimension and move to device
            patch = torch.from_numpy(patch).float()
            patch = transforms(patch)
            patch = patch.unsqueeze(0).to(DEVICE)

            # predict
            with torch.no_grad():
                pred = model(patch)
                conf, y_hat = pred.softmax(dim=1).max(1)

            # throw away padded zeros again (if there are any): take only the subset of the
            # actual patch shape as originally noted pre-zero padding
            conf = conf[:,:patch_shape[1],:patch_shape[2]]
            y_hat = y_hat[:,:patch_shape[1],:patch_shape[2]]

            # store in target arrays at correct position
            conf_comp[loc_y:end_y,loc_x:end_x] = conf.squeeze().cpu().numpy()
            y_hat_comp[loc_y:end_y,loc_x:end_x] = y_hat.squeeze().cpu().numpy()
    return conf_comp, y_hat_comp


# we only need four bands out of the twelve for our deep learning model
composite_bands = [1,2,3,7]     # Blue, Green, Red, NIR

# predict two composites
conf_2017, y_hat_2017 = predict_composite(data_2017[composite_bands,...],
                                          model,
                                          tile_size=[512, 512])

conf_2024, y_hat_2024 = predict_composite(data_2024[composite_bands,...],
                                          model,
                                          tile_size=[512, 512])

Visualised side-by-side:

In [None]:
%matplotlib widget

plt.close('all')


figure = plt.figure(figsize=(15, 10))


@widgets.interact(
        red=S2_BANDS,
        green=S2_BANDS,
        blue=S2_BANDS,
        brightness=widgets.FloatSlider(value=1.0, min=0.1, max=5.0)
)
def vis_composites(red=S2_BANDS[3],
                   green=S2_BANDS[2],
                   blue=S2_BANDS[1],
                   brightness=1.0):
    plt.subplot(2,2,1)
    band_indices = [S2_BANDS.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(data_2017[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.axis('off')
    plt.title('2017-06-28')
    plt.subplot(2,2,2)
    band_indices = [S2_BANDS.index(band) for band in [red, green, blue]]
    arr = brightness * np.transpose(data_2024[band_indices,...], (1,2,0))
    plt.imshow(np.clip(arr, 0, 1))
    plt.axis('off')
    plt.title('2024-06-06')
    plt.subplot(2,2,3)
    plt.imshow(y_hat_2017, cmap=cmap)
    cbar = plt.colorbar(ticks=np.arange(len(LABEL_CLASSES)))
    cbar.ax.set_yticklabels(LABEL_CLASSES)
    plt.subplot(2,2,4)
    plt.imshow(y_hat_2024, cmap=cmap)
    cbar = plt.colorbar(ticks=np.arange(len(LABEL_CLASSES)))
    cbar.ax.set_yticklabels(LABEL_CLASSES)
    

If your U-net has been trained long enough, that should look pretty good! Now for the final part:
the difference map.

For starters, we could just subtract the 2017 prediction from the 2024 prediction. Remember that we
assigned value 0 as "non-forest" and value 1 as "forest". That difference map would then contain
three possible values:
* -1: forest loss (0 in 2024, 1 in 2017)
* 0: no change
* 1: forest gain (1 in 2024, 0 in 2017)

However, we can go one step further: the "no change" case could either be because there has always
been forest, or because both timestamps show "no forest". Thus, we can shift case 1 (forest gain)
one value up to 2 and introduce two new values:
* 0: no change (no forest)
* 1: no change (forest)

We'll do all of that below and then visualise the result with a custom colour map.

In [None]:
# calculate difference in prediction
pred_diff = y_hat_2024 - y_hat_2017                 # difference: -1 = forest loss; 0 = no change; 1 = forest gain

# improve information content w.r.t. no change
pred_diff[pred_diff > 0] = 2                        # new value: 2 = forest gain
no_change = pred_diff == 0                          # find pixels with zero-difference (no change)
pred_diff[no_change] = y_hat_2024[no_change]        # new "no change" scenario: 0 = no forest; 1 = forest

#¬†create custom colour map
cmap_diff = matplotlib.colors.ListedColormap([
    [1.0, 0.7, 0.4],        # forest loss
    [0.6, 0.6, 0.6],        # no change (no forest)
    [0.2, 0.6, 0.1],        # no change (forest)
    [0.1, 0.9, 0.9]         # forest gain
])


plt.close('all')

plt.figure(figsize=(10,6))
plt.imshow(pred_diff, cmap=cmap_diff)
cbar = plt.colorbar()
cbar.set_ticks(ticks=[-1, 0, 1, 2],
               labels=['loss', 'no change (no forest)', 'no change (forest)', 'gain'])
plt.title('Predicted change 2017-2024')
plt.show()

And, just like that, we have computed a change map for forest cover.

Finally, we can calculate the total area gained or lost in forest. 

üñåÔ∏è Implement this below. Remember that Sentinel-2 has a resolution of 10m, so you should be able to
calculate the total area predicted as forest gain, respectively loss, in $km^2$, as well as in
percentage of pixels. Report all four quantities to a precision of two decimal points.

In [None]:
forest_gain_km2 = ...
forest_gain_percentage = ...
forest_loss_km2 = ...
forest_loss_percentage = ...

Do the values you obtain sound reasonable? Also look at where (and why) difference maps predict
forest gain/loss. Depending on the model you have used, you might for example see those diagonal
stripes again erroneously influencing the result. Compare with your neighbours to see what kind of
spread you get, and think about what could have influenced this variation.


**Bonus: geospatial rasters**

We can save the change map now as a GeoTIFF that you can open in QGIS, ArcGIS Pro, _etc._:

In [None]:
# open one of the GeoTIFFs in read mode to get all the metadata (we chose this file at the start)
with rasterio.open(file_path, 'r') as f_band:
    meta = f_band.meta
    # let's print some of the metadata
    print(f'Data type: {meta["dtype"]}')
    print(f'Size: {meta["width"]} x {meta["height"]}')
    print(f'Count: {meta["count"]}')            # number of bands
    print(f'CRS: {meta["crs"]}')                # CRS: coordinate reference system (https://docs.qgis.org/3.34/en/docs/gentle_gis_introduction/coordinate_reference_systems.html)
    print(f'Transform:\n{meta["transform"]}')   # info: https://pygis.io/docs/d_affine.html

# we need to update some of the metadata for our change map
meta.update({
    'count': 1,         # number of bands: we only have one (the change map)
    'dtype': 'uint8'    # data type: 8-bit unsigned integer is enough for the four values we have
})

# save change detection map
with rasterio.open('change_map.tiff', 'w', **meta) as f_out:
    f_out.write(pred_diff[np.newaxis,...].astype(np.uint8) + 1)     # +1 to make values start at zero (required for unsigned integer data type & colour map)
    f_out.write_colormap(
        1,                                                          # define colour map for the first (and only) band
        dict([idx, [int(255*val) for val in col] + [255]]
             for idx, col in enumerate(cmap_diff.colors))           # take all colours in order, convert values to 0-255 int, add extra 255 for alpha value (opacity)
    )

üí° The metadata we see and print above is part of the (Geo-) TIFF format. It not only defines basic
image properties, such as size (width, height), count (number of bands/colour channels) and data
type, but also _geospatial_ properties. There are two main ones to know in this context:
1. The [Coordinate Reference System
   (CRS)](https://docs.qgis.org/3.34/en/docs/gentle_gis_introduction/coordinate_reference_systems.html):
   this informs about the _ellipsoid_ that serves as a "reference" (a zero-point, if you will) for
   the image in geographic space. As you know, Earth isn't a perfect round ball, but an irregular
   shape with undulations (valleys, mountains); therefore, an ellipsoid is used that approximates
   its shape. Multiple such ellipsoids have been defined, such as [WGS84](https://epsg.io/4326) (a
   global ellipsoid, used here) or other ones that are more accurate in some particular areas in the
   world. Almost every country has multiple of them ([references for the
   UK](https://www.gov.uk/guidance/uk-geospatial-data-standards-coordinate-reference-systems)).
2. Once the CRS is defined, we need to know where in relation to it the image comes to lie. For
   regularly taken datasets, this can be done with an [affine transformation
   matrix](https://pygis.io/docs/d_affine.html) (as used above). Affine matrices allow linear
   transformations of images with respect to geospace, such as scaling, translation, rotation,
   skewing and shearing. That's accurate enough for a satellite that always follows a steady orbit
   around the Earth. As soon as acquisitions can get non-linear, more advanced geocoding methods are
   required (think about an aeroplane or drone being caught by a gust of wind during data
   acquisition).

If you are interested: the above information allows us to easily convert between pixel and
geospatial coordinates! For example:

In [None]:
# pixel coordinate we want to translate to geospatial one
pixel_coord = [400, 565]        # pixel at 400th column, 565th row

# convert to geospatial coordinate: simply multiply with affine transform
spatial_coord = meta['transform'] * pixel_coord

print(f'Geospatial coordinates: {spatial_coord[1]}, {spatial_coord[0]}')

# convert back to pixel coordinates: multiply with inverse transform
pixel_coord_back = ~meta['transform'] * spatial_coord

print(f'Backtransformed to pixel coordinates: {pixel_coord_back[0]}, {pixel_coord_back[1]}')

Open a map viewer (_e.g._, [Google Maps](https://www.google.co.uk/maps)) and copy-paste the
geospatial coordinate pair into the search function. Check whether you end up in the same location.

üí° Those spatial coordinates have come out in degrees lat/lon, because that's the format that our
CRS (WGS84) defines. Other systems may return metres or something else.

üí° Now you know how to sample pixel values from geospatial rasters based on _e.g._ GPS
coordinates. üòä


You could now go ahead and calculate change maps for other timestamps of the same area if you
wanted. Of course, if you want to go back before 2014, you would have to use a different sensor
(Landsat for example goes back to the 1970s). Some approaches even allow you to do **data fusion**,
such as combining scenes from different satellites, maybe even modalities (for example Synthetic
Aperture Radar/SAR, which also allows you to map forests and can see through clouds). This and many,
many more things are possible with remote sensing!

---

##¬†5. Summary and Outlook

We have taken a look at:
1. Optical, multispectral satellite data (from Sentinel-2)
2. Spectral indices (in this case, NDVI)
3. Semantic segmentation with deep learning (using U-net)
4. Mapping of forest cover in satellite imagery
5. Change detection (mapping & quantifying deforestation across time)


As you have also seen, machine (and especially deep) learning isn't always the optimal solution.
Sometimes it's just an overkill. Nonetheless, there are countless cases where a more sophisticated
model is needed, also in remote sensing. For example:
* Detecting objects of very complicated appearance or very small size (_e.g._, cars): here, spectral
  reflectance signatures are not sufficient anymore to distinguish them from the rest, we need (high
  spatial resolution and) texture (and a model that can cope with it, such as a U-net).
* Mapping vegetation phenology throughout a year: such cases require explicit time series of
  satellite products (we cannot reliably measure biomass in winter, for example).


###¬†Further resources

**Forest Monitoring with Remote Sensing**
* Nguyen, T.A., Ru√üwurm, M., Lenczner, G. and Tuia, D., 2024. Multi-temporal forest monitoring in
  the Swiss Alps with knowledge-guided deep learning. Remote Sensing of Environment, 305, p.114109.
  [https://www.sciencedirect.com/science/article/pii/S0034425724001202](https://www.sciencedirect.com/science/article/pii/S0034425724001202).
* Waldeland, A.U., Trier, √ò.D. and Salberg, A.B., 2022. Forest mapping and monitoring in Africa
  using Sentinel-2 data and deep learning. International Journal of Applied Earth Observation and
  Geoinformation, 111, p.102840.
  [https://doi.org/10.1016/j.jag.2022.102840](https://doi.org/10.1016/j.jag.2022.102840).
  _After today's exercise, this work should sound very intuitive to you._


**Deep Learning and Remote Sensing in general**

These are some useful publications:
* Camps-Valls, G., Tuia, D., Zhu, X.X. and Reichstein, M. eds., 2021. Deep learning for the Earth
  Sciences: A comprehensive approach to remote sensing, climate science and geosciences. John Wiley
  & Sons. [Google books
  preview](https://books.google.ch/books?hl=en&lr=&id=e2c4EAAAQBAJ&oi=fnd&pg=PR16&dq=deep+learning+for+remote+sensing+camps-valls&ots=gHjF752TBl&sig=YYRWkkXFOuYg2GYALeeibJbEuNc#v=onepage&q=deep%20learning%20for%20remote%20sensing%20camps-valls&f=false).
* Zhu, X.X., Tuia, D., Mou, L., Xia, G.S., Zhang, L., Xu, F. and Fraundorfer, F., 2017. Deep
  learning in remote sensing: A comprehensive review and list of resources. IEEE geoscience and
  remote sensing magazine, 5(4), pp.8-36.
  [https://doi.org/10.1109/MGRS.2017.2762307](https://doi.org/10.1109/MGRS.2017.2762307).

There also is an active community hosting conferences and workshops in the field:
* [EarthVision](https://www.grss-ieee.org/events/earthvision-2025/): yearly workshop at the CVPR
  conference
* [ML4RS](https://ml-for-rs.github.io/iclr2025/): yearly workshop at ICLR


Ultimately, there are things we cannot do with remote sensing, as is the case with everything and
every modality we have seen in the course. However, when used correctly and in conjunction with
other data (if necessary), remote sensing can be an extremely powerful tool for all sorts of
ecological analyses. I invite you to think about research questions you would like to answer, and
how geospatial analyses and remote sensing can contribute. The sky's the limit (no pun intended)! üòÉ