# Demo: Calculating band indices

This notebook demonstrates how to calculate band indices for multispectral data using the `eo-insights` package.
The notebook was inspired by work done by FrontierSI for .id on detecting residential development using satellite imagery.

The notebook demonstrates:

1. Importing the relevant module and function
1. Loading relevant bands from Digital Earth Australia's Sentinel-2 product
1. Calculating key indices for development detection (NDWI, BSI, MSAVI, and BAEI)
1. Calculating temporal statistics for the indices (median and maximum)
1. Displaying the calculated indices

## Set up

The following cell should be uncommented and run if you installed the package in editable mode and are actively developing and testing modules.
Otherwise, it can be left commented.

In [None]:
# %load_ext autoreload
# %autoreload 2

### Enable logging

This will allow you to see info and warning messages from the package.

In [None]:
import logging
import sys

logging.basicConfig(
    format="%(asctime)s | %(levelname)s : %(message)s",
    level=logging.INFO,
    stream=sys.stdout,
)

### Import the relevant packages

This demo uses the Digital Earth Australia Sentinel-2 product, and hence imports the `de_australia_stac_config`.
For more information on available configurations, see [configuration_demo.ipynb](configuration_demo.ipynb)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import typing

from eo_insights.band_indices import calculate_indices, IndexName
from eo_insights.raster_base import RasterBase, QueryParams, LoadParams
from eo_insights.stac_configuration import de_australia_stac_config

You can check the available collections using the `.list_collections()` method.

In [None]:
de_australia_stac_config.list_collections()

## Load data
### Set up query and load parameters

Date range and bounding box are set as part of the `QueryParams` class.
CRS, resolution, and desired bands are set as part of the `LoadParams` class.

In [None]:
# set up AOI and time period
query_params = QueryParams(
    bbox=(
        144.97,
        -37.47,
        144.95,
        -37.45,
    ),
    start_date="2020-01-01",
    end_date="2020-06-30",
)
# specify bands for calculating NDWI, BSI, MSAVI, and BAEI
load_params = LoadParams(
    crs="EPSG:3577",
    resolution=10,
    bands=("blue", "green", "red", "nir", "swir_1", "swir_2", "fmask"),
)

### Load using RasterBase class

In [None]:
# query data from DE Aus
s2_raster = RasterBase.from_stac_query(
    config=de_australia_stac_config,
    collections=["ga_s2am_ard_3", "ga_s2bm_ard_3"],
    query_params=query_params,
    load_params=load_params,
)

#### View the loaded bands

In [None]:
s2_raster.data

#### View the loaded masks

In [None]:
s2_raster.masks

### Apply masks before calculating indices

This step applies the default masking parameters that are specified in the relevant collection configuration.
Supplying `nodata=np.nan` will overwrite the default nodata value.

In [None]:
s2_raster.apply_mask("fmask", nodata=np.nan)

## Calculate band indices required for development detection

This section demonstrates how to calculate the MSAVI, BAEI, BSI and NDWI indices using the `calculate_indices` function from the `eo_insights.band_indices` module.

Band indices are passed in as lower case strings.
You can see the available indices by running the next cell.

In [None]:
available_indices = typing.get_args(IndexName)
available_indices

In [None]:
s2_raster.data = calculate_indices(s2_raster.data, ["ndwi", "bsi", "msavi", "baei"])
s2_raster.data

Once these indices are calculated, the median values for MSAVI, BAEI and BSI are calculated and then the maximum values are calculated for NDWI.

In [None]:
# calc. medians
ds_medians = (
    s2_raster.data[["msavi", "baei", "bsi"]]
    .median(dim="time")
    .assign_coords(time=query_params.end_date)
).compute()

ds_medians

In [None]:
# calc. the max NDWI
ds_ndwi_max = (
    s2_raster.data["ndwi"].max(dim="time").assign_coords(time=query_params.end_date)
)

ds_ndwi_max

In [None]:
# combine medians and max into single xarray
ds_combined = ds_medians.merge(ds_ndwi_max)

In [None]:
# mask out all zero pixels
ds_masked_nonzero = ds_combined.where(
    (ds_combined["msavi"] != 0)
    | (ds_combined["baei"] != 0)
    | (ds_combined["bsi"] != 0)
    | (ds_combined["ndwi"] != 0)
).compute()

ds_masked_nonzero

### Plot indicies 

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

ds_masked_nonzero["ndwi"].plot.imshow(ax=axs[0, 0])
axs[0, 0].set_title("NDWI")
axs[0, 0].set_aspect("equal", adjustable="box")

ds_masked_nonzero["bsi"].plot.imshow(ax=axs[0, 1])
axs[0, 1].set_title("BSI")
axs[0, 1].set_aspect("equal", adjustable="box")

ds_masked_nonzero["msavi"].plot.imshow(ax=axs[1, 0])
axs[1, 0].set_title("MSAVI")
axs[1, 0].set_aspect("equal", adjustable="box")

ds_masked_nonzero["baei"].plot.imshow(ax=axs[1, 1])
axs[1, 1].set_title("BAEI")
axs[1, 1].set_aspect("equal", adjustable="box")

plt.tight_layout()
plt.show()