![Digital Earth Pacific](dep.png)

### Digital Earth Pacific Notebook 2 Predict Random Forest Machine Learning (ML) Model

The objective of this notebook is to run the prediction based on the machine learning model that you trained in notebook 1. 

Through the notebook you will be working through the following steps: 

1. **Setting up your area of interest**  
2. **Setting up your time of interest** 
3. **Run the machine learning model to do a predicted classification of every pixel within your area of interest at your time of interest using the trained model from notebook 1**  
4. **Exploring the model outputs in an interactive map**  

In [None]:
import geopandas as gpd
import joblib
import numpy as np
import odc.geo.xr  # noqa
from dask import config
from dask.distributed import Client as dask_client
from odc.stac import load
from pystac_client import Client
from shapely import geometry
import depal_mh as dep
from depal_mh import predict_xr
from model import get_overlay
from ipyleaflet import Map, LayersControl, basemaps

In [None]:
%reload_ext autoreload
%autoreload 2

config.set({'distributed.worker.daemon': False})

## Find and load data

Load data and set up your array to use for prediction

In [None]:
# Configure some things up front
chunks = dict(x=2048, y=2048)
year = "2023"
country_code = "mh"
model_file_name = country_code + "_lulc.model"

aoi = dep.get_country_admin_boundary("Marshall Islands", "Atoll", "Majuro")
bbox = dep.get_bbox(aoi)
bbox_geometry = geometry.box(*bbox)

gdf = gpd.GeoDataFrame({'geometry': [bbox_geometry]}, crs='EPSG:4326')
gdf.explore()

In [None]:
catalog = "https://stac.staging.digitalearthpacific.org"
client = Client.open(catalog)

# Search for Sentinel-2 GeoMAD data
items = client.search(
    collections=["dep_s2_geomad"],
    bbox=bbox,
    datetime=year
).items()

# Load the data
data = load(items, chunks=chunks, bbox=bbox, resolution=10).squeeze("time")

#coastal clip
data = dep.do_coastal_clip(aoi, data, buffer=0)
data

In [None]:
loaded_model = joblib.load(model_file_name)

## Run the prediction (this takes some time)

In [None]:
filled = data.fillna(-9999.0)

# This runs the actual prediction
#with dask_client():
with dask_client(
    n_workers=8, threads_per_worker=8, memory_limit="12GB"
):
    predicted = predict_xr(loaded_model, filled, proba=True)

    # Convert to int
    cleaned_predictions = predicted.copy(deep=True)
    cleaned_predictions.predictions.data = predicted.predictions.data.astype(np.int8)
    cleaned_predictions.probabilities.data = predicted.probabilities.data.astype(np.float32)

    cleaned_predictions = cleaned_predictions.rename(
        {"predictions": "lulc", "probabilities": "prob"}
    ).compute()

cleaned_predictions

In [None]:
from matplotlib import colors

classes = dep.get_lulc_class_colours()
     
values_list = [c[0] for c in classes]
color_list = [c[2] for c in classes]

# Build a listed colormap.
c_map = colors.ListedColormap(color_list)
bounds = values_list + [len(classes)]
norm = colors.BoundaryNorm(bounds, c_map.N)

cleaned_predictions.lulc.plot.imshow(cmap=c_map, norm=norm, size=10)

## Save Generated Output as GeoTIFF

In [None]:
# Write GeoTIFF
cleaned_predictions.lulc.odc.write_cog(country_code + "_lulc_" + str(year) + ".tif", overwrite=True)