In [1]:
!pip install dask-ml



In [2]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import earthaccess
import h5netcdf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyinterp.backends.xarray  # Module that handles the filling of undefined values.
import pyinterp.fill
import seaborn as sns
import xarray as xr
import rioxarray as rxr
from matplotlib.patches import Rectangle
import logging
import xarray as xr
import dask.array as da
from dask_ml.decomposition import PCA
from dask_ml.cluster import KMeans
from dask.diagnostics import ProgressBar
from dask.distributed import Client
from dask_ml.decomposition import PCA

In [3]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger()

auth = earthaccess.login()
client = Client()

2025-08-06 21:37:43,465 - INFO - You're now authenticated with NASA Earthdata Login
Perhaps you already have a cluster running?
Hosting the HTTP server on port 35527 instead


In [4]:
tspan = ("2024-05-01", "2025-05-01")
bbox = (113.338953078, -43.6345972634, 153.569469029, -10.6681857235)
res_rf = earthaccess.search_data(
    short_name="PACE_OCI_L3M_SFREFL",
    temporal=tspan,
    granule_name='*.MO.*0p1deg*',
    bounding_box=bbox
)
path_rf = earthaccess.open(res_rf)
ds_rf = xr.open_mfdataset(path_rf, combine="nested", concat_dim="date")


# Cut out AOI
min_lon, max_lat, max_lon, min_lat = bbox
ds_rf = ds_rf.sel(lat=slice(min_lat, max_lat), lon=slice(min_lon, max_lon))

2025-08-06 21:37:51,208 - INFO - Granules found: 12
2025-08-06 21:38:02,263 - INFO - Opening 12 granules, approx size: 18.4 GB
2025-08-06 21:38:02,264 - INFO - using endpoint: https://obdaac-tea.earthdatacloud.nasa.gov/s3credentials


QUEUEING TASKS | :   0%|          | 0/12 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/12 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/12 [00:00<?, ?it/s]

In [5]:
# Remove extra vars
dims_to_remove = {'rgb', 'eightbitcolor'}
vars_to_drop = [var for var in ds_rf.data_vars
                if dims_to_remove & set(ds_rf[var].dims)]
ds_rf = ds_rf.drop_vars(vars_to_drop)

In [20]:
def prep_tile(tile):
    """Reshape a small tile: stack lat/lon and flatten features."""
    log.info(f"Preparing tile with shape {tile.sizes}")
    stacked = tile.stack(samples=("lat", "lon")).transpose("samples", "wavelength")
    X = stacked.data
    log.info(f"  -> stacked shape: {X.shape}, chunks: {X.chunks}")
    X = da.where(da.isnan(X), 0, X)
    X = X.rechunk({0: -1, 1: -1})  # ensure fully rechunked before reshape
    X_flat = X.reshape((X.shape[0], -1))
    log.info(f"  -> reshaped to: {X_flat.shape}, chunks: {X_flat.chunks}")
    return X_flat, stacked

In [16]:
def train_and_transform_pca(ds_rf, month):
    # Subset out training dataset
    min_lon_train, max_lat_train, max_lon_train, min_lat_train = bbox #(137.645, -31.715, 153.673, -30.104)
    ds_train = ds_rf.sel(lat=slice(min_lat_train, max_lat_train), lon=slice(min_lon_train, max_lon_train), date=month)

    log.info("Training model on subset...")
    
    # 1. Prepare training data
    X_train, _ = prep_tile(ds_train['rhos'])
    
    # 2. Train PCA
    pca = PCA(n_components=5)
    X_train_pca = pca.fit(X_train)
    
    # 3. Apply PCA transform to full dataset
    da_raw = ds_rf.to_array().transpose("lat", "lon", "variable",'wavelength')  # assumes single 'rhos' variable
    da_pca = pca.transform(da_raw)
    da_pca = da_pca.squeeze(axis=3)
    
    # 4. Wrap result into xarray DataArray with correct dims/coords
    ds_pca = ds_rf.copy()
    ds_pca["pca"] = xr.DataArray(
        da_pca,
        dims=("lat", "lon", "component"),
        coords={
            "lat": ds_rf.coords["lat"],
            "lon": ds_rf.coords["lon"],
            "component": np.arange(da_pca.shape[-1])
        }
    )
    return ds_pca

In [9]:
def plot_monthly_value(month, ds_pca):
    # Select the dataset and squeeze out the 'date' dimension
    ds = ds_pca.sel(date=month)

    plt.figure(figsize=(10, 6))
    wavelengths = ds_rf['wavelength'].values
    for cid, spectrum in enumerate(pca.components_):
        plt.plot(wavelengths, spectrum, label=f'Cluster {cid}')
    
    plt.xlabel('Wavelength (nm)')
    plt.ylabel('Reflectance')
    plt.title('PCA Reflectance')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    
    band_indices = [0,1,2,3,4]
    wavelengths = ds['component'].values[band_indices]
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 4))
    
    for i, (idx, wl) in enumerate(zip(band_indices, wavelengths)):
        ax = axes[i]
        img = vis_pca.isel(component=idx)
        img.plot.imshow(ax=ax, cmap='plasma', robust=True, add_colorbar=False)
        ax.set_title(f"Component: {wl}")
        ax.axis('off')
    
    plt.savefig(f'pace_monthly/pca_month_{month}.png')

In [None]:
for i in ds_rf['date'].values:
    ds_pca = train_and_transform_pca(ds_rf, i)
    plot_monthly_value(i, ds_pca)

2025-08-06 21:40:56,353 - INFO - Training model on subset...
2025-08-06 21:40:56,354 - INFO - Preparing tile with shape Frozen({'lat': 329, 'lon': 403, 'wavelength': 122})
2025-08-06 21:40:56,402 - INFO -   -> stacked shape: (132587, 122), chunks: ((403, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 3224, 1612, 1612), (8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 2))
2025-08-06 21:40:56,407 - INFO -   -> reshaped to: (132587, 122), chunks: ((132587,), (122,))
