# Prepare dataset for modeling

In [1]:
import os
import numpy as np
import xarray as xr
import glob
import rioxarray

# 1. Parameters

In [2]:
# Directory containing .zarr files
DATA_DIR = "/home/ubuntu/mucilage_pipeline/mucilage-detection/data/adr_test/target"

# Bands to use for analytics
BANDS = ["b02", "b03", "b04", "b8a", "b11", "b12", 'amei', 'ndwi']  # Blue, Green, Red, NIR, NIR narrow, SWIR1, SWIR2

# zarr files
zarr_files = glob.glob(os.path.join(DATA_DIR, "*.zarr"))

# 2. Helper functions

In [3]:
def resample_to_10m(ds, band, ref):
    """
    Resample band to match the resolution & grid of reference band.
    ds: opened .zarr datatree
    band: name of band to resample (string)
    ref: name of reference band (string)
    """
    crs_code = "EPSG:32632"

    # Define reference band
    ref_band = ds[f"measurements/reflectance/r10m/{ref}"]  # reference (10m red)
    ref_band = ref_band.rio.write_crs(crs_code, inplace=True)

    # Band to convert
    band_20m = ds[f"measurements/reflectance/r20m/{band}"]
    band_10m = band_20m.rio.write_crs(crs_code, inplace=True)  # ensure CRS

    return band_10m.rio.reproject_match(ref_band)

def compute_amei(ds, eps=1e-6):
    red   = ds["measurements/reflectance/r10m/b04"].values.astype(np.float32)
    green = ds["measurements/reflectance/r10m/b03"].values.astype(np.float32)
    nir   = resample_to_10m(ds, 'b8a', 'b04')
    nir = nir.values.astype(np.float32)
    swir  = resample_to_10m(ds, 'b11', 'b04')
    swir = swir.values.astype(np.float32)  # "B11" or "B12"

    # AMEI = (2*red + nir - 2*swir) / (green + 0.25*swir)
    denom = green + 0.25 * swir
    amei  = (2*red + nir - 2*swir) / (denom + eps)  # eps avoids divide-by-zero

    return amei

def compute_ndwi(ds, eps=1e-6):
    green = ds["measurements/reflectance/r10m/b03"].values.astype(np.float32)
    nir   = resample_to_10m(ds, 'b8a', 'b04')
    nir = nir.values.astype(np.float32)

    # AMEI = (green - nir) / (green + nir)
    ndwi  = (green - nir) / (green + nir + eps)  # eps avoids divide-by-zero

    return ndwi


def build_stack_10m(ds, bands):
    """
    Create (H, W, C) stack from selected bands/indices.
    Assumes ds has reflectance bands and you may also compute indices like AMEI/NDWI.
    """
    stack = []
    ref_10m = ds["measurements/reflectance/r10m/b04"]  # reference (10m red)
    
    for b in bands:
        if b in ds['measurements/reflectance/r10m']:   # reflectance at 10m
            arr = ds['measurements/reflectance/r10m'][b].values.astype(np.float32)
        elif b in ds['measurements/reflectance/r20m']: # reflectance at 20m
            arr = resample_to_10m(ds, b, 'b04')
            arr = arr.values.astype(np.float32)
        elif b == "amei":
            arr = compute_amei(ds)  # your function
        elif b == "ndwi":
            arr = compute_ndwi(ds)  # you’d need to define
        else:
            raise ValueError(f"Band {b} not found or not supported.")
        
        stack.append(arr)
    stack = np.stack(stack, axis=-1)  # (H, W, C)
    return stack


def extract_patches_3d(array, patch_size=256, stride=256):
    """
    Extract patches from (H, W, C) array.
    Returns a list of patches with shape (patch_size, patch_size, C).
    """
    H, W, C = array.shape
    patches = []

    for i in range(0, H - patch_size + 1, stride):
        for j in range(0, W - patch_size + 1, stride):
            patch = array[i:i+patch_size, j:j+patch_size, :]
            patches.append(patch)
    return patches


def process_folder(zarr_files, bands, patch_size=256, stride=256):
    """
    Loop through all zarr files in a folder and extract patches.
    """
    all_patches = []

    for zf in zarr_files:
        print(f"Processing {zf} ...")
        ds = xr.open_datatree(zf, engine="zarr", mask_and_scale=False)

        stack = build_stack_10m(ds, bands)
        patches = extract_patches_3d(stack, patch_size, stride)
        
        all_patches.extend(patches)  # you might want to save to disk instead
    
    return all_patches

# 3. Patchify

Prepare 256x256xbands patches from each image

In [8]:
patches = process_folder(zarr_files[:2], BANDS)

print(f"Total patches: {len(patches)}")
print(f"Patch shape: {patches[0].shape}")

Processing /home/ubuntu/mucilage_pipeline/mucilage-detection/data/adr_test/target/S2B_MSIL2A_20240728T095549_N0511_R122_T32TQQ_20240728T114034.zarr ...
Processing /home/ubuntu/mucilage_pipeline/mucilage-detection/data/adr_test/target/S2A_MSIL2A_20240723T100031_N0511_R122_T32TQR_20240723T155949.zarr ...


: 

# 4. Prepare dataset

Split into train/test/validation