In [1]:
!pip install dask-ml

UnboundLocalError: cannot access local variable 'child' where it is not associated with a value

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import earthaccess
import h5netcdf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyinterp.backends.xarray  # Module that handles the filling of undefined values.
import pyinterp.fill
import seaborn as sns
import xarray as xr
import rioxarray as rxr
from matplotlib.patches import Rectangle
import logging
import xarray as xr
import dask.array as da
from dask_ml.decomposition import PCA
from dask_ml.cluster import KMeans
from dask.diagnostics import ProgressBar
from dask.distributed import Client

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger()

auth = earthaccess.login()
client = Client()

In [None]:
tspan = ("2024-06-01", "2024-06-30")
figname = "results/na_prairies"
bbox =  (-97.5, 36.0, -83.0, 45.5)
min_lon_train, max_lat_train, max_lon_train, min_lat_train = bbox

In [None]:
res_rf = earthaccess.search_data(
    short_name="PACE_OCI_L3M_SFREFL",
    temporal=tspan,
    granule_name='*.MO.*0p1deg*',
    bounding_box=bbox
)
path_rf = earthaccess.open(res_rf)
ds_rf = xr.open_mfdataset(path_rf, combine="nested", concat_dim="date")


# Cut out AOI
min_lon, max_lat, max_lon, min_lat = bbox
ds_rf = ds_rf.sel(lat=slice(min_lat, max_lat), lon=slice(min_lon, max_lon))

In [None]:
ds_rf

In [None]:
# 3. Select wavelengths for RGB
wavelengths_rgb = [660, 550, 470]  # nm, approximate visible RGB centers
wavelength_idx = [int(abs(ds_rf.wavelength - wl).argmin()) for wl in wavelengths_rgb]

# 4. Extract RGB bands from rhos
rhos = ds_rf['rhos'].isel(date=0, wavelength=wavelength_idx)  # shape: (lat, lon, 3)

# 5. Rearrange to (3, lat, lon)
rgb = rhos.transpose('wavelength', 'lat', 'lon')

# 6. Normalize and convert to 8-bit
rgb = rgb.clip(min=0)
rgb_uint8 = (rgb / rgb.max() * 255).astype("uint8")

# 7. Assign CRS and spatial dims
rgb_uint8.rio.set_spatial_dims(x_dim='lon', y_dim='lat', inplace=True)
rgb_uint8.rio.write_crs("EPSG:4326", inplace=True)

# 8. Save as GeoTIFF
rgb_uint8.rio.to_raster("pace_rgb_truecolor.tif")

In [None]:
# Remove extra vars
dims_to_remove = {'rgb', 'eightbitcolor'}
vars_to_drop = [var for var in ds_rf.data_vars
                if dims_to_remove & set(ds_rf[var].dims)]
ds_rf = ds_rf.drop_vars(vars_to_drop)

# Set to the desired wavelengths
rhos = [550, 704, 804]
ds_rf.sel(wavelength=rhos)

# Subset out training dataset
ds_train = ds_rf.sel(lat=slice(min_lat_train, max_lat_train), lon=slice(min_lon_train, max_lon_train))

In [None]:
def prep_tile(tile):
    """Reshape a small tile: stack lat/lon and flatten features."""
    log.info(f"Preparing tile with shape {tile.sizes}")
    stacked = tile.stack(samples=("lat", "lon")).transpose("samples", "date", "wavelength")
    X = stacked.data
    log.info(f"  -> stacked shape: {X.shape}, chunks: {X.chunks}")
    X = da.where(da.isnan(X), 0, X)
    X = X.rechunk({0: -1, 1: -1, 2: -1})  # ensure fully rechunked before reshape
    X_flat = X.reshape((X.shape[0], -1))
    log.info(f"  -> reshaped to: {X_flat.shape}, chunks: {X_flat.chunks}")
    return X_flat, stacked

def predict_on_tiles(ds, model, tile_size=200):
    lat_chunks = range(0, ds.dims["lat"], tile_size)
    lon_chunks = range(0, ds.dims["lon"], tile_size)

    predicted_tiles = []

    for lat_start in lat_chunks:
        lat_end = min(lat_start + tile_size, ds.dims["lat"])
        row_tiles = []
        log.info(f"  Processing lat slice {lat_start}:{lat_end}")

        for lon_start in lon_chunks:
            lon_end = min(lon_start + tile_size, ds.dims["lon"])
            log.info(f"  → lon slice {lon_start}:{lon_end}")
            tile = ds["rhos"].isel(lat=slice(lat_start, lat_end), lon=slice(lon_start, lon_end))

            if tile.count().values == 0:
                log.info("Skipping empty tile")
                continue

            try:
                X_flat, stacked = prep_tile(tile)
                X_np = X_flat.compute()
                labels = model.predict(X_np)
                labels_xr = xr.DataArray(labels, coords={"samples": stacked["samples"]}, dims="samples")
                labels_unstacked = labels_xr.unstack("samples")
                row_tiles.append(labels_unstacked)
                log.info(f"  Finished prediction on tile")
            except Exception as e:
                log.warning(f"  Failed to process tile: {e}")
                continue

        if row_tiles:
            log.info(f"Concatenating row tiles along 'lon'")
            predicted_row = xr.concat(row_tiles, dim="lon")
            predicted_tiles.append(predicted_row)

    if predicted_tiles:
        log.info("Concatenating all rows into full label map")
        full_labels = xr.concat(predicted_tiles, dim="lat")
        return full_labels

    log.warning("No tiles were processed successfully.")
    return None

In [None]:
log.info("Training model on subset...")
X_train, _ = prep_tile(ds_train['rhos'])
X_train_np = X_train.compute()
log.info(f"Training data shape: {X_train_np.shape}")

model = KMeans(n_clusters=10, random_state=42)
model.fit(X_train_np)
log.info("KMeans model fitted.")

In [None]:
log.info("Starting tile-wise prediction...")
labels_2d = predict_on_tiles(ds_rf, model, tile_size=200)

if labels_2d is not None:
    labels_2d.name = "kmeans_cluster"
    log.info("Prediction complete. Plotting result...")
    labels_2d.plot.imshow(cmap='tab10')
else:
    log.error("Clustering failed: no valid output generated.")


In [None]:
log.info("Writing output to GeoTIFF...")
labels_2d = labels_2d.sortby("lat", ascending=False)
labels_2d.rio.set_spatial_dims(x_dim='lon', y_dim='lat')
labels_2d.rio.write_crs("EPSG:4326", inplace=True)
labels_2d.rio.to_raster(f"{figname}_kmeans_clusters.tif")
log.info("GeoTIFF saved as kmeans_clusters.tif")

In [None]:
ds_with_labels = ds_rf.copy()
ds_with_labels['cluster'] = labels_2d
ds_with_labels

In [None]:
rhos = ds_with_labels['rhos'].values  # shape: (lat, lon, wavelength)
clusters = ds_with_labels['cluster'].values  # shape: (lat, lon)
wavelengths = ds_with_labels['wavelength'].values

rhos_single = np.squeeze(rhos, axis=0)  # squeeze out the "date" or "time" dimension

rhos_flat = rhos_single.reshape(-1, rhos_single.shape[-1])  

clusters_flat = clusters.ravel()


cluster_ids = np.unique(clusters_flat[~np.isnan(clusters_flat)]).astype(int)

spectra_per_cluster = []

for cid in cluster_ids:
    mask = clusters_flat == cid                    # shape: (lat*lon,)
    cluster_rhos = rhos_flat[mask, :]              # shape: (N, wavelength)
    mean_spectrum = cluster_rhos.mean(axis=0)
    spectra_per_cluster.append((cid, mean_spectrum))


In [None]:
plt.figure(figsize=(10, 6))
for cid, spectrum in spectra_per_cluster:
    plt.plot(wavelengths, spectrum, label=f'Cluster {cid}')

plt.xlabel('Wavelength (nm)')
plt.ylabel('Mean Reflectance')
plt.title('Mean Reflectance Spectrum per Cluster')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f'{figname}_mrs.png')
plt.show()