# ML Pipeline

In [37]:
import geoengine as ge

from geoengine_openapi_client.models import MlModelMetadata, RasterDataType

from sklearn.tree import DecisionTreeClassifier
import numpy as np
from skl2onnx import to_onnx


In [38]:
ge.initialize("http://localhost:3030/api")

## Train a dummy model (TODO: feed with data from Geo Engine)

In [39]:
np.random.seed(0) 
X = np.random.rand(100, 2).astype(np.float32)  # 100 instances, 2 features
y = np.where(X[:, 0] > X[:, 1], 42, 33)  # 1 if feature 0 > feature 42, else 33

clf = DecisionTreeClassifier()
clf.fit(X, y)

test_samples = np.array([[0.1, 0.2], [0.2, 0.1]])
predictions = clf.predict(test_samples)
print("Predictions:", predictions)

# Convert into ONNX format.
from skl2onnx import to_onnx

onx = to_onnx(clf, X[:1], target_opset=9) # target_opset is the ONNX version to use

Predictions: [33 42]


## Register it with Geo Engine

In [40]:
model_name = f"{ge.get_session().user_id}:decision_tree"

metadata = MlModelMetadata(
    file_name="model.onnx",
    input_type=RasterDataType.U8,
    num_input_bands=2,
    output_type=RasterDataType.U8,
)

ge.register_ml_model(onnx_model=onx, name=model_name, metadata=metadata, display_name="Decision Tree", description="A simple decision tree model")

# Apply model using the ONNX operator

In [41]:
bands = [ge.workflow_builder.operators.GdalSource("ndvi"),
         ge.workflow_builder.operators.TimeShift(source=ge.workflow_builder.operators.GdalSource("ndvi"), 
                                                 shift_type="relative", 
                                                 granularity="months", 
                                                 value=1)]

stack = ge.workflow_builder.operators.RasterStacker(sources = bands)

onnx = ge.workflow_builder.operators.Onnx(source=stack, model=model_name)

workflow_dict = onnx.to_workflow_dict()

workflow = ge.register_workflow(workflow_dict)

query = ge.QueryRectangle(
    ge.BoundingBox2D(-111.533203125, -4.482421875, 114.345703125, 73.388671875),
    ge.TimeInterval(np.datetime64('2014-04-01')),
    ge.SpatialResolution(0.1, 0.1)
)

data = workflow.get_xarray(
    query
)

data.plot()

BadRequestException: Operator: Operator: MachineLearning error: Raster data types of source (U8) does not match model input type (F32).