In [1]:
import os
import json
import numpy as np
import pandas as pd
from pycocotools.coco import COCO
import cv2
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pycocotools import mask as maskUtils
from PIL import Image, ImageDraw

Create csv excluding the NaN patches

In [None]:
import os
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# === CONFIG ===
LABELS_FILE = "/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final.csv"
OUTPUT_CSV = "/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final_filtered.csv"
REMOVED_IDX_FILE = "/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/removed_indices.npy"

TARGET_RES = "r10m"
PATCH_SIZE = 256
BANDS = ["b01", "b02", "b03", "b04", "b05", "b06", "b07", "b08", "b8a", "b11", "b12"]

# Same split function
def split_data(labels_file, test_size=0.3, val_size=0.5, seed=42):
    df = pd.read_csv(labels_file)
    df_train, df_tmp = train_test_split(df, test_size=test_size, stratify=df["label"], random_state=seed)
    df_test, df_val = train_test_split(df_tmp, test_size=val_size, stratify=df_tmp["label"], random_state=seed)
    return df_train.reset_index(drop=True), df_val.reset_index(drop=True), df_test.reset_index(drop=True)

def resample_band(ds, band, target_res="r10m", ref="b04", crs="EPSG:32632"):
    """
    Resample any band (reflectance or classification) to target resolution.
    """
    ref_band = ds[f"measurements/reflectance/{target_res}/{ref}"].rio.write_crs(crs) # Reference band at target resolution

    if band == "scl":
        band_da = ds[f"conditions/mask/l2a_classification/r20m/{band}"].rio.write_crs(crs)
        source_res = "r20m"
    else:
        # Detect which reflectance resolution contains the band
        source_res = next(
        (r for r in ["r10m", "r20m", "r60m"] if band in ds[f"measurements/reflectance/{r}"]),
        None
        )
        if source_res is None:
            raise ValueError(f"Band {band} not found in reflectance or scl folder")
        band_da = ds[f"measurements/reflectance/{source_res}/{band}"].rio.write_crs(crs)
    # If source == target, no resampling needed
    if source_res == target_res:
        return band_da

    return band_da.rio.reproject_match(ref_band)

def build_stack(ds, bands, target_res="r10m", ref_band="b04", crs="EPSG:32632"):
    """
    Build a lazy dask-backed (H, W, C) stack from bands, resampling as needed.

    Args:
        ds: xarray Dataset or DataTree
        bands: list of band names to include
        target_res: desired output resolution for all bands
        ref_band: reference band for resampling (default: 'b04' red)
        crs: CRS to assign if missing

    Returns:
        xarray.DataArray with dimensions (y, x, band)
    """
    stack = []

    for b in bands:
        if b in ds['measurements/reflectance/r10m'] or \
           b in ds['measurements/reflectance/r20m'] or \
           b in ds['measurements/reflectance/r60m']:
            arr = resample_band(ds, b, target_res=target_res, ref=ref_band, crs=crs) / 10000.0
        else:
            raise ValueError(f"Band {b} not found or not supported.")

        # Expand dims for stacking
        arr = arr.expand_dims(band=[b])
        stack.append(arr)

    # Concatenate all bands along 'band' dimension
    stacked = xr.concat(stack, dim="band").transpose("y", "x", "band")
    return stacked

# Load splits
df_train, df_val, df_test = split_data(LABELS_FILE)
splits = {"train": df_train, "val": df_val, "test": df_test}

# Store valid/invalid indices (in original CSV index space)
valid_indices_all = []
invalid_indices_all = []

for split_name, df_split in splits.items():
    print(f"\n🔍 Checking split: {split_name} ({len(df_split)} rows)")
    all_patches = []
    all_labels = []
    valid_indices = []
    invalid_indices = []

    for zarr_path, group in tqdm(df_split.groupby("zarr_path"), desc=f"Scanning {split_name}"):
        if not os.path.exists(zarr_path):
            print(f"⚠️ Missing file: {zarr_path}")
            continue
        ds = xr.open_datatree(zarr_path, engine="zarr", mask_and_scale=False, chunks={})
        stack = build_stack(ds, BANDS, target_res=TARGET_RES, ref_band="b04")

        for idx, row in group.iterrows():
            x, y, label = row["x"], row["y"], row["label"]
            x_rescaled = x
            y_rescaled = y
            patch_size_rescaled = PATCH_SIZE
            patch = stack.isel(
                y=slice(y_rescaled, y_rescaled + patch_size_rescaled),
                x=slice(x_rescaled, x_rescaled + patch_size_rescaled)
            ).to_numpy().astype(np.float32)

            if np.isnan(patch).any() or np.isinf(patch).any():
                invalid_indices.append(idx)
                continue  # skip invalid patch
            
            all_patches.append(patch)
            all_labels.append(label)
            valid_indices.append(idx)
        ds.close()
    X = np.stack(all_patches, axis=0)   # (N, H, W, C)
    y = np.array(all_labels)
    np.savez_compressed(f"/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/{split_name}_cache.npz", X=X, y=y)
    valid_indices_all.extend(valid_indices)
    invalid_indices_all.extend(invalid_indices)
    print(f"Found {len(valid_indices)} valid patches in {split_name}")

df_concat = pd.concat([df_train, df_val, df_test])
# df_concat still has original indices (because split_data returned them)
df_filtered_in_split_order = df_concat.loc[df_concat.index.isin(valid_indices_all)].copy()

# Reset index for neat CSV (but the patch order in file is now train->val->test)
df_filtered_in_split_order = df_filtered_in_split_order.reset_index(drop=True)
df_filtered_in_split_order.to_csv(OUTPUT_CSV, index=False)

# Save invalid indices for reference
np.save(REMOVED_IDX_FILE, np.array(invalid_indices_all, dtype=np.int64))

print(f"\n✅ Saved filtered CSV with {len(df_filtered_in_split_order)} rows (out of {len(df_concat)}).")
print(f"🗑️ Removed {len(invalid_indices_all)} invalid patches.")

In [52]:
# Create mapping of old patches with csv rows

old_csv = pd.read_csv("/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final.csv")

# Recreate the old splits the same way
df_train_old, df_val_old, df_test_old = split_data("/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final.csv")

# Combine them in the same order as before (train → val → test)
df_all_old = pd.concat([df_train_old, df_val_old, df_test_old], axis=0).reset_index(drop=True)

# Load old caches (the ones used to make rgb_patches)
train_data = np.load("/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/train_cache_old.npz")
val_data = np.load("/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/val_cache_old.npz")
test_data = np.load("/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/test_cache_old.npz")

# Combine labels
y_all = np.concatenate([train_data["y"], val_data["y"], test_data["y"]])


# Get the indices of positive patches (mucilage)
pos_indices = np.where(y_all == 1)[0]
df_pos = df_all_old.iloc[pos_indices].reset_index(drop=True)

# Each positive corresponds 1:1 to your patch_XXXX.png export order
df_pos["rgb_filename"] = [f"patch_{i:04d}.png" for i in range(len(df_pos))]

df_pos.to_csv("/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/rgb_patch_mapping.csv", index=False)

In [8]:
# add index column to the csvs

df_orig = pd.read_csv("/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final.csv")
df_filt = pd.read_csv("/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final_filtered.csv")
OUTPUT_FILTERED_CSV = "/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final_filtered_split.csv"

# Recreate original splits (if not already stored)
df_train_orig, df_val_orig, df_test_orig = split_data("/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final.csv")
for df in [df_train_orig, df_val_orig, df_test_orig, df_filt]:
    df["patch_id"] = df["zarr_path"].astype(str) + "_" + df["x"].astype(str) + "_" + df["y"].astype(str)

# Add split info
df_train_orig["split"] = "train"
df_val_orig["split"] = "val"
df_test_orig["split"] = "test"
df_orig_splits = pd.concat([df_train_orig, df_val_orig, df_test_orig], axis=0).reset_index(drop=True)

orig_keys = list(zip(df_orig_splits["zarr_path"], df_orig_splits["x"].astype(float), df_orig_splits["y"].astype(float)))
filt_keys = set(zip(df_filt["zarr_path"], df_filt["x"].astype(float), df_filt["y"].astype(float)))
removed_indices = [i for i, key in enumerate(orig_keys) if key not in filt_keys]
print(f"🧩 {len(removed_indices)} patches dropped due to NaNs")

df_filtered_split = df_filt.merge(
    df_orig_splits[["patch_id", "split"]],
    on="patch_id",
    how="left"
)
df_filtered_split.drop("patch_id", axis=1, inplace=True)
df_filtered_split.to_csv(OUTPUT_FILTERED_CSV, index=False)

🧩 39 patches dropped due to NaNs


Save as numpy the mucilage patches

In [2]:
# Path to your cached .npz file
cache_file = "/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/train_cache.npz"  # or val.npz / test.npz
out_file = "/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/train_positive_new.npz"

# Load the cached dataset
data = np.load(cache_file)
X, y = data["X"], data["y"]  # X.shape = (N,H,W,C), y.shape = (N,)

# Select only patches with label 1 (mucilage)
mask = y == 1
X_pos = X[mask]
y_pos = y[mask]  # will be all 1, optional

print(f"Selected {len(X_pos)} positive patches out of {len(y)} total.")

# Save to a new npz file
# np.savez_compressed(out_file, X=X_pos, y=y_pos)

Selected 154 positive patches out of 955 total.


Convert and save as RGB

In [None]:
path = "/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/train_positive.npz"
arr = np.load(path)
X = arr['X']  # (N, H, W, C)
output_dir = "/home/ubuntu/mucilage_pipeline/mucilage-detection/rgb_patches"
os.makedirs(output_dir, exist_ok=True)

for i, patch in enumerate(X):
    rgb = patch[:, :, [3, 2, 1]]  # RGB bands
    # Percentile normalization
    p2, p98 = np.nanpercentile(rgb, (2, 98))
    rgb = np.clip((rgb - p2) / (p98 - p2 + 1e-6), 0, 1)
    rgb = (rgb * 255).astype(np.uint8)
    # Save each patch as PNG
    filename = os.path.join(output_dir, f"patch_{i:04d}.png")
    cv2.imwrite(filename, cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))

(203, 256, 256, 11)


Convert annotations from Roboflow to binary masks

In [None]:
# Path to your dataset
base_dir = "/home/ubuntu/mucilage_pipeline/mucilage-detection/roboflow_dataset"
splits = ["train", "valid", "test"]

for split in splits:
    img_dir = os.path.join(base_dir, split)
    ann_path = os.path.join(img_dir, "_annotations.coco.json")
    mask_dir = os.path.join(base_dir, f"masks_{split}")
    os.makedirs(mask_dir, exist_ok=True)

    coco = COCO(ann_path)

    for img_id in coco.getImgIds():
        img_info = coco.loadImgs(img_id)[0]
        img_name = img_info["file_name"]
        h, w = img_info["height"], img_info["width"]

        mask = np.zeros((h, w), dtype=np.uint8)

        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)
        for ann in anns:
            # Convert polygon or RLE to binary mask
            m = coco.annToMask(ann)
            mask = np.maximum(mask, m * 255)

        cv2.imwrite(os.path.join(mask_dir, img_name), mask)

    print(f"✅ Masks for {split} saved to {mask_dir}")

In [7]:
# Roboflow dataset with all images togethes (not splitted)

PATCH_SIZE = 256
CSV_PATH = "/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final_filtered_split.csv"
CACHE_DIR = "/home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy"
COCO_PATH = "/home/ubuntu/mucilage_pipeline/mucilage-detection/roboflow_dataset/train/_annotations.coco.json"
MASKS_OUT = "/home/ubuntu/mucilage_pipeline/mucilage-detection/roboflow_dataset/saved_masks"
os.makedirs(MASKS_OUT, exist_ok=True)

# Load CSV & Create Splits
df = pd.read_csv(CSV_PATH)

# Load COCO annotations
with open(COCO_PATH, 'r') as f:
    coco = json.load(f)

anns_by_image = {}
for ann in coco['annotations']:
    anns_by_image.setdefault(ann['image_id'], []).append(ann)

img_map = {img['id']: img['file_name'] for img in coco['images']}

# Map RGB patch index (cumulative) -> annotations
rgb_ann_map = {}
for img_id, filename in img_map.items():
    base = os.path.basename(filename)
    idx_str = base.split('_')[1]  # '0000', '0001', ...
    rgb_ann_map[idx_str] = anns_by_image.get(img_id, [])

def ann_to_mask(annotations, h, w):
    mask = np.zeros((h, w), dtype=np.uint8)
    for ann in annotations:
        seg = ann['segmentation']
        if isinstance(seg, list):
            for poly in seg:
                poly = np.array(poly).reshape((-1, 2))
                img = Image.new('L', (w, h), 0)
                ImageDraw.Draw(img).polygon(poly.flatten().tolist(), outline=1, fill=1)
                mask += np.array(img, dtype=np.uint8)
        else:
            m = maskUtils.decode(seg)
            mask = np.maximum(mask, m)
    return np.clip(mask, 0, 1)

# Process all splits
global_pos_counter = 0
for split_name in ["train", "val", "test"]:
    cache_path = os.path.join(CACHE_DIR, f"{split_name}_cache.npz")
    data = np.load(cache_path)
    X, y = data["X"], data["y"]

    total_masks = []
    for label in y:
        if label == 0:
            mask = np.zeros((PATCH_SIZE, PATCH_SIZE), dtype=np.uint8)
        else:
            idx_str = f"{global_pos_counter:04d}"  # cumulative index
            anns = rgb_ann_map.get(idx_str, [])
            if len(anns) == 0:
                mask = np.zeros((PATCH_SIZE, PATCH_SIZE), dtype=np.uint8)
                print(f"⚠️ Missing annotation for positive patch {idx_str}, using zeros.")
            else:
                mask = ann_to_mask(anns, PATCH_SIZE, PATCH_SIZE)
            global_pos_counter += 1
        total_masks.append(mask)

    total_masks = np.stack(total_masks)
    np.savez_compressed(os.path.join(MASKS_OUT, f"{split_name}_masks.npz"), masks=total_masks)
    print(f"✅ {split_name} masks generated: {total_masks.shape}")

✅ train masks generated: (955, 256, 256)
✅ val masks generated: (203, 256, 256)
✅ test masks generated: (210, 256, 256)


Reconstruct original order

In [3]:
def split_data(labels_file, test_size=0.3, val_size=0.5, seed=42):
    df = pd.read_csv(labels_file)

    # first split train vs test
    df_train, df_tmp = train_test_split(
        df, test_size=test_size, stratify=df["label"], random_state=seed
    )
    # then split train vs val
    df_test, df_val = train_test_split(
        df_tmp, test_size=val_size, stratify=df_tmp["label"], random_state=seed
    )
    return df_train, df_val, df_test

df = "/home/ubuntu/mucilage_pipeline/mucilage-detection/csv/patches_final.csv"
df_train, df_val, df_test = split_data(df)

In [8]:
base_dir = "/home/ubuntu/mucilage_pipeline/mucilage-detection"
splits = {
    "train": df_train,
    "val": df_val,
    "test": df_test
}

for split, df in splits.items():
    print(f"\n=== Processing {split.upper()} ===")

    # Load original npz cache
    cache_file = os.path.join(base_dir, f"saved_npy/{split}_cache.npz")
    data = np.load(cache_file)
    X, y = data["X"], data["y"]

    # Identify positive/negative indices
    # Indices
    pos_indices = np.where(y == 1)[0]
    neg_indices = np.where(y == 0)[0]

    # Initialize empty masks (same number of patches as X)
    H, W = 256, 256
    M = np.zeros((len(X), H, W), dtype=np.uint8)

    # Refined masks directory (from Roboflow export)
    mask_dir = os.path.join(base_dir, f"roboflow_dataset/masks")

    # Build prefix-based lookup dictionary for Roboflow masks
    mask_lookup = {}
    for fname in os.listdir(mask_dir):
        if fname.startswith("patch_"):
            prefix = fname.split("_png")[0]  # e.g. "patch_0000"
            mask_lookup[prefix] = os.path.join(mask_dir, fname)

    print(f"Found {len(mask_lookup)} refined masks for {split}.")

    # Fill masks for positive patches
    for i, idx in enumerate(tqdm(pos_indices)):
        prefix = f"patch_{i:04d}"  # matches your exported patch names
        mask_path = mask_lookup.get(prefix, None)
        if mask_path:
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is not None:
                M[idx] = (mask > 127).astype(np.uint8)
            else:
                print(f"⚠️ Could not read {mask_path}")
        else:
            print(f"⚠️ Missing mask for {prefix} ({split})")

    # Save mask array aligned with original X
    out_path = os.path.join(base_dir, f"saved_npy/{split}_masks_refined.npz")
    np.savez_compressed(out_path, M=M)
    print(f"✅ Saved refined masks to {out_path}")


=== Processing TRAIN ===
Found 222 refined masks for train.


100%|██████████| 154/154 [00:00<00:00, 4052.98it/s]


✅ Saved refined masks to /home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/train_masks_refined.npz

=== Processing VAL ===
Found 222 refined masks for val.


100%|██████████| 33/33 [00:00<00:00, 6921.99it/s]

✅ Saved refined masks to /home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/val_masks_refined.npz

=== Processing TEST ===





Found 222 refined masks for test.


100%|██████████| 35/35 [00:00<00:00, 3230.58it/s]

✅ Saved refined masks to /home/ubuntu/mucilage_pipeline/mucilage-detection/saved_npy/test_masks_refined.npz



