**AIM:** 

run kmeans fit on all of central

**STEPS:**

* load data for all of central region
* mask to central region
* mask to crop mask
* flatten and remove nans
* fit kmeans to all central crop data
* look at results in a number of small areas using 5, 10, 15 crop classes
* Use model to predict everywhere. Use xr_predict to deal with areas that aren't needed

In [None]:
import pickle

import datacube
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
import os
from datacube.utils import geometry
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.classification import predict_xr, sklearn_flatten, sklearn_unflatten
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import rgb
from deafrica_tools.spatial import xr_rasterize
from deafrica_tools.temporal import temporal_statistics, xr_phenology
from sklearn.cluster import DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler

from feature_extraction import feature_layers

In [None]:
create_local_dask_cluster()

In [None]:
# Create data and results directories if they don't exist
if not os.path.exists("data"):
    os.makedirs("data")

if not os.path.exists("results"):
    os.makedirs("results")

## Step 1

Get central district and crop mask areas

In [None]:
admin_boundaries_file = (
    "data/admin-boundaries/GRID3_Zambia_Administrative_Boundaries_Districts_2020.shp"
)

admin_boundaries_gdf = gpd.read_file(admin_boundaries_file).to_crs("EPSG:6933")

province = "Central"
province_boundaries_gdf = admin_boundaries_gdf.loc[
    admin_boundaries_gdf["PROVINCE"] == province
]

In [None]:
province_boundaries_gdf.head()

In [None]:
province_gdf = gpd.GeoDataFrame(
    {
        "province": [province],
        "geometry": province_boundaries_gdf["geometry"].unary_union,
    },
    crs=province_boundaries_gdf.crs,
)

## Step 3

Lazy load data, create central mask, then mask crops

In [None]:
dc = datacube.Datacube(app="crop_type_ml")

# Write a general query
time = "2021"
resolution = (-20, 20)
output_crs = "EPSG:6933"

query = {
    "time": time,
    "resolution": resolution,
    "output_crs": output_crs,
    "dask_chunks": {"time": 1, "x": 2000, "y": 2000},
}

In [None]:
area_of_interest_gdf = province_boundaries_gdf
district_column = "DISTRICT"
output_dir = "data"
output_prefix = "2021_features_cropmasked"

file_list = []

for index, district in area_of_interest_gdf.iterrows():
    
    district_name = district[district_column]
    print(f"Processing {district_name}")

    # set up query based on polygon
    geom = geometry.Geometry(geom=district.geometry, crs=area_of_interest_gdf.crs)
    q = {"geopolygon": geom}

    # merge polygon query with user supplied query params
    query.update(q)

    # Get the features for every pixel
    feature_data = feature_layers(query).squeeze(dim="time", drop=True)
    
    crop_mask_query = query.copy()
    crop_mask_query.update({"time": "2019"})
    
    # Load the crop mask
    crop_mask = dc.load(
        product="crop_mask_southeast",
        **crop_mask_query
    )
    
    # Create a mask
    district_mask = xr_rasterize(
        gdf=gpd.GeoDataFrame({"DISTRICT": [district_name], "geometry": [district.geometry]}, crs=area_of_interest_gdf.crs),
        da=feature_data,
        transform=feature_data.geobox.transform,
        crs="EPSG:6933",
    )
    
    # Filter to crop pixels within the district
    district_crop_data = feature_data.where((crop_mask.filtered == 1) & (district_mask == 1))
    
    # before reshaping, get list of data variables, which are the names of each feature.
    feature_list = list(district_crop_data.data_vars)
    
    # Reshape to get input for model (array where each row is a pixel
    crop_data_for_model = district_crop_data.stack(pixel=("y", "x")).load()
    print(f"    Converted to list of pixels. Shape = {crop_data_for_model.to_array().shape}")
    
    # Drop all rows containing nan observations
    crop_data_for_model = crop_data_for_model.dropna(dim="pixel", how="any")
    print(f"    Dropped pixels containing nans. Shape = {crop_data_for_model.to_array().shape}")
    
    # Convert to Xarray Dataarray before converting to Numpy
    crop_data_for_model = crop_data_for_model.to_array()
    
    # Convert to numpy array. Tanspose for use with sklearn
    crop_data_for_model = crop_data_for_model.to_numpy()
    crop_data_for_model = np.transpose(crop_data_for_model)
    
    # prepare for pickling
    # Output the feature list (containing variable names) and the crop-masked data as a numpy array
    output_data = (feature_list, crop_data_for_model)
    
    #pickle the data for later use
    pickle_file = f"{output_dir}/{output_prefix}_{district_name}.pickle"
    file_list.append(pickle_file)

    with open(pickle_file, "wb") as f:
        pickle.dump(output_data, f)
        
    print(f"    Size of pickled output: {os.path.getsize(pickle_file)}")