In [None]:
import datacube
import xarray as xr
from joblib import load
import matplotlib.pyplot as plt
from datacube.utils.cog import write_cog
import pandas as pd
import geopandas as gpd
import numpy as np
import os
import json
import pickle

from deafrica_tools.datahandling import load_ard
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.plotting import rgb, display_map
from deafrica_tools.classification import predict_xr
from deafrica_tools.spatial import xr_rasterize

from datacube.utils import geometry
from datacube.utils.cog import write_cog

from odc.io.cgroups import get_cpu_quota
from odc.algo import geomedian_with_mads, xr_geomedian

from feature_collection import feature_layers

## Create Dask cluster for running predictions

In [None]:
ncpus = round(get_cpu_quota())
print("ncpus = " + str(ncpus))

client = create_local_dask_cluster(return_client=True, n_workers=1, threads_per_worker=ncpus)

## Read in training data feaure names and class labels

In [None]:
experiment_name = "exp_multipixel_allfeatures_removecorrfeaturesgt0p9_RandomForest"

In [None]:
# Get label dictionary
labels_path = "results/class_labels.json"
with open(labels_path, "r") as json_file:
    labels_dict = json.load(json_file)

# Get model features
feautres_path = f"results/{experiment_name}_features.json"
with open(feautres_path, "r") as json_file:
    features_dict = json.load(json_file)
    
features = features_dict["features"]

## Load trained ML model and shapefile for prediction

To manage memory, we provide a shapefile that splits the area of interest into tiles, which are then looped over

In [None]:
# Choose model and load
model_path = f"results/{experiment_name}.joblib"  #"results/randomforest_model.joblib"
model = load(model_path).set_params(n_jobs=1)

# Choose file containing test areas and load
districts_file = "data/area_redo_update.gpkg" #"data/gridded_province.shp"
districts_gdf = gpd.read_file(districts_file)

# Set results path
results_path = "data/area_redo"

In [None]:
districts_gdf

## Create the query for running the predictions
This uses the existing query from the training data collection notebook, and adds `dask_chunks` as an additional parameter.

In [None]:
# Load the query used for fitting
query_file = "results/query.pickle"

with open(query_file, "rb") as f:
    query = pickle.load(f)
    
# Specify any specific additions to the data query -- e.g. dask_chunks for enabling parallel computation
dask_chunks = {"x": 2000, "y": 2000}
query.update({"dask_chunks": dask_chunks})

query

## Run model over grids

The model will be run for each area of the shapefile, producing a prediction file and a probabilities file. These will be saved to the data folder. The next notebook will then combine each separate file into a single raster map.

If an area has already been processed, it will be skipped, and prediction will resume for any incomplete tiles. This is useful if the process fails partway through, or if you are logged out of the sandbox before completion

In [None]:
area_of_interest_gdf = districts_gdf
district_column = "id"

dc = datacube.Datacube(app="crop_type_ml")

for index, district in area_of_interest_gdf.iterrows():
    
    # Set up geometry
    district_name = str(int(district[district_column]))
    print(f"Processing {district_name}")
    
    # Check if district has already been processed. If so, skip
    output_filename = f"{results_path}/district_{district_name}_croptype_prediction.tif"
    if os.path.exists(output_filename):
        print("Completed; Skipping")
        continue

    # set up query based on district polygon
    geom = geometry.Geometry(geom=district.geometry, crs=area_of_interest_gdf.crs)
    query.update({"geopolygon": geom})

    # Load the feature data
    print("    Loading feature data")
    data = feature_layers(query).persist()
    
    
    # Only keep features that are used by the model
    data = data[features]

    #predict using the imported model
    predicted = predict_xr(model,
                           data.unify_chunks(),
                           proba=True,
                           persist=True,
                           clean=True,
                           return_input=False
                          ).astype(np.uint8).persist()
    
    # Load masks and clip
    crop_mask_query = query.copy()
    crop_mask_query.update({"time": "2019"})

    # Load the crop mask
    print("    Loading crop_mask")
    crop_mask = dc.load(product="crop_mask", **crop_mask_query)
    
    # Create a mask for the district
    print("    Getting district mask")
    district_mask = xr_rasterize(
        gdf=gpd.GeoDataFrame({"DISTRICT": [district_name], "geometry": [district.geometry]}, crs=area_of_interest_gdf.crs),
        da=predicted,
        transform=predicted.geobox.transform,
        crs="EPSG:6933",
    )

    # set the no data value
    NODATA = 255

    # Mask the predictions to
    print("    Preparing predictions")
    predicted_masked = (
        predicted.Predictions.where((crop_mask.filtered == 1) & (district_mask==1), NODATA)
    ).compute()
    
    predicted_masked.attrs["nodata"] = NODATA
    
    # Write to cog
    prediction_file = f"{results_path}/district_{district_name}_croptype_prediction.tif"
    print(f"    Writing predictions to {prediction_file}")
    write_cog(
        predicted_masked,
        fname=prediction_file,
        overwrite=True,
        nodata=255,
    )
    
    del predicted_masked
    
    probability_masked = (
        predicted.Probabilities.where((crop_mask.filtered == 1) & (district_mask==1), NODATA)
    ).compute()
    
    probability_masked.attrs["nodata"] = NODATA
    
    probabilities_file = f"{results_path}/district_{district_name}_croptype_probabilities.tif"
    print(f"    Writing probabilities to {probabilities_file}")
    write_cog(
        probability_masked,
        fname=probabilities_file,
        overwrite=True,
        nodata=255,
    )
    
    del probability_masked
    
    del crop_mask
    del district_mask

    

## Close the dask client

In [None]:
client.close()