In [None]:
import os
import pickle

import datacube
import geopandas as gpd
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
from datacube.utils import geometry
from datacube.utils.cog import write_cog
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.classification import predict_xr, sklearn_flatten, sklearn_unflatten
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import rgb
from deafrica_tools.spatial import xr_rasterize
from deafrica_tools.temporal import temporal_statistics, xr_phenology
from feature_extraction import feature_layers
from sklearn.cluster import DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from skimage.measure import label, regionprops

In [None]:
create_local_dask_cluster()

In [None]:
model_file = "results/ml_model.joblib"
sklearn_model = joblib.load(model_file)

districts_file = "data/cropping_propotion_by_district.geojson"
districts_gdf = gpd.read_file(districts_file)

In [None]:
districts_gdf

## Estimate proportion of each class per district

In [None]:
# Connect to data cube
dc = datacube.Datacube(app="crop_type_ml")

# Write a general query
time = "2021"
resolution = (-20, 20)
output_crs = "EPSG:6933"

query = {
    "time": time,
    "resolution": resolution,
    "output_crs": output_crs,
    "dask_chunks": {"time": 1, "x": 2000, "y": 2000},
}

In [None]:
# function to remove areas smaller than 10 pixels

def identify_and_filter_regions(prediction_xr, pixel_count_threshold=10, connectivity=1):
    
    # Convert to numpy for processing
    predictions_np = prediction_xr.to_numpy()
    
    # Need to add one to all values to avoid a 0 class value (skimage treats as background)
    predictions_np = predictions_np + np.ones(predictions_np.shape)
    
    # Set nan to 0 to be classified as the background
    predictions_np = np.nan_to_num(predictions_np, nan=0)
    
    # Run the labelling step
    # Use connectivity one to reduce the chance of keeping long connected roads/edges, and small connected patches
    predictions_labelled = label(predictions_np, connectivity=connectivity, background=0)
    
    # Identify all regions where area is greater than pixel_count_threshold
    labels_to_keep = []

    for region in regionprops(predictions_labelled):
        if region.area >= pixel_count_threshold:
            labels_to_keep.append(region.label)
            
    # Create a mask corresponding to regions to keep
    mask = np.isin(predictions_labelled, labels_to_keep)
    
    # Need to subtract one to all values to return to original class values
    predictions_np = predictions_np - np.ones(predictions_np.shape) 

    # Mask the original prediction
    prediction_masked = np.where(mask, predictions_np, np.nan)
    
    # Reformat as an xarray and return
    prediction_masked_xr = xr.DataArray(prediction_masked, coords=prediction_xr.coords, dims=prediction_xr.dims, attrs=prediction_xr.attrs)
    
    return prediction_masked_xr

In [None]:
area_of_interest_gdf = districts_gdf
district_column = "DISTRICT"
n_classes = sklearn_model.cluster_centers_.shape[0]

classes = [f"class_{i}" for i in range(n_classes)]

for index, district in area_of_interest_gdf.iterrows():
    
    # Set up geometry
    district_name = district[district_column]
    print(f"Processing {district_name}")
    
    # Check if district has already been processed. If so, skip
    output_filename = f"data/district_{district_name}_cropping_propotion_by_class_filtered.csv"
    if os.path.exists(output_filename):
        print("Completed; Skipping")
        continue

    # set up query based on district polygon
    geom = geometry.Geometry(geom=district.geometry, crs=area_of_interest_gdf.crs)
    q = {"geopolygon": geom}

    # merge polygon query with user supplied query params
    query.update(q)

    # Load the feature data
    print("    Loading feature data")
    feature_data = feature_layers(query).squeeze(dim="time", drop=True).load()

    # Load the crop mask
    crop_mask_query = query.copy()
    crop_mask_query.update({"time": "2019"})

    # Load the crop mask
    print("    Loading crop_mask")
    crop_mask = dc.load(product="crop_mask_southeast", **crop_mask_query).load()
    
    # Create a mask for the district
    district_mask = xr_rasterize(
        gdf=gpd.GeoDataFrame({"DISTRICT": [district_name], "geometry": [district.geometry]}, crs=area_of_interest_gdf.crs),
        da=feature_data,
        transform=feature_data.geobox.transform,
        crs="EPSG:6933",
    )

    # Clip the feature data to the crop mask and the district mask:
    district_crop_data = feature_data.where((crop_mask.filtered == 1) & (district_mask==1)).squeeze()

    # Predict the classes
    district_crop_class = predict_xr(sklearn_model, district_crop_data, persist=True)

    # Mask the predictions to
    print("    Preparing predictions")
    district_crop_class_masked = (
        district_crop_class.Predictions.where((crop_mask.filtered == 1) & (district_mask==1))
    )
    
    district_crop_class_masked = (district_crop_class_masked
        .squeeze()
        .assign_attrs({"crs": "EPSG:6933"})
        .rename("class")
        .drop(["time", "spatial_ref"])
    )
    
    # Filter small regions
    print("    Filtering regions with fewer than 10 pixels (at connectivity=1)")
    district_crop_class_masked_filtered = identify_and_filter_regions(district_crop_class_masked)
    
    # Write to cog
    prediction_file = f"data/district_{district_name}_prediction_filtered.tif"
    print(f"    Writing predictions to {prediction_file}")
    write_cog(
        district_crop_class_masked_filtered,
        fname=prediction_file,
        overwrite=True,
    )

    # Loop through each class
    district_class_counts = []

    total_pixels = 0

    print("    Counting pixels in each class")
    for i in range(n_classes):
        count = district_crop_class_masked_filtered.where(district_crop_class_masked_filtered == i).count().item()
        district_class_counts.append(count)

        total_pixels += count
        
    district_crop_propotions_df = pd.DataFrame({'class': classes, 'count': district_class_counts})
    district_crop_propotions_df["proportion"] = district_crop_propotions_df["count"]/total_pixels
    
    proportion_file = f"data/district_{district_name}_cropping_propotion_by_class_filtered.csv"
    district_crop_propotions_df.to_csv(proportion_file)