# Demo: Machine Learning for native species habitat mapping

This notebook demonstrates how to load a number of Earth observation datasets using the `eo-insights` package and train a machine learning model with them. 
The purpose is to demonstrate how the `eo-insights` package can be used to support machine learning workflows.

This notebook has been inspired by work conducted at FrontierSI.
It uses a subset of species occurrence points for the Southern Bell Frog that were extracted from the Atlas of Living Australia.
The subset has been provided in the `data` folder for the purposes of running this demonstration.

## Caveats
At this time, the `eo-insights` package focusses on data management, but it would be within scope for many of the approaches used in this notebook to become a formalised part of the package to support machine learning for Earth observation.

The notebook is a demonstration only -- the model trained in this notebook should not be used for making predictions.
It has been trained on a small subset of data and has not been fine-tuned.

## Overview

This notebook demonstrates:

1. Loading a geojson of species occurrence data
1. Querying products from Digital Earth Australia
1. Using a subset of bands to run a segmentation algorithm
1. Calculating zonal statistics for the segments
1. Training a Random Forest Classifier from sklearn
1. Performing a prediction on a larger region
1. Displaying prediction probabilities 

## Set up notebook

The following cell should be uncommented and run if you installed the package in editable mode and are actively developing and testing modules. Otherwise, it can be left commented.

In [None]:
# %load_ext autoreload
# %autoreload 2

### Enable logging
This will allow you to see info and warning messages from the package.

In [None]:
import logging
import sys

logging.basicConfig(
    format="%(asctime)s | %(levelname)s : %(message)s",
    level=logging.INFO,
    stream=sys.stdout,
)

## Import configuration and modules

In [None]:
from pathlib import Path

import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from affine import Affine
from rasterstats import zonal_stats
from skimage.segmentation import quickshift, slic
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler

from eo_insights.band_indices import calculate_indices
from eo_insights.raster_base import LoadParams, QueryParams, RasterBase
from eo_insights.spatial import xr_rasterize, xr_vectorize
from eo_insights.stac_configuration import de_australia_stac_config

## Load training data and define study area

In [None]:
CURRENT_DIR = Path.cwd()
DATA_PATH = CURRENT_DIR / "data" / "habitat_mapping"

# relative path of training data
gdf_fauna = gpd.read_file(
    DATA_PATH / "habitat_mapping_southern_bell_frog.geojson"
).to_crs("EPSG:3577")
bbox = gdf_fauna.to_crs("EPSG:4326").total_bounds

gdf_fauna.explore(column="year")

## Query satellite products

### Get DEA configuration and list collections

In [None]:
config = de_australia_stac_config
config.list_collections()

### Load DEA products of most recent year:
* Landsat 8 & 9 yearly geomedian

In [None]:
query_params_ls = QueryParams(
    bbox=bbox,
    start_date="2023-01-01",
    end_date="2023-12-31",
)

# Landsat 8 & 9 yearly geomedian: load only geomedian spectral bands for now
params_ls_8_9_gm = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("blue", "green", "red", "nir", "swir_1", "swir_2"),
)
raster_ls_8_9_gm = RasterBase.from_stac_query(
    config=config,
    collections=["ga_ls8cls9c_gm_cyear_3"],
    query_params=query_params_ls,
    load_params=params_ls_8_9_gm,
)

# calculate indices
ds_ls_8_9_gm = calculate_indices(raster_ls_8_9_gm.data, ["ndvi", "ndwi"])

ds_ls_8_9_gm = ds_ls_8_9_gm.isel(time=0)
ds_ls_8_9_gm

* Sentinel-2 Barest Earth
* Calculate BSI

In [None]:
query_params_s2_be = QueryParams(
    bbox=bbox,
    start_date="2017-01-01",
    end_date="2018-12-31",
)

# load only bands that are used for BSI calculation
params_s2_be = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("blue", "red", "nir", "swir_1"),
)
raster_s2_be = RasterBase.from_stac_query(
    config=config,
    collections=["s2_barest_earth"],
    query_params=query_params_s2_be,
    load_params=params_s2_be,
)

# calculate BSI
ds_s2_be_bsi = calculate_indices(raster_s2_be.data, ["bsi"])

# dropping original bands
ds_s2_be_bsi = ds_s2_be_bsi[["bsi"]].isel(time=0)
ds_s2_be_bsi
# ds_s2_be_bsi['bsi'].plot.imshow()

* SRTM 1 second DEM version 1.0

In [None]:
query_params_dem = QueryParams(
    bbox=bbox,
    start_date="2014-01-01",
    end_date="2014-12-31",
)

params_dem = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("dem"),
)
raster_dem = RasterBase.from_stac_query(
    config=config,
    collections=["ga_srtm_dem1sv1_0"],
    query_params=query_params_dem,
    load_params=params_dem,
)
ds_dem = raster_dem.data.isel(time=0)
ds_dem

* Multi-scale Topographic Position Index (TPI) layer

In [None]:
query_params_tpi = QueryParams(
    bbox=bbox,
    start_date="2018-01-01",
    end_date="2018-12-31",
)

params_tpi = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("regional", "intermediate", "local"),
)
raster_tpi = RasterBase.from_stac_query(
    config=config,
    collections=["multi_scale_topographic_position"],
    query_params=query_params_tpi,
    load_params=params_tpi,
)
ds_tip = raster_tpi.data.isel(time=0)
ds_tip

* Weathering Intensity layer

In [None]:
query_params_weathering = QueryParams(
    bbox=bbox,
    start_date="2018-01-01",
    end_date="2018-12-31",
)

params_weathering = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("intensity"),
)
raster_weathering = RasterBase.from_stac_query(
    config=config,
    collections=["weathering_intensity"],
    query_params=query_params_weathering,
    load_params=params_weathering,
)
ds_weathering = raster_weathering.data.isel(time=0)
ds_weathering

### Stack all product bands

In [None]:
ds_all = xr.merge(
    [ds_ls_8_9_gm, ds_s2_be_bsi, ds_dem, ds_tip, ds_weathering], compat="override"
).compute()
ds_all.compute()

## Object-based classification for species habitat prediction

### Image segmentation

In [None]:
arr_data = np.moveaxis(ds_all[["ndvi", "ndwi", "dem"]].to_array().to_numpy(), 0, -1)

# feature normalisation
rows, columns, n_band = arr_data.shape
arr_data = np.reshape(arr_data, (rows * columns, n_band))
arr_data = StandardScaler().fit_transform(arr_data)

# Do segmentation - slic or quickshift
arr_data = np.reshape(arr_data, (rows, columns, n_band))
# da_segments = slic(arr_data,n_segments=500,compactness=compactness,slic_zero=False)
# da_segments = quickshift(arr_data,ratio=0.8,kernel_size=2,max_dist=10,sigma=0,convert2lab=False)
da_segments = quickshift(
    arr_data, ratio=1.0, kernel_size=3, max_dist=10, sigma=0, convert2lab=False
)
da_segments = xr.DataArray(
    da_segments, coords=ds_all.coords, dims=ds_all.dims, attrs=ds_all.attrs
).astype(np.int16)

#### Vectorise segmentation raster

In [None]:
gdf_segments = xr_vectorize(da_segments).drop(["attribute"], axis=1)
gdf_segments.explore()

#### Object-level features calculation through zonal statistics

In [None]:
var_names = list(ds_all.keys())
list_stats = ["median", "std", "percentile_10", "percentile_90"]
transform = Affine.translation(
    float(da_segments.x.min()), float(da_segments.y.max())
) * Affine.scale(30, -30)
# transform = segments.rio.transform() # when rioxarray is available
gdf_stats_all = None
attr_fields = ["geometry"]

# Calculate zonal statistics for all bands
for var in var_names:

    print("calculating zonal statistics for band", var)
    band = ds_all[var].to_numpy()
    zonestats = zonal_stats(
        gdf_segments,
        band,
        stats=list_stats,
        affine=transform,
        all_touched=True,
        geojson_out=True,
    )

    # convert to geopandas dataframe
    gdf_stats = gpd.GeoDataFrame.from_features(zonestats, crs=gdf_segments.crs)

    # rename stats to use band name as prefix
    for stat in list_stats:
        stat_var = var + "_" + stat
        gdf_stats.rename(columns={stat: stat_var}, inplace=True)
        attr_fields.append(stat_var)

    # append statistics
    if gdf_stats_all is None:
        gdf_stats_all = gdf_stats.copy()
    else:
        gdf_stats_all = pd.concat(
            [gdf_stats_all, gdf_stats.drop(["geometry"], axis=1)], axis=1
        )

# remove redundant attributes
for column_name in list(gdf_stats.columns):
    if column_name not in attr_fields:
        gdf_stats_all.drop([column_name], axis=1, inplace=True)

#### Export zonal statistics as a vector file for reuse (optional)

In [None]:
gdf_stats_all.to_file(DATA_PATH / "habitat_mapping_segmentation_stats.geojson")

### Classification model training

#### Prepare training samples

In [None]:
field = "Presence"
n_absence = 100

# Get presence samples and pesudo-absence samples
occurence_segs = gdf_stats_all[
    gdf_stats_all.intersects(gdf_fauna.unary_union)
].reset_index(drop=True)
absence_segs = (
    pd.concat([gdf_stats_all, occurence_segs])
    .drop_duplicates(keep=False)
    .sample(n=n_absence)
    .reset_index(drop=True)
)

# Presence samples labelled as 1; pesudo-absence samples labelled as 2
occurence_segs[field] = 1
absence_segs[field] = 2

# Merge presence and absence samples
train_segs = pd.concat([occurence_segs, absence_segs]).reset_index(drop=True)

#### Fit a random forest classifier

In [None]:
train_segs = train_segs.drop(columns=["geometry"])
column_names = train_segs.columns
X = train_segs.iloc[:, 0:-1]
y = train_segs.iloc[:, -1]
Classifier = RandomForestClassifier(n_estimators=200)
Classifier.fit(X, y)

In [None]:
skf = StratifiedKFold(
    n_splits=5, shuffle=True, random_state=1
)  # stratified K-fold splitting
overall_acc = cross_val_score(Classifier, X, y, cv=skf, scoring="accuracy")
print("Overall accuracy from cv scores: ", np.mean(overall_acc))
f1_macro = cross_val_score(Classifier, X, y, cv=skf, scoring="f1_macro")
print("f1_macro from cv scores: ", np.mean(f1_macro))

### Prediction on all the study area
Here we predict and display prediction probability of  Bell Frog along with training points:

In [None]:
predictions = Classifier.predict(
    gdf_stats_all[column_names[0:-1]].interpolate(method="nearest")
)
probas = Classifier.predict_proba(
    gdf_stats_all[column_names[0:-1]].interpolate(method="nearest")
)
gdf_stats_all[field] = predictions
attrs_prob = ["Prob_presence", "Prob_absence"]
for i in range(2):
    attr = attrs_prob[i]
    gdf_stats_all[attr] = probas[:, i]

m = gdf_stats_all.explore(column=attrs_prob[0])
gdf_fauna.explore(m=m)

#### Rasterise predictions if vector file size is too large (optional)

In [None]:
prob_rasterised = xr_rasterize(
    gdf=gdf_stats_all, da=da_segments, attribute_name=attrs_prob[0]
)
prob_rasterised.plot()