In [13]:
# Required for interactive plots

%matplotlib widget

In [None]:
import asyncio
import base64
import json
import os
import shutil
import sys
import uuid
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from enum import Enum
from pathlib import Path

import geopandas as gpd
import IPython.display
import ipyvuetify as vue
import numpy as np
import xarray as xr
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, MlTensorShape3D, MlModelInputNoDataHandling, MlModelInputNoDataHandlingVariant, MlModelOutputNoDataHandling, MlModelOutputNoDataHandlingVariant
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from onnx.checker import check_model
from skl2onnx import to_onnx
from skl2onnx.common.data_types import FloatTensorType, UInt8TensorType
from sklearn.ensemble import RandomForestClassifier

import geoengine as ge
from geoengine.workflow_builder.operators import (
    Expression,
    GdalSource,
    OgrSource,
    Onnx,
    RasterStacker,
    RasterTypeConversion,
    RasterVectorJoin,
    RenameBands,
    TemporalRasterAggregation,
)

sys.path.insert(1, "..")
import labeling

In [None]:
# Type definitions


@dataclass(frozen=True)
class Location(json.JSONEncoder):
    display_name: str
    center: tuple[int, int]
    time: np.datetime64

    def __repr__(self):
        return f"{self.display_name} at {self.center} on {self.time}"

    def to_dict(self) -> dict[str, any]:
        return {"display_name": self.display_name, "center": self.center, "time": self.time.astype(str)}

    def from_dict(d: dict[str, any]) -> "Location":
        return Location(d["display_name"], tuple(d["center"]), np.datetime64(d["time"]))


@dataclass(frozen=True)
class Config:
    instance_url: str
    instance_session_token: str | None
    locations: list[Location]
    """Optionally initial point labels for a location"""
    location_labels: dict[str, Path]
    sentinel2_red_band: str
    sentinel2_green_band: str
    sentinel2_blue_band: str
    sentinel2_nir_band: str
    sentinel2_mask_band: str

In [None]:
##########################
##### CONFIGURE THIS #####
##########################

__config = Config(
    instance_url=os.environ.get("GEOENGINE_INSTANCE_URL", "http://localhost:3030/api"),
    instance_session_token=os.environ.get("GEOENGINE_SESSION_TOKEN", None),
    locations=[
        Location("Cologne", (356766, 5644819), np.datetime64("2021-07-01T00:00:00")),
        Location("Marburg", (483843, 5628614), np.datetime64("2021-07-01T00:00:00")),
    ],
    location_labels={"Cologne": Path("assets/cologne.geojson")},
    sentinel2_red_band="_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B04`",
    sentinel2_green_band="_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B03`",
    sentinel2_blue_band="_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B02`",
    sentinel2_nir_band="_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B08`",
    sentinel2_mask_band="_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`",
)

In [None]:
class Model:
    config: Config
    query_rectangle: ge.QueryRectangle | None
    class_column: str
    class_value_1: str
    class_value_2: str
    sentinel_workflow: ge.Workflow | None
    sentinel_data: xr.DataArray | None
    training_data_file: str
    class_data: xr.DataArray | None

    def __init__(self, config: Config):
        self.config = config
        self.query_rectangle = None
        self.class_column = "class"
        self.class_value_1 = "water"
        self.class_value_2 = "non-water"
        self.sentinel_data = None
        self.training_data_file = f"training_data/{uuid.uuid4()}.geojson"
        self.class_data = None

    def set_location(self, location: Location) -> None:
        """Set the location and query rectangle for the given location"""

        resolution = ge.SpatialResolution(10, 10)
        radius_px = 512
        bbox = ge.BoundingBox2D(
            xmin=location.center[0] - resolution.x_resolution * radius_px,
            xmax=location.center[0] + resolution.x_resolution * radius_px - 1,
            ymin=location.center[1] - resolution.y_resolution * radius_px,
            ymax=location.center[1] + resolution.y_resolution * radius_px - 1,
        )
        self.query_rectangle = ge.QueryRectangle(
            spatial_bounds=bbox,
            time_interval=ge.TimeInterval(location.time, location.time + np.timedelta64(4, "W")),
            srs="EPSG:32632",
            resolution=resolution,
        )

        # DEMO CASE
        if location.display_name == "Cologne" and self.class_value_1 == "water" and self.class_value_2 == "non-water":
            shutil.copyfile("assets/cologne.geojson", self.training_data_file)

    async def _preload_sentinel_data(self) -> list[str]:
        """Preload the sentinel data and store it as bands"""

        user_id = ge.get_session().user_id
        xmin = self.query_rectangle.spatial_bounds.xmin
        band_names = [
            f"{user_id}:B04_{xmin}",
            f"{user_id}:B03_{xmin}",
            f"{user_id}:B02_{xmin}",
            f"{user_id}:B08_{xmin}",
            f"{user_id}:SCL_{xmin}",
        ]
        save_tasks = []
        for new_band_name, band_name in zip(
            band_names,
            [
                self.config.sentinel2_red_band,
                self.config.sentinel2_green_band,
                self.config.sentinel2_blue_band,
                self.config.sentinel2_nir_band,
                self.config.sentinel2_mask_band,
            ],
            strict=False,
        ):
            workflow = ge.register_workflow(GdalSource(band_name))
            try:
                save_tasks.append(
                    workflow.save_as_dataset(
                        query_rectangle=self.query_rectangle.as_raster_query_rectangle_api_dict(),  # TODO: remove conversion  # noqa: E501
                        name=new_band_name,
                        display_name=new_band_name,
                    ).as_future()
                )
            except ge.BadRequestException as e:
                inner_error = ge.GeoEngineException(json.loads(e.body))
                if not inner_error.error == "DatasetNameAlreadyExists":  # re-use existing data
                    raise e

        await asyncio.gather(*save_tasks)

        return band_names

    async def load_sentinel_data(self):
        """Load the sentinel data and store it as an xarray"""

        assert self.query_rectangle is not None, "Query rectangle must be set before loading data"

        # TODO: do not store bands as first step
        band_names = await self._preload_sentinel_data()
        assert len(band_names) == 5, f"Expected 5 bands, got {len(band_names)}"

        red_band = GdalSource(band_names[0])
        green_band = GdalSource(band_names[1])
        blue_band = GdalSource(band_names[2])
        nir_band = GdalSource(band_names[3])
        mask_band = RasterTypeConversion(
            GdalSource(band_names[4]),
            output_data_type="U16",
        )

        workflow = TemporalRasterAggregation(
            aggregation_type="mean",
            granularity="months",
            window_size=1,
            ignore_no_data=True,
            source=RasterStacker(
                sources=[
                    Expression(
                        expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
                        output_type="F32",
                        source=RasterStacker([red_band, mask_band]),
                    ),
                    Expression(
                        expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
                        output_type="F32",
                        source=RasterStacker([green_band, mask_band]),
                    ),
                    Expression(
                        expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
                        output_type="F32",
                        source=RasterStacker([blue_band, mask_band]),
                    ),
                    Expression(
                        expression="if (C == 3 || (C >= 7 && C <= 11)) { NODATA } else { (A - B) / (A + B) }",
                        output_type="F32",
                        source=RasterStacker([nir_band, red_band, mask_band]),
                    ),
                ],
                rename=RenameBands.rename(["red", "green", "blue", "ndvi"]),
            ),
        )

        self.sentinel_workflow = ge.register_workflow(workflow)

        self.sentinel_data = await self.sentinel_workflow.raster_stream_into_xarray(
            query_rectangle=self.query_rectangle,
            clip_to_query_rectangle=True,
            bands=[0, 1, 2, 3],
        )

    def _load_training_data(self) -> gpd.GeoDataFrame:
        """Query the training data for the current location"""

        assert self.query_rectangle is not None, "Query rectangle must be set before loading data"
        assert self.sentinel_workflow is not None, "Sentinel workflow must be fixed before training data"

        labels_name = ge.upload_dataframe(
            gpd.read_file(self.training_data_file).set_crs(self.query_rectangle.srs, allow_override=True)
        )

        training_workflow = RasterVectorJoin(
            raster_sources=[self.sentinel_workflow.workflow_definition().operator],
            vector_source=OgrSource(labels_name),
            names=ge.workflow_builder.operators.ColumnNames.default(),
            temporal_aggregation="none",
            feature_aggregation="first",
            temporal_aggregation_ignore_nodata=True,
            feature_aggregation_ignore_nodata=True,
        )

        training_workflow = ge.register_workflow(training_workflow)

        return training_workflow.get_dataframe(self.query_rectangle)

    def _train_classifier(self, training_df: gpd.GeoDataFrame) -> RandomForestClassifier:
        """Train a classifier on the given training data"""

        X = training_df[["red", "green", "blue", "ndvi"]].to_numpy().astype(np.float32)

        y = training_df[self.class_column].to_numpy().astype(np.uint8)

        clf = RandomForestClassifier(random_state=42)

        clf.fit(X, y)

        return clf

    async def train_and_predict(self):
        """Train a model and predict the classes"""

        training_df = self._load_training_data()
        clf = self._train_classifier(training_df)

        onnx_clf = to_onnx(
            clf,
            initial_types=[("X", FloatTensorType((None, 4)))],
            final_types=[
                ("label", UInt8TensorType((None, 1))),
                ("probability", FloatTensorType((None, 1))),
            ],
            options={"zipmap": False},  # `probability` is a matrix instead of dictionaries
            target_opset=9,  # `target_opset` is the ONNX version to use
        )

        check_model(onnx_clf, full_check=True)

        model_name = f"{ge.get_session().user_id}:rf_{uuid.uuid4()}"

        ge.register_ml_model(
            onnx_model=onnx_clf,
            model_config=ge.ml.MlModelConfig(
                name=model_name,
                file_name="model.onnx",
                metadata=MlModelMetadata(
                    inputType=RasterDataType.F32,
                    outputType=RasterDataType.U8,
                    inputShape=MlTensorShape3D(x=1, y=1, bands=4),
                    outputShape=MlTensorShape3D(x=1, y=1, bands=1),
                    inputNoDataHandling=MlModelInputNoDataHandling(variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA),
                    outputNoDataHandling=MlModelOutputNoDataHandling(variant=MlModelOutputNoDataHandlingVariant.NANISNODATA),
                ),
                display_name="Decision Tree",
                description="A simple decision tree model",
            ),
        )

        model_workflow = ge.register_workflow(
            Onnx(
                source=self.sentinel_workflow.workflow_definition().operator,
                model=model_name,
            )
        )

        self.class_data = await model_workflow.raster_stream_into_xarray(
            query_rectangle=self.query_rectangle,
            clip_to_query_rectangle=True,
            bands=[0],
        )

In [None]:
with open("assets/favicon.ico", "rb") as image_file:
    favicon = base64.b64encode(image_file.read()).decode()

favicon_changer = IPython.display.Javascript(f"""
    document.querySelector("link[rel='icon']").href = "data:image/x-icon;base64,{favicon}";
""")

IPython.display.display(favicon_changer)

In [None]:
%%HTML

<!-- Remove margins and paddings from the HTML view of the notebook -->

<style>
  div.lm-Widget,
  div.jp-Cell-inputWrapper,
  div.jp-Cell-outputWrapper,
  div.jp-Cell {
    overflow: auto;
    border: none;
    padding: 0 !important;
    margin: 0 !important;
    contain: none !important; /* fixes fixed footer problem */
  }
  body {
    padding: 0 !important;
    margin: 0 !important;
  }
</style>

In [None]:
class View(vue.App):
    location_btn = vue.Btn(children=["Query Sentinel-2 Data"])
    location_select = vue.Select(label="Location and Time")
    class1_input = vue.TextField(label="Class 1 Label")
    class2_input = vue.TextField(label="Class 2 Label")
    slot1 = vue.Content(
        children=[
            location_select,
            vue.Row(children=[vue.Col(children=[class1_input]), vue.Col(children=[class2_input])]),
            location_btn,
        ]
    )

    slot2 = vue.Html(
        tag="div",
        children=[
            vue.Html(tag="em", children=["Select a location and time to query Sentinel-2"]),
        ],
    )
    predict_btn = vue.Btn(children=["Train Model and Predict"])

    slot3 = vue.Html(
        tag="div",
        children=[
            vue.Html(tag="em", children=["You need to create training data first"]),
        ],
    )

    error_bar = vue.Alert(children=[], color="")

    def __init__(self, config: Config, title="Title"):
        super().__init__()

        self.location_select.items = [
            {"text": str(location), "value": location.to_dict()} for location in config.locations
        ]
        self.location_select.v_model = config.locations[0].to_dict()

        self.children = [
            vue.AppBar(
                children=[
                    vue.Img(
                        src="assets/GeoEngine_Mainlogo_Line_Schutzraum.svg",
                        max_width=str(2048 / 12),
                        max_height=str(531 / 12),
                        contain=True,
                    ),
                    vue.Spacer(),
                    vue.ToolbarTitle(children=[title]),
                ],
            ),
            vue.Container(
                children=[
                    vue.Card(
                        children=[
                            vue.CardTitle(children=["Set Location and Time"]),
                            vue.CardSubtitle(
                                children=[
                                    "Select a location and time to query Sentinel-2's R,G,B and NDVI data. The data will be used to train a model. "  # noqa: E501
                                    "RGB bands serve as background for the labeling tool, NDVI is used as an additional feature."  # noqa: E501
                                ]
                            ),
                            vue.CardText(children=[self.slot1]),
                        ]
                    ),
                    vue.Html(tag="br"),
                    vue.Card(
                        children=[
                            vue.CardTitle(children=["Draw Training Data"]),
                            vue.CardSubtitle(
                                children=[
                                    "Here you can draw the training data for the model. Just click on the map to label the pixels. "  # noqa: E501
                                    "Switch between the classes with the buttons below the map."
                                ]
                            ),
                            vue.CardText(children=[self.slot2]),
                        ]
                    ),
                    vue.Html(tag="br"),
                    vue.Card(
                        children=[
                            vue.CardTitle(children=["Train Model and Predict"]),
                            vue.CardSubtitle(
                                children=[
                                    "We will train a simple random forest classifier on the training data and predict the classes. "  # noqa: E501
                                    "The predictor will use the RGB bands and the NDVI index to predict the class of each pixel."  # noqa: E501
                                ]
                            ),
                            vue.CardText(children=[self.slot3]),
                        ]
                    ),
                    vue.Html(tag="br"),
                    self.error_bar,
                    vue.Html(tag="br"),
                ]
            ),
            vue.Footer(children=[vue.Col(children=["2025 – Geo Engine GmbH"], class_="text-center")], app=True),
        ]

    def register_location_button(self, callback: Callable[[Location, str, str], None]) -> None:
        def on_click(_widget, _event, _data):
            location = Location.from_dict(self.location_select.v_model)
            callback(location, self.class1_input.v_model, self.class2_input.v_model)

        self.location_btn.on_event("click", on_click)

    def register_predict_button(self, callback: Callable[[], None]) -> None:
        def on_click(_widget, _event, _data):
            callback()

        self.predict_btn.on_event("click", on_click)

    def disable_1(self) -> None:
        self.location_btn.disabled = True
        self.location_select.disabled = True
        self.class1_input.disabled = True
        self.class2_input.disabled = True

    def disable_2(self) -> None:
        self.predict_btn.disabled = True

    def enable_2(self) -> None:
        self.predict_btn.disabled = False

    def set_2_loading(self) -> None:
        self.slot2.children = [vue.ProgressLinear(indeterminate=True)]

    def set_3_loading(self) -> None:
        self.slot3.children = [vue.ProgressLinear(indeterminate=True)]

    def set_1_labels(self, class_label_1: str, class_label_2: str) -> None:
        self.class1_input.v_model = class_label_1
        self.class2_input.v_model = class_label_2

    def set_2_plot(
        self, *, filename: str, class_column: str, label_1: str, label_2: str, crs: str, background: xr.DataArray
    ) -> None:
        point_labeling_tool = labeling.PointLabelingTool(
            filename=filename,
            class_column=class_column,
            classes={
                label_1: {
                    "value": 1,
                    "color": "blue",
                },
                label_2: {
                    "value": 0,
                    "color": "green",
                },
            },
            crs=crs,
            background=lambda ax: (
                ax.set_facecolor("black"),
                background.isel(time=0, band=[0, 1, 2]).plot.imshow(
                    rgb="band",
                    vmax=4000,
                    ax=ax,
                ),
            ),
            figsize=(10, 10),
        )
        self.slot2.children = [
            point_labeling_tool.children[0],
            point_labeling_tool.fig.canvas,
            self.predict_btn,
        ]

    def set_3_plot(self, data: xr.DataArray, classes: list[tuple[str, str]], background: xr.DataArray) -> None:
        with plt.ioff():
            fig, ax = plt.subplots(
                figsize=(10, 10),
                constrained_layout=True,
            )

        fig.canvas.header_visible = False
        fig.canvas.footer_visible = False

        ax.set_facecolor("black")
        background.isel(time=0, band=[0, 1, 2]).plot.imshow(
            rgb="band",
            vmax=4000,
            ax=ax,
        )

        data.isel(band=0, time=0).plot.imshow(
            alpha=0.6,
            add_colorbar=False,
            colors=[c for (_, c) in classes],
            levels=len(classes) + 1,
            vmax=len(classes),
            ax=ax,
        )

        legend_handles = [Circle((0.5, 0.5), 1, label=name, facecolor=color) for (name, color) in classes]
        ax.legend(
            handles=legend_handles,
            loc="center left",
            bbox_to_anchor=(1, 0.5),
        )

        self.slot3.children = [fig.canvas]

    def set_error(self, error: str | None) -> None:
        if error is None:
            self.error_bar.children = []
            self.error_bar.color = ""
            return
        self.error_bar.children = [error]
        self.error_bar.color = "error"


vue.theme.themes.light.primary = "#3fab39"
vue.theme.themes.light.secondary = "#0087bc"

__view = View(__config, "Simple Random Forest Two-Class Classifier on Sentinel-2 Images")
__view

View(children=[AppBar(children=[Img(contain=True, layout=None, max_height='44.25', max_width='170.666666666666…

In [None]:
class Mode(Enum):
    SETUP = 1
    LABELING_LOADING = 2
    LABELING = 3
    CLASSIFICATION_LOADING = 4
    CLASSIFICATION = 5


class ViewController:
    view: View
    model: Model
    mode: Mode = Mode.SETUP

    def __init__(self, view: View, model: Model):
        self.view = view
        self.model = model

        ge.initialize(server_url=self.model.config.instance_url, token=self.model.config.instance_session_token)

    def update(self) -> None:
        """Update the view based on the current mode"""

        self._wrap_error(self._update)()

    def _wrap_error(self, f: Callable) -> Callable:
        """Wrap a function to catch exceptions and set the error message"""

        def wrapped(*args, **kwargs):
            try:
                f(*args, **kwargs)
            except Exception as e:
                self.view.set_error(str(e))

        return wrapped

    def _wrap_error_async(self, f: Coroutine) -> Coroutine:
        """Wrap an async function to catch exceptions and set the error message"""

        async def wrapped(*args, **kwargs):
            try:
                await f(*args, **kwargs)
            except Exception as e:
                self.view.set_error(str(e))

        return wrapped

    def _update(self) -> None:
        """Update the view based on the current mode"""

        match self.mode:
            case Mode.SETUP:

                def _setup(location: Location, class_value_1: str, class_value_2: str) -> None:
                    self.model.class_value_1 = class_value_1
                    self.model.class_value_2 = class_value_2
                    self.model.set_location(location)
                    self.mode = Mode.LABELING_LOADING
                    self.update()

                self.view.set_1_labels(self.model.class_value_1, self.model.class_value_2)
                self.view.register_location_button(self._wrap_error(_setup))
            case Mode.LABELING_LOADING:

                async def _load_sentinel_data() -> None:
                    await self.model.load_sentinel_data()
                    self.mode = Mode.LABELING
                    self.update()

                self.view.disable_1()
                self.view.set_2_loading()
                asyncio.create_task(self._wrap_error_async(_load_sentinel_data)())
            case Mode.LABELING:

                def _train_and_predict() -> None:
                    self.mode = Mode.CLASSIFICATION_LOADING
                    self.update()

                self.view.set_2_plot(
                    filename=self.model.training_data_file,
                    class_column=self.model.class_column,
                    label_1=self.model.class_value_1,
                    label_2=self.model.class_value_2,
                    crs=self.model.query_rectangle.srs,
                    background=self.model.sentinel_data,
                )
                self.view.enable_2()
                self.view.register_predict_button(self._wrap_error(_train_and_predict))
            case Mode.CLASSIFICATION_LOADING:

                async def _train_and_predict() -> None:
                    await self.model.train_and_predict()
                    self.mode = Mode.CLASSIFICATION
                    self.update()

                self.view.disable_2()
                self.view.set_3_loading()
                asyncio.create_task(self._wrap_error_async(_train_and_predict)())
            case Mode.CLASSIFICATION:
                self.view.enable_2()
                self.view.set_3_plot(
                    self.model.class_data,
                    classes=[
                        (self.model.class_value_2, "green"),
                        (self.model.class_value_1, "blue"),
                    ],
                    background=self.model.sentinel_data,
                )


__controller = ViewController(__view, Model(__config))
__controller.update()