In [None]:
import os
import pickle

import datacube
import geopandas as gpd
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
from datacube.utils import geometry
from datacube.utils.cog import write_cog
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.classification import predict_xr, sklearn_flatten, sklearn_unflatten
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import rgb
from deafrica_tools.spatial import xr_rasterize
from deafrica_tools.temporal import temporal_statistics, xr_phenology
from feature_extraction import feature_layers
from sklearn.cluster import DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sampling_functions import stratified_random_sampling

In [None]:
# Read in geojson with crop proportions per district

crop_prop_district_gdf = gpd.read_file(
    "data/cropping_propotion_by_district.geojson"
)
crop_prop_district_gdf

In [None]:
# Get total number of samples for each district
crop_prop_district_gdf["crop_proportion_of_province"] = (
    crop_prop_district_gdf["crop_area_km2"]
    / crop_prop_district_gdf["crop_area_km2"].sum()
)

n_samples_baseline = 100
total_samples = n_samples_baseline * len(crop_prop_district_gdf)

crop_prop_district_gdf["n_sample"] = (
    crop_prop_district_gdf["crop_proportion_of_province"] * total_samples
).astype(int)

crop_prop_district_gdf

## Sampling

In [None]:
area_of_interest_gdf = crop_prop_district_gdf
district_column = "DISTRICT"

classes_to_remove = ["class_7", "class_11"]

for index, district in area_of_interest_gdf.iterrows():
    
    district_name = district[district_column]
    print(f"Processing {district_name}")
    
    crop_proportion_gdf = area_of_interest_gdf[area_of_interest_gdf["DISTRICT"] == district_name]
    
    # Load area stats
    district_stats_file = f"data/district_{district_name}_cropping_propotion_by_class_filtered.csv"
    district_stats = pd.read_csv(district_stats_file)
    
    # Drop rows corresponding to classes we don't want
    district_stats = district_stats.drop(district_stats.loc[district_stats['class'].isin(classes_to_remove)].index)
    
    # Renormalise such that sum of remaining proportions is 1 after removing classes
    new_prop_total = district_stats["proportion"].sum().item()
    district_stats["proportion"] = district_stats["proportion"]/new_prop_total
    
    # based on number of allocated samples for district, get number of allocated samples per class
    district_stats["n_sample"] = (district_stats["proportion"] * district["n_sample"]).astype(int)
    
    # Prepare array for sampling
    predictions = rioxarray.open_rasterio(
        f"data/district_{district_name}_prediction_filtered.tif", masked=True
    )

    # Assign crs and set classes to remove to nan
    predictions = predictions.assign_attrs({"crs": "EPSG:6933"})

    for class_name in classes_to_remove:

        class_value = int(class_name.split("_")[-1])
        predictions = xr.where(predictions!=class_value, predictions, np.nan, keep_attrs=True)
    
    # Run the sampling
    gdf = stratified_random_sampling(
        predictions,
        district["n_sample"],
        min_sample_n=3,
        min_threshold_prop=0.01,
        n_strategies=3,
        manual_class_ratios=None,
        out_fname=None,
    )
    
    gdf = gdf.drop(columns=['band', 'spatial_ref'])
    
    # Save to geojson
    gdf.to_file(f"results/district_{district_name}_3draws.shp")
    
    # Convert to pandas and then pivot to get columns per draw
    df = pd.DataFrame(gdf.drop(columns='geometry'))
    
    pivot = df.pivot(index=["class", "sample_no"], columns="strategy", values=["lat", "lon"])
    
    # Name columns as draw_x_latitude, draw_x_longitude
    pivot.columns = pivot.columns.get_level_values(1) + '_' +  pivot.columns.get_level_values(0)
    
    # Sort index alphabetically to match coordinates for lat/lon
    pivot = pivot.sort_index(axis=1, ascending=True)
    
    # Save to CSV
    pivot.to_csv(f"results/district_{district_name}_3draws.csv")
    
