In [None]:
import pickle

import datacube
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
import os
import glob
from joblib import dump
from datacube.utils import geometry
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.classification import predict_xr, sklearn_flatten, sklearn_unflatten
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import rgb
from deafrica_tools.spatial import xr_rasterize
from deafrica_tools.temporal import temporal_statistics, xr_phenology
from sklearn.cluster import DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler

In [None]:
create_local_dask_cluster()

In [None]:
# Fix random seed to ensure reproducibility while developing
rng = np.random.default_rng(17)

# Read in the district files:
input_prefix = "data/2021_features_cropmasked"


files = glob.glob("{input_prefix}_*.pickle")

data_arrays = []

for i, file in enumerate(files):
    with open(file, "rb") as f: 
        labels, data = pickle.load(f)
        
        # Subsample data
        nrows_to_save = 10000
        random_rows = rng.choice(data.shape[0], size=nrows_to_save, replace=False)
        data = data[random_rows, :]
        
        #data_arrays.append(data_subset)
        data_arrays.append(data.squeeze())
        
# Concatenate to get all data for k-means
all_data = np.concatenate(data_arrays, axis=0)

In [None]:
all_data.shape

## Fit k-means

In [None]:
k = 15
kmeans = KMeans(n_clusters=k, random_state=42)

kmeans.fit(all_data)

## Save model and Scaler

In [None]:
output_model = "results/ml_model.joblib"

dump(kmeans, output_model)