In [1]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import earthaccess
import h5netcdf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyinterp.backends.xarray  # Module that handles the filling of undefined values.
import pyinterp.fill
import seaborn as sns
import xarray as xr
import rioxarray as rxr
from matplotlib.patches import Rectangle
import logging
import xarray as xr
import dask.array as da
from dask_ml.decomposition import PCA
from dask_ml.cluster import KMeans
from dask.diagnostics import ProgressBar
from dask.distributed import Client

In [2]:
client = Client()

tspan = ("2024-04-01", "2025-04-30")
bbox = (123.084, -29.566, 138.024, -17.050)
res_rf = earthaccess.search_data(
    short_name="PACE_OCI_L3M_SFREFL",
    temporal=tspan,
    granule_name='*.MO.*0p1deg*',
    bounding_box=bbox
)
path_rf = earthaccess.open(res_rf)
ds_rf = xr.open_mfdataset(path_rf, combine="nested", concat_dim="date")
min_lon, max_lat, max_lon, min_lat = bbox
ds_rf = ds_rf.sel(lat=slice(min_lat, max_lat), lon=slice(min_lon, max_lon))
ds_rf

QUEUEING TASKS | :   0%|          | 0/13 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/13 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/13 [00:00<?, ?it/s]

Unnamed: 0,Array,Chunk
Bytes,113.58 MiB,54.00 kiB
Shape,"(13, 126, 149, 122)","(1, 16, 108, 8)"
Dask graph,3744 chunks in 41 graph layers,3744 chunks in 41 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 113.58 MiB 54.00 kiB Shape (13, 126, 149, 122) (1, 16, 108, 8) Dask graph 3744 chunks in 41 graph layers Data type float32 numpy.ndarray",13  1  122  149  126,

Unnamed: 0,Array,Chunk
Bytes,113.58 MiB,54.00 kiB
Shape,"(13, 126, 149, 122)","(1, 16, 108, 8)"
Dask graph,3744 chunks in 41 graph layers,3744 chunks in 41 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.75 kiB,768 B
Shape,"(13, 3, 256)","(1, 3, 256)"
Dask graph,13 chunks in 40 graph layers,13 chunks in 40 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 9.75 kiB 768 B Shape (13, 3, 256) (1, 3, 256) Dask graph 13 chunks in 40 graph layers Data type uint8 numpy.ndarray",256  3  13,

Unnamed: 0,Array,Chunk
Bytes,9.75 kiB,768 B
Shape,"(13, 3, 256)","(1, 3, 256)"
Dask graph,13 chunks in 40 graph layers,13 chunks in 40 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [3]:
dims_to_remove = {'rgb', 'eightbitcolor'}
vars_to_drop = [var for var in ds_rf.data_vars
                if dims_to_remove & set(ds_rf[var].dims)]
ds_rf = ds_rf.drop_vars(vars_to_drop)

In [None]:



# ----------------------------------------
# Setup logging
# ----------------------------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger()

# ----------------------------------------
# Begin preprocessing
# ----------------------------------------
log.info("Starting clustering pipeline...")

# Step 1: Chunk the dataset
log.info("Chunking dataset...")
chunked = ds_rf.chunk({'lat': 100, 'lon': 100})

# Step 2: Select variable and reshape
log.info("Stacking spatial dimensions and reshaping...")
stacked_da = chunked['rhos'].stack(samples=('lat', 'lon')).transpose('samples', 'date', 'wavelength')
X = stacked_da.data
X_flat = X.reshape((X.shape[0], -1))
log.info(f"Flattened data shape: {X_flat.shape}")

# Step 3: Rechunk
log.info("Rechunking to ensure one chunk along feature axis...")
X_flat = X_flat.rechunk({1: -1})

# Step 4: Fill NaNs
log.info("Filling NaNs with zeros...")
X_flat = da.where(da.isnan(X_flat), 0, X_flat)

# # Step 5: PCA
# log.info("Fitting PCA for dimensionality reduction...")
# pca = PCA(n_components=10)
# with ProgressBar():
#     X_reduced = pca.fit_transform(X_flat)
#     X_reduced.compute_chunk_sizes()
# log.info(f"PCA reduced shape: {X_reduced.shape}")
X_reduced = X_flat
# Step 6: KMeans
log.info("Fitting KMeans clustering...")
kmeans = KMeans(n_clusters=15, init_max_iter=1, max_iter=1, oversampling_factor=10, random_state=0)
with ProgressBar():
    kmeans.fit(X_reduced)
log.info("KMeans clustering complete.")

# Step 7: Convert labels back to spatial grid
log.info("Converting labels to xarray...")
labels = kmeans.labels_
labels_xr = xr.DataArray(labels, coords={'samples': stacked_da['samples']}, dims='samples')
labels_2d = labels_xr.unstack('samples')

log.info("Clustering pipeline complete.")

2025-08-05 18:05:20,551 - INFO - Starting clustering pipeline...
2025-08-05 18:05:20,552 - INFO - Chunking dataset...
2025-08-05 18:05:20,621 - INFO - Stacking spatial dimensions and reshaping...
2025-08-05 18:05:20,690 - INFO - Flattened data shape: (18774, 1586)
2025-08-05 18:05:20,691 - INFO - Rechunking to ensure one chunk along feature axis...
2025-08-05 18:05:20,696 - INFO - Filling NaNs with zeros...
2025-08-05 18:05:20,699 - INFO - Fitting KMeans clustering...
2025-08-05 18:05:20,701 - INFO - Starting _check_array
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
