In [None]:
# Cell 1: Imports, configuration, and small helpers
import numpy as np
import xarray as xr
import geopandas as gpd
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from skimage.transform import resize
import joblib
import folium
import matplotlib.pyplot as plt

# optional: rioxarray for saving GeoTIFF
try:
    import rioxarray  # noqa: F401
    RIO_AVAILABLE = True
except Exception:
    RIO_AVAILABLE = False

# User-configurable values - edit these
STAC_CATALOG = "https://earth-search.aws.element84.com/v1"   # STAC endpoint
AOI_PATH = "AOI/EfateAOI.geojson"                           # path to your AOI GeoJSON
TRAINING_PATH = "Training_Data/InvasiveClean6.geojson"      # training points GeoJSON
DATETIME = "2024-05/2024-09"                                # date range for search
SENTINEL_COLLECTION = ["sentinel-2-c1-l2a"]

# Mask/band options
MASK_NAME = "combined_mask"
INCLUDE_MASK_AS_FEATURE = False   # True to include mask as a predictor band, False to only use to exclude pixels

# Model / outputs
RF_N_ESTIMATORS = 100
RF_RANDOM_STATE = 42
MODEL_OUT = "rf_model.joblib"
PRED_GTIFF = "predicted.tif"
OUT_DS_NETCDF = "predicted_dataset.nc"

# Small helper for debugging prints
def info(msg, *args):
    print("[INFO]", msg, *args)

In [None]:
# Cell 2: STAC client, AOI load, and Dask local client
from pystac_client import Client
from dask.distributed import Client as DaskClient
from odc.stac import load, configure_s3_access

info("Opening STAC catalog:", STAC_CATALOG)
catalog = Client.open(STAC_CATALOG)

# Load AOI
aoi_gdf = gpd.read_file(AOI_PATH)
info("AOI bounds:", aoi_gdf.total_bounds)
aoi_gdf.plot(edgecolor="red", facecolor="none")
plt.title("AOI Check")
plt.show()

bbox = aoi_gdf.total_bounds  # [minx, miny, maxx, maxy]

# Start a local dask client (adjust workers/threads/memory as needed)
dask_client = DaskClient(n_workers=1, threads_per_worker=16, memory_limit="16GB")
info("Dask client started:", dask_client)
# Configure S3 access for ODC
configure_s3_access(cloud_defaults=True, requester_pays=True)
info("Configured S3 access (cloud_defaults=True, requester_pays=True)")

In [None]:
# Cell 3: Load training points and inspect
gdf = gpd.read_file(TRAINING_PATH, bbox=tuple(bbox))
info("Training points loaded:", len(gdf), "records")
# Quick interactive check (in notebook)
gdf.explore(column="randomforest", legend=True)

In [None]:
# Cell 4: STAC search and load into an xarray.Dataset
items = catalog.search(
    collections=SENTINEL_COLLECTION,
    bbox=bbox,
    datetime=DATETIME,
    query={"eo:cloud_cover": {"lt": 25}},
).item_collection()

info("Found items:", len(items))
# Load relevant measurements into an xarray.Dataset; chunk sizes tuned to your environment
data = load(
    items,
    measurements=["red", "green", "blue", "nir08", "swir16", "scl"],
    bbox=bbox,
    chunks={"x": 2048, "y": 2048},
    groupby="solar_day",
)
info("Loaded data variables:", list(data.data_vars))
data

In [None]:
# Cell 5: Cloud mask, scaling, indices, median composite
# SCL mask values to exclude: 1 (defective), 3 (shadow), 9 (cloud), 10 (thin cirrus)
mask_flags = [1, 3, 9, 10]
cloud_mask = ~data.scl.isin(mask_flags)
masked = data.where(cloud_mask)

# scale to 0-1 and clip (Sentinel-2 L2A uses 10000 scaling)
scaled = (masked.where(masked != 0) * 0.0001).clip(0, 1)

# add NDVI as an example (keeps coords/dims)
scaled["ndvi"] = (scaled.nir08 - scaled.red) / (scaled.nir08 + scaled.red)

info("Computing median composite (this may take a few minutes)...")
median = scaled.median("time").compute()
info("Median composite computed. Variables:", list(median.data_vars))

# Quick check
median.odc.explore(vmin=0, vmax=0.3)

In [None]:
# Cell 6: Compute combined mask (water, built-up, bare/roads) from the median composite
# Use the bands in the median dataset; adjust names if your dataset uses different variable names
green = median["green"]
red = median["red"]
nir = median["nir08"]
swir = median["swir16"]

# Indices (xarray operations retain coords)
ndwi  = (green - nir) / (green + nir)
ndbi  = (swir - nir) / (swir + nir)
ndbai = (swir - red) / (swir + red)

water_mask    = ndwi > 0.2
building_mask = ndbi > 0.1
road_mask     = (ndbai > 0.15) & (ndwi < 0) & (ndbi < 0.2)

combined_mask_da = (water_mask | building_mask | road_mask).astype("uint8")
combined_mask_da.name = MASK_NAME
info("Combined mask created with dims:", combined_mask_da.dims)

# Show the mask quick preview
try:
    combined_mask_da.odc.to_rgba(palette=["none","red"], alpha=0.3).odc.explore()  # interactive
except Exception:
    print("Preview: mask min/max:", float(combined_mask_da.min()), float(combined_mask_da.max()))

In [None]:
# Cell 7: Add mask as band if desired, otherwise keep it separate for filtering
if INCLUDE_MASK_AS_FEATURE:
    # add mask into median dataset so to_array includes it as a feature
    median = median.assign(**{MASK_NAME: combined_mask_da})
    info("Mask added to median dataset as a band. Variables now:", list(median.data_vars))
else:
    mask_da = combined_mask_da  # we'll use mask_da for filtering
    info("Mask kept separate; will be used to exclude pixels from training/prediction.")

In [None]:
# Cell 8 (replacement): Prepare training samples - reproject, sample raster values at point locations, filter masked points

# Reproject training points to raster CRS
gdf_pts = gdf.to_crs(median.odc.geobox.crs)
gx = gdf_pts.geometry.x.values
gy = gdf_pts.geometry.y.values

# Convert median (Dataset) to a DataArray of bands for sampling
arr = median.to_array()  # expected dims: ('variable','y','x') but inspect below
print("arr.dims:", arr.dims)
print(arr)  # inspect structure if something unexpected appears

# Identify spatial dims (assume last two dims are spatial)
y_dim, x_dim = arr.dims[-2], arr.dims[-1]
points_dim = "points"

# Vectorized sampling using DataArray indexers (fast)
gx_da = xr.DataArray(gx, dims=points_dim)
gy_da = xr.DataArray(gy, dims=points_dim)

sampled = arr.sel({y_dim: gy_da, x_dim: gx_da}, method="nearest")
print("sampled.dims (after sel):", sampled.dims)

# Remove any singleton dims that might have been created
sampled = sampled.squeeze(drop=True)
print("sampled.dims (after squeeze):", sampled.dims, "ndim:", sampled.ndim)

# If we now have a 2D DataArray (variable, points) we can convert to pandas, otherwise fallback
if sampled.ndim == 2:
    # transpose so rows=points, cols=variables and convert to pandas
    sampled_df = sampled.transpose(points_dim, "variable").to_pandas()
    sampled_df = pd.DataFrame(sampled_df).reset_index(drop=True)
else:
    # Fallback: explicit nearest-index lookup per point (slower but robust)
    import numpy as np
    print("Vectorized sel did not produce 2D output. Falling back to explicit nearest-index lookup...")
    y_coords = arr.coords[y_dim].values
    x_coords = arr.coords[x_dim].values

    # compute index of nearest raster cell for each point
    iy = np.array([np.abs(y_coords - yy).argmin() for yy in gy])
    ix = np.array([np.abs(x_coords - xx).argmin() for xx in gx])

    # collect values: for each point extract arr[:, iy[k], ix[k]] -> 1D (variable)
    vals = np.stack([
        arr.isel({y_dim: iy_k, x_dim: ix_k}).values.ravel()
        for iy_k, ix_k in zip(iy, ix)
    ], axis=0)   # shape (n_points, n_variables)

    # variable names (arr.coords['variable'] exists because we used to_array())
    var_names = list(arr.coords["variable"].values)
    sampled_df = pd.DataFrame(vals, columns=var_names).reset_index(drop=True)

# Combine with class labels
training_df = pd.concat([gdf_pts.reset_index(drop=True)["randomforest"], sampled_df.reset_index(drop=True)], axis=1)

# Optionally filter out points that fall into combined_mask==1 (if not using mask as feature)
if not INCLUDE_MASK_AS_FEATURE:
    mask_at_pts = combined_mask_da.sel({combined_mask_da.dims[-2]: gy, combined_mask_da.dims[-1]: gx}, method="nearest").values
    keep = (mask_at_pts == 0)
    training_df = training_df.loc[keep].reset_index(drop=True)
    info(f"Kept {training_df.shape[0]} training points after mask filtering (out of {len(gx)})")

# Drop NaNs and preview
training_df = training_df.dropna()
info("Training rows used for model fitting after dropna:", len(training_df))
training_df.head()

In [None]:
# Cell 9: Train RandomForest classifier
classes = training_df.iloc[:, 0].values
observations = training_df.iloc[:, 1:].values
info("Observation shape:", observations.shape)

clf = RandomForestClassifier(n_estimators=RF_N_ESTIMATORS, random_state=RF_RANDOM_STATE, n_jobs=-1)
clf.fit(observations, classes)
info("RandomForest fitted. Estimators:", RF_N_ESTIMATORS)

# Save the model
joblib.dump(clf, MODEL_OUT)
info("Saved model to", MODEL_OUT)

In [None]:
# Cell 10: Predict full image
# Stack median into (pixels, variables)
arr = median.to_array()  # dims: variable, y, x
# Determine spatial dim names from arr
y_dim, x_dim = arr.dims[-2], arr.dims[-1]
ny, nx = arr.sizes[y_dim], arr.sizes[x_dim]

stacked = arr.stack(pixels=(y_dim, x_dim)).transpose("pixels", "variable")  # pixels x variables
X = stacked.values  # numpy array shape (n_pixels, n_bands)
info("Full predictor array shape:", X.shape)

# Build valid mask: no NaNs in features
valid_data_mask = ~np.any(np.isnan(X), axis=1)

if not INCLUDE_MASK_AS_FEATURE:
    mask_flat = combined_mask_da.stack(pixels=(y_dim, x_dim)).values.astype(bool)
    predict_mask = valid_data_mask & (~mask_flat)
else:
    predict_mask = valid_data_mask

info("Pixels to predict:", predict_mask.sum(), "out of", len(predict_mask))

# Predict only valid pixels
pred_flat = np.full(X.shape[0], np.nan, dtype=np.float32)
if predict_mask.sum() > 0:
    pred_flat[predict_mask] = clf.predict(X[predict_mask]).astype(np.float32)

pred_2d = pred_flat.reshape(ny, nx)
predicted_da = xr.DataArray(pred_2d, coords={y_dim: median[y_dim], x_dim: median[x_dim]}, dims=(y_dim, x_dim))
predicted_da.name = "predicted_class"
info("Prediction completed. Dtype:", predicted_da.dtype)

In [None]:
# Cell 11: Build output Dataset and save outputs (GeoTIFF / NetCDF)
out_vars = {"predicted": predicted_da, MASK_NAME: combined_mask_da}

out_ds = xr.Dataset(out_vars)
info("Output dataset prepared. Variables:", list(out_ds.data_vars))

if RIO_AVAILABLE:
    try:
        # write CRS and to raster (GeoTIFF)
        out_ds["predicted"].rio.write_crs(median.odc.geobox.crs, inplace=True)
        out_ds["predicted"].rio.to_raster(PRED_GTIFF)
        info("Saved predicted GeoTIFF to", PRED_GTIFF)
    except Exception as e:
        info("Could not save GeoTIFF via rioxarray:", e)
        out_ds.to_netcdf(OUT_DS_NETCDF)
        info("Saved NetCDF to", OUT_DS_NETCDF)
else:
    out_ds.to_netcdf(OUT_DS_NETCDF)
    info("rioxarray not available; saved NetCDF to", OUT_DS_NETCDF)

In [None]:
# Cell 12: Visualize using folium (interactive)
# compute a map center (attempt using training points in lon/lat)
try:
    # gdf_pts currently in raster CRS; convert to geographic WGS84 if needed
    gdf_wgs84 = gdf_pts.to_crs(epsg=4326)
    center = [float(gdf_wgs84.geometry.y.mean()), float(gdf_wgs84.geometry.x.mean())]
except Exception:
    # fallback to AOI centroid (AOI likely in lon/lat already)
    centroid = aoi_gdf.to_crs(epsg=4326).geometry.centroid.iloc[0]
    center = [centroid.y, centroid.x]

m = folium.Map(location=center, zoom_start=11)

# Add median RGB layer (if odc helpers available)
try:
    median.odc.to_rgba(vmin=0, vmax=0.3).odc.add_to(m, name="Median Composite")
except Exception:
    info("ODC RGBA helper not available for median visualization")

# Add predicted classes (if odc helpers exist)
try:
    predicted_da.odc.add_to(m, name="Predicted")
except Exception:
    info("ODC helper not available for predicted layer")

# Add combined mask as a semi-transparent layer (if supported)
try:
    combined_mask_da.astype("uint8").odc.to_rgba(palette=["none","red"], alpha=0.3).odc.add_to(m, name="Combined Mask")
except Exception:
    info("ODC helper not available for mask display")

# Add training points (convert to WGS84 for folium)
try:
    gdf_wgs84.explore(m=m, column="randomforest", legend=True, name="Training Data")
except Exception:
    info("Could not add training points layer to map")

folium.LayerControl().add_to(m)
m

In [None]:
# Cell 13: (Optional) Quick evaluation on training data (be careful: this is training accuracy)
from sklearn.metrics import classification_report, confusion_matrix

# Evaluate on training rows used
train_preds = clf.predict(observations)
print("Classification report (training data):")
print(classification_report(classes, train_preds))
print("Confusion matrix:")
print(confusion_matrix(classes, train_preds))