In [None]:
"""
Full extraction for RGB + NDVI patches using polygon annotations.

Pass 1: geometry-only classification (fast; no raster IO) over full extent
Pass 2: read+save selected windows into dataset_{k}/{label}/...

Requires:
  pip install rasterio geopandas shapely numpy pillow rtree
"""

from pathlib import Path
import csv
from collections import defaultdict
import numpy as np
from PIL import Image
import geopandas as gpd
from shapely.geometry import box
import rasterio
from rasterio.windows import Window
from rasterio.vrt import WarpedVRT
from rasterio.enums import Resampling

# ===================== CONFIG =====================
NDVI_PATH = r"path/to/orthomosaic_ndvi.tif"
RGB_PATH = r"/path/to/orthomosaic_rgb.tif"
GEOJSON  = r"/path/to/annotations.geojson"
OUT_ROOT = r"/path/to/OUT_ROOT"

SAVE_FORMAT = "both"   # 'png' | 'npy' | 'both'
PNG_SCALE_PER_PATCH = True

PATCH_SIZE = 224
STRIDE     = 128
MIN_VALID  = 0.5   # min fraction of valid pixels (across all rasters)
MIN_COVER_FRAC = 0.05  # min fraction of tile covered by annotations to be positive
LABEL_FIELD  = "Class"  # in GeoJSON: "High Risk", "Medium Risk", "No Risk"

RANDOM_SEED = 42
RESAMPLING_METHOD = Resampling.bilinear

RASTER_STACK = [RGB_PATH, NDVI_PATH]

# size of the spatial blocks that define each dataset_k
BLOCK_SIZE = 2048  # adjust as you like

# Treat pure-black RGB pixels as nodata (for black borders with no nodata flag)
TREAT_ZERO_AS_NODATA = True
BLACK_THRESHOLD = 0   # 0 for exact black; set to 1–2 if there are near-black halos

EDGE_FRAC = 0.95   # require ≥95% valid pixels to call it “inside”

# ---------- LIGHT LOGGING & VALIDATION HELPERS ----------
def log(msg: str):
    print(f"[INFO] {msg}", flush=True)

def validate_config():
    for p in [RGB_PATH, NDVI_PATH, GEOJSON]:
        assert Path(p).exists(), f"Missing file: {p}"
    assert isinstance(PATCH_SIZE, int) and PATCH_SIZE > 0
    assert isinstance(STRIDE, int) and STRIDE > 0
    assert 0.0 <= MIN_VALID <= 1.0

def assert_raster_profile(profile: dict):
    assert profile.get("crs") is not None
    for k in ["transform", "width", "height"]:
        assert k in profile
    assert profile["width"] > 0 and profile["height"] > 0

def assert_gdf(gdf: gpd.GeoDataFrame):
    if len(gdf) > 0:
        assert gdf.crs is not None
        assert LABEL_FIELD in gdf.columns, (
            f"LABEL_FIELD '{LABEL_FIELD}' not in annotation columns: {list(gdf.columns)}"
        )


# ===================== UTILITIES =====================
def load_reference_profile(ref_path: str) -> dict:
    with rasterio.open(ref_path) as ref:
        return {
            "crs": ref.crs,
            "transform": ref.transform,
            "width": ref.width,
            "height": ref.height,
        }

def load_annotations(geojson_path: str, target_crs) -> gpd.GeoDataFrame:
    gdf = gpd.read_file(geojson_path)
    if gdf.crs != target_crs:
        gdf = gdf.to_crs(target_crs)
    return gdf

def build_grid(width: int, height: int, patch: int, stride: int) -> list[tuple[int, int]]:
    xs = list(range(0, max(0, width  - patch + 1), stride))
    ys = list(range(0, max(0, height - patch + 1), stride))
    return [(x, y) for y in ys for x in xs]

def window_bbox(ref_ds, win: Window) -> tuple[float, float, float, float]:
    w_transform = ref_ds.window_transform(win)
    left,  top    = w_transform * (0, 0)
    right, bottom = w_transform * (win.width, win.height)
    return left, bottom, right, top

# ---------- skip nodata/black windows ----------
def window_has_data(ref_ds, win: Window) -> bool:
    """
    True if the window contains any real data.
    1) Mask says at least one valid pixel.
    2) If enabled, RGB sniff: require at least one pixel > BLACK_THRESHOLD.
    """
    m = ref_ds.read_masks(1, window=win)            # 0 = nodata, >0 = valid
    if not np.any(m):
        return False

    if TREAT_ZERO_AS_NODATA:
        bands_to_read = min(3, ref_ds.count)        # up to first 3 bands (RGB)
        rgb = ref_ds.read(
            indexes=list(range(1, bands_to_read + 1)),
            window=win,
            masked=False
        )  # (bands,H,W)
        if not np.any(rgb > BLACK_THRESHOLD):
            print("warning: window rejected due to all-black RGB")
            return False

    return True

def classify_overlap_label(bbox_poly, gdf, sindex, min_cover_frac: float = 0.05) -> tuple[str, float]:
    """
    Returns (label, frac_covered).

    Tiles with total coverage < min_cover_frac are 'No Risk'.
    Any overlap with High or Medium risk polygons -> 'Risk'.
    Otherwise -> 'No Risk'.
    """
    if sindex is None:
        return "No Risk", 0.0

    cand_idx = list(sindex.intersection(bbox_poly.bounds))
    if not cand_idx:
        return "No Risk", 0.0

    hits = gdf.iloc[cand_idx]
    inter = hits.geometry.intersection(bbox_poly)

    areas = np.array(
        [g.area if (g is not None and not g.is_empty) else 0.0 for g in inter],
        dtype=float,
    )
    tile_area = bbox_poly.area if bbox_poly.area > 0 else 0.0
    frac_total = float(areas.sum()) / tile_area if tile_area > 0 else 0.0

    # If too little overlap overall, treat as negative
    if frac_total < min_cover_frac:
        return "No Risk", frac_total

    labels_norm = hits[LABEL_FIELD].astype(str).str.lower().tolist()

    has_risk = any(
        (l.startswith("high") or l.startswith("medium")) and areas[i] > 0
        for i, l in enumerate(labels_norm)
    )

    if has_risk:
        return "Risk", frac_total
    else:
        return "No Risk", frac_total



def read_window_stack_aligned(raster_paths: list[str], window: Window, ref_profile: dict):
    """
    Returns:
      stacked: (C,H,W)  (masked pixels filled with 0)
      valid_mask: (H,W) bool (AND of dataset masks read via read_masks)
    """
    chans, masks = [], []

    for i, rp in enumerate(raster_paths):
        with rasterio.open(rp) as src:
            if i == 0:
                # Reference (RGB) on native grid
                arr = src.read(window=window, masked=True)          # (B,H,W) masked
                chans.append(np.ma.filled(arr, 0))
                m = src.read_masks(1, window=window) > 0            # (H,W) True=valid
                masks.append(m)
            else:
                # Align others onto reference grid
                with WarpedVRT(
                    src,
                    crs=ref_profile["crs"],
                    transform=ref_profile["transform"],
                    width=ref_profile["width"],
                    height=ref_profile["height"],
                    resampling=RESAMPLING_METHOD,
                ) as vrt:
                    arr = vrt.read(window=window, masked=True)      # (B,H,W)
                    chans.append(np.ma.filled(arr, 0))
                    m = vrt.read_masks(1, window=window) > 0        # (H,W) True=valid
                    masks.append(m)

    stacked = np.concatenate(chans, axis=0)                         # (C,H,W)
    valid_mask = np.logical_and.reduce(masks)                       # (H,W)
    return stacked, valid_mask


def to_uint8_per_patch(img_hw_c: np.ndarray) -> np.ndarray:
    H, W, C = img_hw_c.shape
    out = np.zeros((H, W, C), dtype=np.uint8)
    for b in range(C):
        band = img_hw_c[:, :, b].astype(np.float32)
        mn, mx = np.nanmin(band), np.nanmax(band)
        out[:, :, b] = 0 if (not np.isfinite(mn) or not np.isfinite(mx) or mx == mn) \
            else np.clip((band - mn) / (mx - mn) * 255.0, 0, 255).astype(np.uint8)
    return out


# --------- margin filter and grouping into blocks ----------
def margin_filter_meta(meta, block_size: int, patch_size: int, stride: int):
    halo = max(0, patch_size - stride)
    if halo == 0:
        return meta
    kept = []
    for m in meta:
        x, y = m["x"], m["y"]
        bx = (x // block_size) * block_size
        by = (y // block_size) * block_size
        left_gap   = x - bx
        top_gap    = y - by
        right_gap  = (bx + block_size) - (x + patch_size)
        bottom_gap = (by + block_size) - (y + patch_size)
        if min(left_gap, top_gap, right_gap, bottom_gap) >= halo:
            kept.append(m)
    return kept


def group_into_blocks(
    meta,
    *,
    block_size: int,
    patch_size: int,
    stride: int,
    apply_margin_filter: bool = True,
):
    """
    Group patches in `meta` into spatial blocks of size `block_size x block_size`.

    Returns:
      block_groups: dict[int, list[int]]
          block_id -> list of indices into meta_kept
      unique_groups: np.ndarray of shape (n_blocks, 2)
          each row is [gx, gy] block-grid coordinate
      meta_kept: list[dict]
          filtered meta (after optional margin filter)
    """
    meta_kept = margin_filter_meta(meta, block_size, patch_size, stride) if apply_margin_filter else meta

    xs = np.array([m["x"] for m in meta_kept])
    ys = np.array([m["y"] for m in meta_kept])

    # block grid coordinates (integer)
    groups_xy = np.column_stack([xs // block_size, ys // block_size])  # (N, 2)
    unique_groups, inv = np.unique(groups_xy, axis=0, return_inverse=True)

    block_groups: dict[int, list[int]] = defaultdict(list)
    for idx_meta, block_id in enumerate(inv):
        block_groups[int(block_id)].append(idx_meta)

    print(f"{len(unique_groups)} block-datasets created.")
    return block_groups, unique_groups, meta_kept


def save_patch_png_and_or_npy(base_path: Path, img_hwc: np.ndarray, save_format: str, scale_per_patch=True) -> str:
    saved_name = None
    if save_format in ("png", "both") and img_hwc.shape[2] <= 4:
        img8 = to_uint8_per_patch(img_hwc) if scale_per_patch else (
            img_hwc / max(1e-9, np.nanmax(img_hwc)) * 255
        ).astype(np.uint8)
        Image.fromarray(img8).save(base_path.with_suffix(".png"))
        saved_name = base_path.with_suffix(".png").name
    if save_format in ("npy", "both"):
        np.save(str(base_path.with_suffix(".npy")), img_hwc)
        if saved_name is None:
            saved_name = base_path.with_suffix(".npy").name
    return saved_name

def write_csv_header(csv_writer):
    csv_writer.writerow([
        "set","filename","label","frac_covered",
        "xmin","ymin","xmax","ymax","center_x","center_y",
        "x_pix","y_pix","patch_size",
        "src_rgb","src_ndvi"
    ])

def csv_row_for_meta(set_name, saved_name, m, patch_size, rgb_name, ndvi_name):
    cx = (m["left"] + m["right"]) / 2.0
    cy = (m["bottom"] + m["top"]) / 2.0
    return [
        set_name, saved_name, m["label"], f"{m['frac']:.6f}",
        m["left"], m["bottom"], m["right"], m["top"], cx, cy,
        m["x"], m["y"], patch_size,
        rgb_name, ndvi_name
    ]


# ===================== DYNAMIC EDGE BUFFERS =====================

def compute_edges_95(ref_path: str, frac: float = 0.95, black_thr: int = 0, use_bands: int = 3):
    """
    Dynamic edge detector using 'black==nodata':
      Valid(x,y) = any( band_value > black_thr ) across first `use_bands` bands.
    Finds first col/row where valid-fraction >= frac.
    Returns pixel offsets: left, right, top, bottom.
    """
    with rasterio.open(ref_path) as ds:
        H, W = ds.height, ds.width
        b = min(use_bands, ds.count)
        if b == 0:
            raise ValueError("Raster has no bands.")

        # Build black-is-nodata mask: (H, W) bool
        rgb = ds.read(indexes=list(range(1, b+1)), masked=False)  # (b,H,W)
        mask = np.any(rgb > black_thr, axis=0)

        # Column-wise (over all rows)
        col_frac = mask.mean(axis=0)
        left  = int(np.argmax(col_frac >= frac))  if np.any(col_frac >= frac)  else W
        right = int(np.argmax(col_frac[::-1] >= frac)) if np.any(col_frac[::-1] >= frac) else W

        # Row-wise (over all cols)
        row_frac = mask.mean(axis=1)
        top    = int(np.argmax(row_frac >= frac))  if np.any(row_frac >= frac)  else H
        bottom = int(np.argmax(row_frac[::-1] >= frac)) if np.any(row_frac[::-1] >= frac) else H

    return {"left": left, "right": right, "top": top, "bottom": bottom}



# ===================== PIPELINE STEPS =====================
def pass1_geometry_classification(raster_stack, patch_size, stride, gdf, label_field, edge_bufs=None):
    meta = []
    sindex = gdf.sindex if len(gdf) else None
    with rasterio.open(raster_stack[0]) as ref:
        width, height = ref.width, ref.height
        candidates = build_grid(width, height, patch_size, stride)
        for (x, y) in candidates:
            # Dynamic edge buffer check
            if edge_bufs:
                if x < edge_bufs["left"] or y < edge_bufs["top"] or \
                   (x + patch_size) > (width - edge_bufs["right"]) or \
                   (y + patch_size) > (height - edge_bufs["bottom"]):
                    continue

            win = Window(x, y, patch_size, patch_size)
            if not window_has_data(ref, win):  # skip no-data / black tiles
                continue
            left, bottom, right, top = window_bbox(ref, win)
            bbox_poly = box(left, bottom, right, top)
            label, frac = classify_overlap_label(
                bbox_poly, gdf, sindex, min_cover_frac=MIN_COVER_FRAC
            )
            meta.append({
                "x": x, "y": y,
                "left": left, "bottom": bottom, "right": right, "top": top,
                "frac": frac,
                "label": label
            })
    return meta


def pass2_read_and_save_block_datasets(
    raster_stack,
    ref_profile,
    meta,
    block_groups: dict[int, list[int]],
    out_root: Path,
    save_format: str,
    min_valid: float,
    patch_size: int,
    png_scale_per_patch: bool,
):
    """
    Save patches grouped by block-id into:

      OUT_ROOT/
        dataset_0000/
          High Risk/*.png|*.npy
          No Risk/*.png|*.npy
        dataset_0001/
          ...

    The 'set' column in the CSV will contain the dataset name (e.g. 'dataset_0000').
    """
    out_root = Path(out_root)
    out_root.mkdir(parents=True, exist_ok=True)

    counters = defaultdict(int)
    csv_path = out_root / "labels_master.csv"

    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        write_csv_header(writer)

        rgb_name = Path(raster_stack[0]).name
        ndvi_name = Path(raster_stack[1]).name if len(raster_stack) > 1 else ""

        with rasterio.open(raster_stack[0]) as ref:
            for block_id, idxs in block_groups.items():
                dataset_name = f"dataset_{block_id:04d}"

                # ensure label dirs exist for this dataset
                labels_in_block = sorted({meta[i]["label"] for i in idxs})
                for lab in labels_in_block:
                    (out_root / dataset_name / lab).mkdir(parents=True, exist_ok=True)

                for i in idxs:
                    m = meta[i]
                    x, y, label = m["x"], m["y"], m["label"]

                    # safety: skip if patch would go out of raster bounds
                    if (x + patch_size) > ref.width or (y + patch_size) > ref.height:
                        continue

                    win = Window(x, y, patch_size, patch_size)
                    stacked, valid_mask = read_window_stack_aligned(raster_stack, win, ref_profile)

                    if not valid_mask.any():
                        continue
                    if valid_mask.mean() < min_valid:
                        continue

                    img_hwc = np.transpose(stacked, (1, 2, 0))  # (H,W,C)

                    idx_out = counters[(block_id, label)]
                    base = out_root / dataset_name / label / f"patch_{idx_out:06d}"
                    saved_name = save_patch_png_and_or_npy(
                        base, img_hwc, save_format, png_scale_per_patch
                    )

                    writer.writerow(
                        csv_row_for_meta(
                            dataset_name,  # goes into 'set' column
                            saved_name,
                            m,
                            patch_size,
                            rgb_name,
                            ndvi_name,
                        )
                    )
                    counters[(block_id, label)] += 1

    print(f"Master CSV: {csv_path}")
    return counters


In [None]:
if __name__ == "__main__":
    validate_config()
    out_root = Path(OUT_ROOT)

    ref_profile = load_reference_profile(RGB_PATH)
    gdf = load_annotations(GEOJSON, target_crs=ref_profile["crs"])
    assert_gdf(gdf)

    edge_bufs = compute_edges_95(
        RGB_PATH,
        frac=EDGE_FRAC,
        black_thr=BLACK_THRESHOLD,
        use_bands=3,
    )
    log(f"Dynamic edge buffers: {edge_bufs}")

    meta = pass1_geometry_classification(
        RASTER_STACK,
        PATCH_SIZE,
        STRIDE,
        gdf,
        LABEL_FIELD,
        edge_bufs=edge_bufs,
    )
    log(f"Pass1 produced {len(meta)} candidate patches.")

    block_groups, unique_groups, meta_kept = group_into_blocks(
        meta,
        block_size=BLOCK_SIZE,
        patch_size=PATCH_SIZE,
        stride=STRIDE,
        apply_margin_filter=True,
    )

    pass2_read_and_save_block_datasets(
        RASTER_STACK,
        ref_profile,
        meta_kept,
        block_groups,
        out_root,
        SAVE_FORMAT,
        MIN_VALID,
        PATCH_SIZE,
        PNG_SCALE_PER_PATCH,
    )
