In [None]:
import json
import pickle

import datacube
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from datacube.utils import geometry
from datacube.utils.cog import write_cog
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.classification import predict_xr
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import display_map, rgb
from deafrica_tools.spatial import xr_rasterize
from feature_collection import feature_layers
from joblib import load
from odc.algo import geomedian_with_mads, xr_geomedian
from odc.io.cgroups import get_cpu_quota

## Create Dask cluster for running predictions

In [None]:
ncpus = round(get_cpu_quota())
print("ncpus = " + str(ncpus))

client = create_local_dask_cluster(return_client=True, n_workers=1, threads_per_worker=ncpus)

## Read in training data feaure names and class labels

In [None]:
experiment_name = "exp_multipixel_allfeatures_removecorrfeaturesgt0p9_RandomForest"

In [None]:
# Get label dictionary
labels_path = "results/class_labels.json"
with open(labels_path, "r") as json_file:
    labels_dict = json.load(json_file)

# Get model features
feautres_path = f"results/{experiment_name}_features.json"
with open(feautres_path, "r") as json_file:
    features_dict = json.load(json_file)
    
features = features_dict["features"]

In [None]:
len(features)

## Load trained ML model and areas to test

In [None]:
# Choose model and load
model_path = f"results/{experiment_name}.joblib"  #"results/randomforest_model.joblib"
model = load(model_path).set_params(n_jobs=1)

# Choose file containing test areas and load
test_areas_file = "data/border_striping_testarea.gpkg"
test_areas_gdf = gpd.read_file(test_areas_file)

In [None]:
model

In [None]:
order = np.argsort(model.feature_importances_)[:]

fig, ax = plt.subplots(figsize=(10, 8))
ax.barh(range(len(np.array(features)[order])), model.feature_importances_[order])
ax.set_xlabel("Importance", fontsize=14)
ax.set_ylabel("Feature", fontsize=14)
ax.set_yticks(np.arange(len(order)))
ax.set_yticklabels(np.array(features)[order], rotation=0, fontsize=8)
plt.ylim([-1,len(np.array(features)[order])])
plt.xlim(0, max(model.feature_importances_)+0.001)

plt.savefig("feature_importance_finalmodel_all.png", bbox_inches="tight")

## Create the query for running the predictions
This uses the existing query from the training data collection notebook, and adds `dask_chunks` as an additional parameter.

In [None]:
# Load the query used for fitting
query_file = "results/query.pickle"

with open(query_file, "rb") as f:
    query = pickle.load(f)
    
# Specify any specific additions to the data query -- e.g. dask_chunks for enabling parallel computation
dask_chunks = {"x": 1000, "y": 1000}
query.update({"dask_chunks": dask_chunks})

query

## Run model over test areas

In [None]:
dc = datacube.Datacube()

predictions = []
area_of_interest_gdf = test_areas_gdf

for index, district in area_of_interest_gdf.iterrows():

    print("working on test region " + str(index))
    
    # set up query based on district polygon
    geom = geometry.Geometry(geom=district.geometry, crs=area_of_interest_gdf.crs)
    query.update({"geopolygon": geom})

    # calculate features
    data = feature_layers(query)
    print(len(data.data_vars))

    # Only keep features that are in the original list of columns
    data = data[features]
    print(len(data.data_vars))

    # predict using the imported model
    predicted = predict_xr(
        model, data, proba=True, persist=True, clean=True, return_input=True
    ).persist()
    
    # Load masks and clip
    crop_mask_query = query.copy()
    crop_mask_query.update({"time": "2019"})

    # Load the crop mask
    print("    Loading crop_mask")
    crop_mask = dc.load(product="crop_mask", **crop_mask_query)
    
    # Create a mask for the district
    print("    Getting district mask")
    district_mask = xr_rasterize(
        gdf=gpd.GeoDataFrame({"DISTRICT": [index], "geometry": [district.geometry]}, crs=area_of_interest_gdf.crs),
        da=predicted,
        transform=predicted.geobox.transform,
        crs="EPSG:6933",
    )

    # set the no data value
    NODATA = np.nan

    # Mask the predictions to
    print("    Preparing predictions")
    predicted_masked = (
        predicted.where((crop_mask.filtered == 1), NODATA)
    ).compute()
    
    predicted_masked.attrs["nodata"] = NODATA

    predictions.append(predicted_masked)

## Visualise predictions over test areas

Resulting figures are saved out to the "results" folder.

In [None]:
figures = [
    plt.subplots(1, 3, figsize=(18, 5), gridspec_kw={"width_ratios": [1, 0.90, 1]})
    for i in range(0, len(predictions))
]

for i, figure in enumerate(figures):
    
    fig, axes = figure

    # get discrete colormap
    cmap = plt.get_cmap("Set3", len(labels_dict))
    # set limits .5 outside true range
    mat = (
        predictions[i]
        .Predictions.astype(int)
        .plot(
            ax=axes[0],
            cmap=cmap,
            vmin=-0.5,
            vmax=9.5,
            add_labels=False,
            add_colorbar=False,
        )
    )

    # tell the colorbar to tick at integers
    cax = plt.colorbar(
        mat,
        ticks=np.arange(min(labels_dict.values()), max(labels_dict.values()) + 1),
        ax=axes[0],
        fraction=0.046,
        pad=0.04,
    )
    cax.ax.set_yticklabels(list(labels_dict.keys()))

    # Plot true colour image
    predictions[i]["NDVI_s2_Q1_2022"].plot(ax=axes[1], add_colorbar=False)
    axes[1].set_xlabel("")
    axes[1].set_ylabel("")

    predictions[i].Probabilities.plot(
        ax=axes[2],
        cmap="magma",
        vmin=0,
        vmax=100,
        add_labels=False,
        add_colorbar=True,
        cbar_kwargs={"fraction": 0.046, "pad": 0.04},
    )

    # Remove axes on all plots
    for ax in axes:
        ax.set_aspect("equal")
        ax.tick_params(
            axis="both",
            which="both",
            top=False,
            bottom=False,
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
        )

    # Add plot titles
    axes[0].set_title("Classified Image")
    axes[1].set_title("NDVI_s2_Q1_2022")
    axes[2].set_title("Probabilities");
    
    
# Save out all figures
for i, figure in enumerate(figures):
    fig, ax = figure
    fig.savefig(f"results/{experiment_name}_test_region_{i}.png", dpi=300, bbox_inches="tight", facecolor="white")

## Close the dask client

In [None]:
client.close()