### Imports & Config

import os, json
import numpy as np
import rasterio
import cv2
from tqdm import tqdm

RAW_DIR = "data/raw"
PROC_DIR = "data/processed"
JSON_PATH = os.path.join(RAW_DIR, "TeamName_Subplot.json")

# Choose dates you will use
DATE = "0626"   # change to 0731 etc
RGB_PATH = os.path.join(RAW_DIR, f"{DATE}_RGB.tif")
NIR_PATH = os.path.join(RAW_DIR, f"{DATE}_NIR.tif")
RED_PATH = os.path.join(RAW_DIR, f"{DATE}_Red.tif")

CHIP_SIZE = (640, 640)     # (W,H) choose one and keep fixed
ADD_NDVI = True

OUT_CHIP_DIR = os.path.join(PROC_DIR, "subplots", f"chips_{DATE}")
os.makedirs(OUT_CHIP_DIR, exist_ok=True)

### Load subplots (same parser, independent)

def load_subplots_simple(json_path: str):
    with open(json_path, "r") as f:
        data = json.load(f)
    items = data.get("subplots", data if isinstance(data, list) else [])
    out = []
    for it in items:
        sid = str(it.get("subplot_id", it.get("id", it.get("tile_id", ""))))
        if not sid:
            continue
        bbox = it.get("bbox", it.get("bbox_xyxy", None))
        if bbox is None:
            continue
        x1,y1,x2,y2 = bbox
        x1,y1,x2,y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
        x1,x2 = min(x1,x2), max(x1,x2)
        y1,y2 = min(y1,y2), max(y1,y2)
        out.append((sid, (x1,y1,x2,y2)))
    return out

subplots = load_subplots_simple(JSON_PATH)
len(subplots), subplots[0]

### Windowed read + resize + NDVI

def read_window(ds: rasterio.DatasetReader, bbox):
    x1,y1,x2,y2 = bbox
    x1 = max(0, x1); y1 = max(0, y1)
    x2 = min(ds.width, x2); y2 = min(ds.height, y2)
    if x2 <= x1 or y2 <= y1:
        return None
    window = rasterio.windows.Window(x1, y1, x2-x1, y2-y1)
    arr = ds.read(window=window)   # (bands, h, w)
    return arr

def robust01(x: np.ndarray, p2=2, p98=98):
    lo = np.percentile(x, p2)
    hi = np.percentile(x, p98)
    x = (x - lo) / (hi - lo + 1e-6)
    return np.clip(x, 0, 1)

def compute_ndvi(nir01, red01):
    return (nir01 - red01) / (nir01 + red01 + 1e-6)

def resize_chw(chw, out_wh):
    # chw: (C,H,W) -> (C, outH, outW)
    C,H,W = chw.shape
    outW,outH = out_wh
    out = []
    for c in range(C):
        out.append(cv2.resize(chw[c], (outW,outH), interpolation=cv2.INTER_AREA))
    return np.stack(out, axis=0)

def save_chip_npz(path, x_chw, meta: dict):
    np.savez_compressed(path, x=x_chw.astype(np.float32), meta=json.dumps(meta))

### Extract all chips for one date

In [None]:
with rasterio.open(RGB_PATH) as ds_rgb, rasterio.open(NIR_PATH) as ds_nir, rasterio.open(RED_PATH) as ds_red:
    for sid, bbox in tqdm(subplots):
        rgb = read_window(ds_rgb, bbox)    # expected (3,h,w)
        nir = read_window(ds_nir, bbox)    # expected (1,h,w) or (h,w) if single band
        red = read_window(ds_red, bbox)

        if rgb is None or nir is None or red is None:
            continue

        # ensure shapes are (1,h,w) for nir/red
        if nir.ndim == 2: nir = nir[None,...]
        if red.ndim == 2: red = red[None,...]
        if rgb.shape[0] >= 3:
            rgb = rgb[:3]

        # If bands are different resolution, you MUST align. For now assume same grid.
        # If not same, add a registration step here.

        # Normalize
        rgb01 = rgb.astype(np.float32) / 255.0
        nir01 = robust01(nir.astype(np.float32))
        red01 = robust01(red.astype(np.float32))

        x = np.concatenate([rgb01, nir01, red01], axis=0)  # (5,h,w)
        if ADD_NDVI:
            ndvi = compute_ndvi(nir01[0], red01[0])[None,...]
            x = np.concatenate([x, ndvi], axis=0)          # (6,h,w)

        x = resize_chw(x, CHIP_SIZE)

        meta = {"subplot_id": sid, "bbox": bbox, "date": DATE}
        out_path = os.path.join(OUT_CHIP_DIR, f"{sid}.npz")
        save_chip_npz(out_path, x, meta)