In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import geopandas as gpd
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import skimage
import xarray as xr
from rasterstats import zonal_stats
from shapely.geometry import shape
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder

from eo_insights.spatial import xr_vectorize
from eo_insights.raster_base import RasterBase, QueryParams, LoadParams
from eo_insights.stac_configuration import de_australia_stac_config
from eo_insights.band_indices import calculate_indices

In [None]:
import logging

logging.basicConfig(level=logging.WARNING)

In [None]:
CURRENT_DIR = Path.cwd()
DATA_PATH = CURRENT_DIR.parent / "data" / "mdba"

## Generate subset of data
Use river red gums and black box in Victoria only.

Skip to next section if subset has already been generated and saved.

In [None]:
veg_species_data = gpd.read_file(DATA_PATH / "all_states_concatenated.shp")

In [None]:
veg_species_subset = veg_species_data[
    veg_species_data["sciname"].str.contains(
        "Eucalyptus camaldulensis|Eucalyptus largiflorens"
    )
]

In [None]:
# Define the mapping from scientific names to common names
mapping = {
    "Eucalyptus camaldulensis": "river red gum",
    "Eucalyptus largiflorens": "black box",
}

# Iterate over the mapping and update the 'common_name' column accordingly
for key, value in mapping.items():
    mask = veg_species_subset["sciname"].str.contains(key, case=False, na=False)
    veg_species_subset.loc[mask, "commonname"] = value

In [None]:
veg_species_vic = veg_species_subset[veg_species_subset["statedb"] == "vic"]

In [None]:
veg_species_vic.head()

In [None]:
veg_species_vic.explore(column="commonname", cmap=["red", "green"])

In [None]:
veg_species_vic.to_file(
    DATA_PATH / "vic_twospecies.gpkg", driver="GPKG", layer="commonname"
)

## Train and test a model

Setting up training and testing data

input: xarray data to use for segmentation, xarray data to use for features, training/testing polygons/points/labels
output: X_train, y_train, X_test, y_test

Steps:
1. Get train and test datasets
2. Load segmentation data
3. Run segmentation to get vectors
4. Match segmentation vectors with occurence points (must be done for train and test, separate because in my case, the data are in two locations)
5. Clean up training data (e.g. randomly sample from "other")
6. Load training data from STAC
7. Run zonal statistics
8. Prepare data for sklearn classifier
9. Train on training data
10. Predict on test data

### Step 1: Get training and testing datasets

In [None]:
veg_species_vic = gpd.read_file(DATA_PATH / "vic_twospecies.gpkg")

In [None]:
areas = gpd.read_file(DATA_PATH / "vic_model_areas.gpkg")

areas.explore()

In [None]:
training_data = gpd.sjoin(
    veg_species_vic,
    areas[areas["purpose"] == "train"].to_crs(veg_species_vic.crs),
    how="inner",
    predicate="within",
)
testing_data = gpd.sjoin(
    veg_species_vic,
    areas[areas["purpose"] == "test"].to_crs(veg_species_vic.crs),
    how="inner",
    predicate="within",
)

### Step 2: Load data for segmentation

In [None]:
def load_segmentation_data(bounding_box):
    # Must return an xarray DataArray

    # Load Sentinel-2 for 2019 - 2020
    segmentation_raster = RasterBase.from_stac_query(
        config=de_australia_stac_config,
        collections=["ga_s2am_ard_3", "ga_s2bm_ard_3"],
        query_params=QueryParams(
            bbox=bounding_box,
            start_date="2019-12-01",
            end_date="2020-11-30",
        ),
        load_params=LoadParams(
            crs="EPSG:3577", resolution=30, bands=["red", "nir", "fmask"]
        ),
    )

    # Calculate NDVI
    segmentation_raster.data = calculate_indices(segmentation_raster.data, ["ndvi"])

    # Apply masking in-place - disabled for now, as there were still nan values after doing medians.
    # segmentation_raster.apply_mask("fmask", nodata=np.nan)

    # Select only NDVI and resample
    segmentation_data = (
        segmentation_raster.data["ndvi"].resample(time="1QS-Dec").median().compute()
    )

    return segmentation_data

In [None]:
training_segmentation_data = load_segmentation_data(training_data.total_bounds)

### Step 3: Run segmentation

In [None]:
def run_segmentation(segmentation_data):
    # Identify dimensions:
    # dims_dict = {dim: size for dim, size in zip(segmentation_data.dims, segmentation_data.shape)}
    output_dims = ("y", "x")
    output_coords = {
        "y": segmentation_data.coords["y"],
        "x": segmentation_data.coords["x"],
    }
    output_attrs = segmentation_data.attrs

    segmentation_data_np = segmentation_data.transpose("y", "x", "time").data

    segments = skimage.segmentation.quickshift(
        segmentation_data_np,
        ratio=1.0,
        kernel_size=2,
        max_dist=10,
        sigma=0,
        convert2lab=False,
        rng=42,
    )

    xr_segments = xr.DataArray(
        segments, coords=output_coords, dims=output_dims, attrs=output_attrs
    ).astype(np.int16)

    gdf_segments = xr_vectorize(xr_segments)

    return gdf_segments

In [None]:
training_segments = run_segmentation(training_segmentation_data)

### Step 4: Match segmentation vectors with occurrence points

Assign the species label as the most common species that appears in the segment.
If no species labels are present, assign the label "other".

In [None]:
def match_segments_to_points(segments, points, label_column):

    segments_with_points = gpd.sjoin(
        segments,
        points[[label_column, "geometry"]].to_crs(segments.crs),
        how="left",
        op="contains",
    )

    # For all segments with no points available, replace with "other"
    segments_with_points[label_column] = segments_with_points[label_column].replace(
        np.nan, "other"
    )

    modal_label = (
        segments_with_points[["attribute", "geometry", label_column]]
        .groupby("attribute")[label_column]
        .apply(lambda x: x.mode())
        .reset_index(0)
    )

    segments_labelled = segments.merge(modal_label)

    return segments_labelled

In [None]:
training_polygons = match_segments_to_points(
    training_segments, training_data, "commonname"
)

In [None]:
training_polygons.explore(column="commonname")

### Step 5: Clean up the training data

Keep all samples for the two species classes.
Limit the number of samples in the "other" class to 30 to help with balancing.

In [None]:
river_red_gum = training_polygons[training_polygons["commonname"] == "river red gum"]
black_box = training_polygons[training_polygons["commonname"] == "black box"]
other = training_polygons[training_polygons["commonname"] == "other"].sample(n=30)

balanced_training_polygons = pd.concat([river_red_gum, black_box, other])

In [None]:
balanced_training_polygons.explore(column="commonname")

### Step 6: Load data for features

In [None]:
def load_feature_data(bounding_box):
    # Must return an xarray Dataset

    # Load Sentinel-2 for 2019 - 2020
    eo_feature_raster = RasterBase.from_stac_query(
        config=de_australia_stac_config,
        collections=["ga_s2am_ard_3", "ga_s2bm_ard_3"],
        query_params=QueryParams(
            bbox=bounding_box,
            start_date="2019-12-01",
            end_date="2020-11-30",
        ),
        load_params=LoadParams(
            crs="EPSG:3577",
            resolution=30,
            bands=["red", "green", "blue", "nir", "fmask"],
        ),
    )

    # Calculate NDVI
    eo_feature_raster.data = calculate_indices(eo_feature_raster.data, ["ndvi"])

    # Do medians
    eo_feature_raster.data = (
        eo_feature_raster.data.resample(time="1QS-Dec").median().compute()
    )

    # Load the DEM - disabled for now as the two arrays can't be easily concatenated right now.
    # Might need to output as a list and then run zonal stats on all
    dem_feature_raster = RasterBase.from_stac_query(
        config=de_australia_stac_config,
        collections=["ga_srtm_dem1sv1_0"],
        query_params=QueryParams(
            bbox=bounding_box,
            start_date="2014",
            end_date="2014",
        ),
        load_params=LoadParams(
            crs="EPSG:3577",
            resolution=30,
            bands=["dem_s"],
        ),
    )

    dem_feature_raster.data = dem_feature_raster.data.compute()

    # feature_raster_data = xr.concat[[eo_feature_raster.data, dem_feature_raster.data]]

    # Apply masking in-place - disabled for now, as there were still nan values after doing medians.
    # eo_feature_raster.apply_mask("fmask", nodata=np.nan)

    return [eo_feature_raster.data, dem_feature_raster.data]

In [None]:
feature_data = load_feature_data(training_data.total_bounds)

In [None]:
feature_data

### Step 7: Run zonal statistics

In [None]:
def get_zonal_stats(
    feature_datasets: list[xr.Dataset],
    geometries,
    zonalstats_list: list[str] = ["median", "std", "percentile_10", "percentile_90"],
):

    final_gdf = None

    for feature_data in feature_datasets:
        n_timesteps = feature_data.dims.get("time")

        for timestep in range(n_timesteps):
            print(f"Computing stats for timestep = {timestep}")

            timestep_xr = feature_data.isel(time=timestep).squeeze()

            for band_name in list(timestep_xr.keys()):

                print(f"    Computing stats for {band_name}")

                data = timestep_xr[band_name].data
                zonalstats = zonal_stats(
                    geometries,
                    data,
                    stats=zonalstats_list,
                    all_touched=True,
                    geojson_out=True,
                    affine=timestep_xr.odc.affine,
                )

                df_stats = pd.DataFrame.from_dict(zonalstats)
                geoms = [shape(j) for j in df_stats["geometry"]]
                gdf_stats = gpd.GeoDataFrame(
                    df_stats, geometry=geoms, crs=timestep_xr.odc.crs
                )

                for stat in zonalstats_list:
                    if n_timesteps > 1:
                        stat_var = f"{band_name}_{timestep}_{stat}"
                    else:
                        stat_var = f"{band_name}_{stat}"
                    gdf_stats[stat_var] = [
                        gdf_stats["properties"][j][stat] for j in range(len(gdf_stats))
                    ]

                gdf_stats = gdf_stats.drop(["properties", "type", "bbox"], axis=1)

                if final_gdf is None:
                    final_gdf = gdf_stats.copy()
                else:
                    final_gdf = pd.concat(
                        [final_gdf, gdf_stats.drop(["id", "geometry"], axis=1)], axis=1
                    )

    return final_gdf

In [None]:
training_features = get_zonal_stats(
    feature_data,
    training_polygons.geometry,
    zonalstats_list=["median", "std", "percentile_10", "percentile_90"],
)

In [None]:
training_features

Attach the training polygons to the features.

In [None]:
training_segments = training_features.merge(
    training_polygons[["geometry", "commonname"]]
)

### Step 8: Prepare data for sklearn classifier

In [None]:
training_data = training_segments.drop(columns=["id", "geometry"])

y_train = training_data["commonname"]
X_train = training_data.drop(["commonname"], axis=1)

le = LabelEncoder()
le.fit(y_train)

print(le.classes_)
y_train_transformed = le.transform(y_train)

### Step 9: Train the Random Forest Classifier

In [None]:
Classifier = RandomForestClassifier(n_estimators=200)
Classifier.fit(X_train, y_train_transformed)

Examine feature importance

In [None]:
column_names = training_data.drop(["commonname"], axis=1).columns
feat_importance_indices = np.argsort(Classifier.feature_importances_)
feat_importance_ordered = np.array(column_names)[feat_importance_indices]
print("Top 10 features: \n", feat_importance_ordered[0:10])

plt.figure(figsize=(5, 3))
plt.barh(
    y=feat_importance_ordered[0:10],
    width=Classifier.feature_importances_[feat_importance_indices][0:10],
)
plt.gca().set_ylabel("Importance", labelpad=6)
plt.gca().set_xlabel("Variable", labelpad=6)

### Step 10: Repeat above steps to prepare test data

In [None]:
print("Loading segmentation data")
testing_segmentation_data = load_segmentation_data(testing_data.total_bounds)

print("Running segmentation")
testing_segments = run_segmentation(testing_segmentation_data)

print("Matching segments to points")
testing_polygons = match_segments_to_points(
    testing_segments, testing_data, "commonname"
)

print("Loading feature data")
test_feature_data = load_feature_data(testing_data.total_bounds)

print("Getting zonal statistics")
testing_features = get_zonal_stats(
    test_feature_data,
    testing_polygons.geometry,
    zonalstats_list=["median", "std", "percentile_10", "percentile_90"],
)

print("Assembling final dataset for ML")
testing_input = testing_features.merge(testing_polygons[["geometry", "commonname"]])

print("Saving test data")
testing_input.to_file(DATA_PATH / "test_samples.gpkg", driver="GPKG")

Prepare test data for classifier

In [None]:
testing_prediction_input = testing_input.drop(columns=["id", "geometry"])

y_test = testing_prediction_input["commonname"]
X_test = testing_prediction_input.drop(["commonname"], axis=1)

y_test_transformed = le.transform(y_test)

Run predictions and append to segments

In [None]:
y_test_pred = Classifier.predict(X_test)

y_test_pred_labels = le.inverse_transform(y_test_pred)

predicted_segments = testing_input.join(
    pd.Series(y_test_pred_labels, name="commonname_prediction")
)

Compute accuracy metrics and confusion matrix

In [None]:
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
)

f1_metric = f1_score(y_test_transformed, y_test_pred, average="macro")
accuracy_metric = accuracy_score(y_test_transformed, y_test_pred)

print(f"F1-Score = {f1_metric}")
print(f"Accuracy Score = {accuracy_metric}")

ConfusionMatrixDisplay(confusion_matrix(y_test_transformed, y_test_pred)).plot()

Display predictions on map, along with species sample points

In [None]:
m = predicted_segments[["commonname", "commonname_prediction", "geometry"]].explore(
    column="commonname_prediction"
)
testing_data.explore(column="commonname", cmap=["red", "green"], m=m)