In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Complete, runnable postprocess script for one ELMFIRE case.

What it does:
- Loads/plots fuel map (categorized colormap)
- Loads weather rasters and plots histograms (WS, WD, M1, M10, M100)
- Loads TOA stack, builds burn-area history
- Loads VIIRS points, filters to extent, groups by half-day, convex hulls, burn-area history
- Overlays VIIRS dots on base map and simulated burned extent for selected times
- Computes Cohen's Kappa per selected frames by rasterizing VIIRS hulls and comparing to φ-binary
- Saves figures + metrics.json

Adjust the PATH CONSTANTS block for your case.
"""

from __future__ import annotations
import os, re, glob, json, math
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch

import rasterio
from rasterio.features import rasterize
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.enums import Resampling as ResampEnum

import geopandas as gpd
from shapely.geometry import Point, Polygon, MultiPoint
from shapely.ops import unary_union

from sklearn.metrics import cohen_kappa_score
from datetime import timedelta
from zoneinfo import ZoneInfo
from datetime import datetime

# -------------------------
# ===== PATH CONSTANTS ====
# -------------------------
# CASE_ROOT = Path("/global/home/users/yirenqin712/scratch/yirenqin712/ELMFIRE_SIMULATION/VnV_Suite/tubbs_fire").resolve()

# # Inputs
# FUELMAP_PATH   = CASE_ROOT / "data/fuels_and_topography/fbfm40b.tif"
# WX_WS_PATH     = CASE_ROOT / "data/weather/ws.tif"
# WX_WD_PATH     = CASE_ROOT / "data/weather/wd.tif"
# WX_M1_PATH     = CASE_ROOT / "data/weather/m1.tif"
# WX_M10_PATH    = CASE_ROOT / "data/weather/m10.tif"
# WX_M100_PATH   = CASE_ROOT / "data/weather/m100.tif"
# TOA_GLOB       = str(CASE_ROOT / "outputs" / "time_of_arrival_*.tif")
# PHI_GLOB       = str(CASE_ROOT / "outputs" / "phi_0000001_*.tif")
# VIIRS_DIR      = CASE_ROOT / "data/observation/viirs"                      # expects viirs_*.shp

# # Outputs
# FIG_DIR = CASE_ROOT / "figures"
# REP_DIR = CASE_ROOT / "report"
# for p in [FIG_DIR, REP_DIR]:
#     p.mkdir(parents=True, exist_ok=True)

# # Map extent + CRS for plotting overlays (adjust if needed)
# # [left, right, bottom, top] in projected CRS (e.g., UTM)
# # If you already know your extent, set it here; otherwise it's computed from fuel map reproject.
# MAP_CRS = "EPSG:32610"  # UTM zone (adjust per case)
OUT_CRS = "EPSG:4326"   # Plotting in lon/lat for base map
# MAP_EXTENT_OVERRIDE: Optional[List[float]] = None  # or [xmin, xmax, ymin, ymax] in MAP_CRS

# # -------------------------
# # ====== UTILITIES =========
# # -------------------------

# def savefig(fig, name: str):
#     """Save a figure into the case-local figures/ folder as PDF and close it."""
#     out = FIG_DIR / f"{name}.pdf"
#     fig.savefig(out, format="pdf")  # no bbox_inches
#     plt.close(fig)
#     return out

def ensure_1band_array(data: np.ndarray) -> np.ndarray:
    """Squeeze raster read output to 2D [H,W]."""
    if data.ndim == 3 and data.shape[0] == 1:
        return data[0]
    if data.ndim == 2:
        return data
    raise ValueError(f"Unexpected raster array shape: {data.shape}")

def read_raster(path: Path) -> Tuple[np.ndarray, rasterio.Affine, dict]:
    with rasterio.open(path) as src:
        arr = src.read(1, masked=True)
        transform = src.transform
        meta = src.meta.copy()
    return arr, transform, meta

def reproject_to(src_path: Path, dst_crs: str) -> Tuple[np.ndarray, rasterio.Affine, dict]:
    with rasterio.open(src_path) as src:
        transform, width, height = calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds
        )
        dst = np.empty((height, width), dtype=src.meta["dtype"])
        reproject(
            source=rasterio.band(src, 1),
            destination=dst,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=transform,
            dst_crs=dst_crs,
            resampling=Resampling.nearest
        )
        meta = src.meta.copy()
        meta.update({"crs": dst_crs, "transform": transform, "width": width, "height": height})
    return dst, transform, meta

def array_extent(transform: rasterio.Affine, width: int, height: int) -> List[float]:
    left, bottom = transform * (0, height)
    right, top   = transform * (width, 0)
    return [left, right, bottom, top]

# -------------------------
# ====== FUEL MAP =========
# -------------------------
def create_custom_fuel_colormap(fuel_data: np.ndarray):
    # Reclassify: 0=structure(91), 1=water(98), 2=nonburnable(92/93/99), 3=vegetation(other)
    reclassified = np.full(fuel_data.shape, 3, dtype=np.uint8)
    reclassified[fuel_data == 91] = 0
    reclassified[fuel_data == 98] = 1
    reclassified[np.isin(fuel_data, [92, 93, 99])] = 2

    colors = ["saddlebrown", "lightblue", "gray", "lightgreen"]
    cmap = mcolors.ListedColormap(colors, name="custom_fuel")
    norm = mcolors.BoundaryNorm(np.arange(-0.5, 4.5, 1), cmap.N)
    return reclassified, cmap, norm

def plot_fuel_map(fuel_map_path: Path, ax: Optional[plt.Axes] = None, show_colorbar: bool = True,
                  dst_crs: str = OUT_CRS, **imshow_kwargs):
    dst, transform, meta = reproject_to(fuel_map_path, dst_crs)
    # mask nodata
    nodata = meta.get("nodata", None)
    if nodata is not None:
        dst = np.ma.masked_equal(dst, nodata)

    fuel_reclass, fuel_cmap, fuel_norm = create_custom_fuel_colormap(dst)

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 8))
    extent = array_extent(transform, meta["width"], meta["height"])
    imshow_kwargs.setdefault("extent", extent)
    imshow_kwargs.setdefault("origin", "upper")

    im = ax.imshow(fuel_reclass, cmap=fuel_cmap, norm=fuel_norm, **imshow_kwargs)
    ax.set_aspect("equal")
    if show_colorbar:
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02, ticks=[0,1,2,3])
        cbar.set_label("Fuel Type")
        cbar.ax.set_yticklabels(["S(91)", "W(98)", "N(92,93,99)", "V(others)"])
    return im, ax, extent

# -------------------------
# ====== WEATHER HISTS =====
# -------------------------
def plot_wx_hist(ax: plt.Axes, filepath: Path, bins: int = 60, title: Optional[str] = None) -> None:
    arr, _, meta = read_raster(filepath)
    data = np.asarray(arr).ravel()
    if np.ma.isMaskedArray(data):
        data = data.compressed()
    data = data[np.isfinite(data)]

    ax.hist(data, bins=bins, density=True)
    ax.set_ylabel("PDF [-]")
    if title:
        ax.set_title(title)

# -------------------------
# ===== TOA + PHI LOADING ==
# -------------------------
def extract_timestamp_from_phi_name(fname: str) -> Optional[int]:
    m = re.search(r"phi_0000001_(\d+)\.tif$", fname)
    return int(m.group(1)) if m else None

def load_toa_stack(toa_glob: str) -> Tuple[List[Path], List[np.ndarray], List[float], rasterio.Affine, dict]:
    paths = sorted([Path(p) for p in glob.glob(toa_glob)])
    if not paths:
        raise FileNotFoundError(f"No TOA rasters found with pattern: {toa_glob}")
    arrays, times = [], []
    transform, meta = None, None
    for p in paths:
        arr, tform, m = read_raster(p)
        arrays.append(np.array(arr))
        # infer time from filename if present, else index
        m2 = re.search(r"time_of_arrival_(\d+)\.tif$", p.name)
        times.append(float(m2.group(1)) if m2 else float(len(times)))
        if transform is None:
            transform, meta = tform, m
    return paths, arrays, times, transform, meta

def load_phi_paths(phi_glob: str) -> List[Path]:
    paths = sorted([Path(p) for p in glob.glob(phi_glob)])
    if not paths:
        raise FileNotFoundError(f"No PHI rasters found with pattern: {phi_glob}")
    return paths

def phi_to_binary(phi_path: Path) -> Tuple[np.ndarray, np.ndarray]:
    """
    Convert φ field to binary burn map on its own grid.
    Assumption: burned = (phi <= 0). Returns (binary, valid_mask).
    """
    with rasterio.open(phi_path) as src:
        phi = src.read(1, masked=True)
        valid = ~phi.mask if np.ma.isMaskedArray(phi) else np.ones_like(phi, dtype=bool)
        burned = (phi <= 0)
        burned = np.where(valid, burned, 0).astype(np.uint8)
        return burned, valid

# -------------------------
# ===== VIIRS PROCESSING ===
# -------------------------
def load_viirs_points(viirs_dir: Path) -> gpd.GeoDataFrame:
    shp_list = sorted(viirs_dir.glob("*.shp"))
    if not shp_list:
        raise FileNotFoundError(f"No VIIRS shapefiles found under: {viirs_dir}")
    gdfs = [gpd.read_file(shp, engine="fiona") for shp in shp_list]
    gdf = gpd.pd.concat(gdfs, ignore_index=True)
    return gdf

def add_viirs_obstime(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    """
    Expects:
      - ACQ_DATE as datetime-like (or parseable)
      - ACQ_TIME as 'HHMM' string (e.g., '1324')
    """
    if not gnp_is_datetime64_any(gdf["ACQ_DATE"]):
        gdf["ACQ_DATE"] = gpd.pd.to_datetime(gdf["ACQ_DATE"])
    def combine(row):
        hm = row["ACQ_TIME"]
        hh = int(str(hm)[:2])
        mm = int(str(hm)[2:])
        return row["ACQ_DATE"] + timedelta(hours=hh, minutes=mm)
    gdf["observation_time"] = gdf.apply(combine, axis=1)
    gdf = gdf.sort_values("observation_time").reset_index(drop=True)
    # half-day bin index (integer): day_number*2 + (0 or 1)
    gdf["half_day"] = (gdf["observation_time"].dt.day * 2 + (gdf["observation_time"].dt.hour // 12)).astype(int)
    return gdf

def gnp_is_datetime64_any(series) -> bool:
    return np.issubdtype(series.dtype, np.datetime64)

def filter_viirs_to_extent(gdf: gpd.GeoDataFrame, extent_xyxy: List[float], target_crs: str) -> gpd.GeoDataFrame:
    """
    extent in target_crs; points are assumed in EPSG:4326 (VIIRS); we project to target_crs then clip.
    """
    gdf = gdf.to_crs("EPSG:4326")  # ensure lon/lat first
    gdf = gdf.to_crs(target_crs)
    xmin, xmax, ymin, ymax = extent_xyxy
    bbox = Polygon([(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)])
    return gdf[gdf.within(bbox)].copy()

from pyproj import Geod

# --- concave hulls per half-day (fallback to convex if needed) ---
def viirs_concave_hulls_by_halfday(gdf: gpd.GeoDataFrame, ratio: float = 0.5, allow_holes: bool = True) -> Dict[int, Polygon]:
    """
    Build concave hulls in EPSG:4326 for each half_day bin.
    Falls back to convex hull (and tiny buffer if degenerate).
    """
    hulls = {}
    for half_id, sub in gdf.groupby("half_day"):
        pts = MultiPoint(list(sub.geometry))
        try:
            hull = concave_hull(pts, ratio=ratio, allow_holes=allow_holes) if len(sub) >= 3 else pts.convex_hull
        except Exception:
            hull = pts.convex_hull
        if hull.geom_type in ("Point", "LineString"):
            hull = hull.buffer(1e-9)  # tiny buffer in degrees to make a polygon
        hulls[half_id] = hull
    return hulls

# --- geodesic areas (km²) from EPSG:4326 hulls ---
_geod = Geod(ellps="WGS84")

def _geodesic_area_m2(poly: Polygon) -> float:
    x, y = poly.exterior.xy
    a_ext, _ = _geod.polygon_area_perimeter(x, y)
    area = abs(a_ext)
    for ring in poly.interiors:
        xi, yi = ring.xy
        a_h, _ = _geod.polygon_area_perimeter(xi, yi)
        area -= abs(a_h)
    return area

def viirs_burn_area_history_from_hulls(hulls: Dict[int, Polygon]) -> Tuple[List[int], List[float]]:
    """
    Returns (half_day_ids, areas_km2) using WGS84 geodesic area.
    """
    times = sorted(hulls.keys())
    areas_km2 = []
    for t in times:
        g = hulls[t]
        if g.is_empty:
            areas_km2.append(0.0)
        elif isinstance(g, MultiPolygon):
            areas_km2.append(sum(_geodesic_area_m2(p) for p in g.geoms) / 1e6)
        else:  # Polygon
            areas_km2.append(_geodesic_area_m2(g) / 1e6)
    return times, areas_km2


def rasterize_polygon_to_ref(poly: Polygon, ref_path: Path, burn_value: int = 1) -> np.ndarray:
    with rasterio.open(ref_path) as ref:
        out_shape = (ref.height, ref.width)
        transform = ref.transform
    if poly.is_empty:
        return np.zeros(out_shape, dtype=np.uint8)
    return rasterize([(poly, burn_value)], out_shape=out_shape, transform=transform,
                     fill=0, dtype="uint8")
# -------------------------
# ======= BURN HISTORY =====
# -------------------------
from typing import Optional, Tuple, List
import numpy as np

def burn_area_history_from_toa(
    toa_array: np.ndarray,
    pixel_area_m2: Optional[float] = None,
    n_steps: Optional[int] = None
) -> Tuple[List[float], List[float], List[float]]:
    """
    Compute cumulative burned area vs time from a single per-pixel TOA raster.

    Parameters
    ----------
    toa_array : np.ndarray
        2D array (or masked array) where each finite element is the time of arrival.
        NaNs / masked cells are treated as no-data.
    pixel_area_m2 : float, optional
        Area of one pixel (m^2). If None, areas are returned in pixel counts.
    n_steps : int, optional
        If provided, compute areas at this many evenly spaced time thresholds
        between min and max TOA (faster, coarser). If None, use every unique TOA.

    Returns
    -------
    times : List[float]
        Time thresholds (either unique TOA values or evenly spaced thresholds).
    areas : List[float]
        Cumulative burned area at each time (same length as `times`).
    """
    # Normalize input to a masked array and extract finite values
    a = np.ma.asarray(toa_array)
    finite_mask = np.isfinite(a) & (~a.mask if np.ma.isMaskedArray(a) else True)
    vals = np.asarray(a[finite_mask], dtype=float)

    if vals.size == 0:
        return [], []

    # Sort TOA values once; cumulative count yields burned pixels vs time
    vals.sort()

    # Drop negative 
    vals = vals[vals > 0]
    if vals.size == 0:
        return [], [], []

    scale = pixel_area_m2 if pixel_area_m2 else 1.0

    if n_steps is None:
        unique_times, counts = np.unique(vals, return_counts=True)
        cum_counts = np.cumsum(counts)
        areas = (cum_counts * scale).astype(float).tolist()
        return unique_times.tolist(), cum_counts.tolist(), areas
    else:
        # Coarse curve at evenly spaced thresholds (much faster if many unique TOAs)
        t_min, t_max = float(vals[0]), float(vals[-1])
        thresholds = np.linspace(t_min, t_max, int(n_steps))
        # Number of vals <= threshold via binary search
        idx = np.searchsorted(vals, thresholds, side="right")
        areas = (idx * scale).astype(float).tolist()
        return thresholds.tolist(), areas


# -------------------------
# ====== PLOTTING (TOA vs VIIRS)
# -------------------------
def plot_burnt_map_from_toa(
    ax: plt.Axes,
    toa_array: np.ndarray,
    time_now: float,
    extent,
    alpha: float = 0.4
) -> None:
    """
    Show burned region where TOA <= time_now on the same grid/extent as TOA.
    """
    # finite domain + threshold
    finite = np.isfinite(toa_array) & (~toa_array.mask if np.ma.isMaskedArray(toa_array) else True)
    burned = (finite & (toa_array <= time_now)).astype(float)
    burned[burned == 0] = np.nan
    ax.imshow(burned, extent=extent, origin="upper", alpha=alpha, cmap="Greys", vmin=0, vmax=1)


def plot_viirs_points(ax: plt.Axes, gdf_proj: gpd.GeoDataFrame, **kwargs):
    xs = gdf_proj.geometry.x.values
    ys = gdf_proj.geometry.y.values
    ax.scatter(xs, ys, **kwargs)

# -------------------------
# ====== KAPPA METRIC ======
# -------------------------
def calc_cohen_kappa_for_case(phi_paths: List[Path], hulls_by_halfday: Dict[int, Polygon]) -> List[float]:
    """
    For each φ raster, find nearest half_day hull, rasterize it on φ grid,
    compute kappa between (φ<=0) vs (hull raster == 1) over valid pixels.
    """
    half_ids_sorted = sorted(hulls_by_halfday.keys())
    kappas = []
    used_half = set()
    for phi_path in phi_paths:
        # match by numeric timestamp in filename
        ts = extract_timestamp_from_phi_name(phi_path.name)
        # pick nearest half-day id (proxy). If none, skip.
        if not half_ids_sorted:
            kappas.append(float("nan"))
            continue
        best = min(half_ids_sorted, key=lambda h: abs(h - (ts if ts is not None else half_ids_sorted[0])))
        # avoid reusing exact same half bin if you want 1-1 matching
        # (comment out following two lines to allow reuse)
        # if best in used_half and len(half_ids_sorted) > 1:
        #     candidates = [h for h in half_ids_sorted if h not in used_half]
        #     if candidates: best = min(candidates, key=lambda h: abs(h - ts))
        used_half.add(best)

        viirs_bin = rasterize_polygon_to_ref(hulls_by_halfday[best], phi_path, burn_value=1)
        phi_bin, valid = phi_to_binary(phi_path)

        y_true = phi_bin[valid].ravel()
        y_pred = viirs_bin[valid].ravel()
        kappa = cohen_kappa_score(y_true, y_pred)
        kappas.append(float(kappa))
        print(f"✅ {phi_path.name}: half_day={best}, Kappa={kappa:.4f}, Burned_px={int(phi_bin.sum())}")
    return kappas


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Complete, runnable postprocess script for one ELMFIRE case.

What it does:
- Loads/plots fuel map (categorized colormap)
- Loads weather rasters and plots histograms (WS, WD, M1, M10, M100)
- Loads TOA stack, builds burn-area history
- Loads VIIRS points, filters to extent, groups by half-day, convex hulls, burn-area history
- Overlays VIIRS dots on base map and simulated burned extent for selected times
- Computes Cohen's Kappa per selected frames by rasterizing VIIRS hulls and comparing to φ-binary
- Saves figures + metrics.json

Adjust the PATH CONSTANTS block for your case.
"""

from __future__ import annotations
import os, re, glob, json, math
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch

import rasterio
from rasterio.features import rasterize
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.enums import Resampling as ResampEnum

import geopandas as gpd
from shapely.geometry import Point, Polygon, MultiPoint
from shapely.ops import unary_union

from sklearn.metrics import cohen_kappa_score
from datetime import timedelta

# -------------------------
# ===== PATH CONSTANTS ====
# -------------------------
CASE_ROOT = Path("/global/home/users/yirenqin712/scratch/yirenqin712/ELMFIRE_SIMULATION/VnV_Suite/tubbs_fire").resolve()

# Inputs
FUELMAP_PATH   = CASE_ROOT / "data/fuels_and_topography/fbfm40b.tif"
WX_WS_PATH     = CASE_ROOT / "data/weather/ws.tif"
WX_WD_PATH     = CASE_ROOT / "data/weather/wd.tif"
WX_M1_PATH     = CASE_ROOT / "data/weather/m1.tif"
WX_M10_PATH    = CASE_ROOT / "data/weather/m10.tif"
WX_M100_PATH   = CASE_ROOT / "data/weather/m100.tif"
TOA_GLOB       = str(CASE_ROOT / "outputs" / "time_of_arrival_*.tif")
PHI_GLOB       = str(CASE_ROOT / "outputs" / "phi_0000001_*.tif")
VIIRS_DIR      = CASE_ROOT / "data/viirs_observation"                      # expects viirs_*.shp

# Outputs
FIG_DIR = CASE_ROOT / "figures"
REP_DIR = CASE_ROOT / "report"
for p in [FIG_DIR, REP_DIR]:
    p.mkdir(parents=True, exist_ok=True)

# Map extent + CRS for plotting overlays (adjust if needed)
# [left, right, bottom, top] in projected CRS (e.g., UTM)
# If you already know your extent, set it here; otherwise it's computed from fuel map reproject.
MAP_CRS = "EPSG:32610"  # UTM zone (adjust per case)
OUT_CRS = "EPSG:4326"   # Plotting in lon/lat for base map
MAP_EXTENT_OVERRIDE: Optional[List[float]] = None  # or [xmin, xmax, ymin, ymax] in MAP_CRS

# -------------------------
# ====== UTILITIES =========
# -------------------------

def save_fig(fig, name: str):
    """Save a figure into the case-local figures/ folder as PDF and close it."""
    out = FIG_DIR / f"{name}.pdf"
    fig.savefig(out, format="pdf")  # no bbox_inches
    plt.close(fig)
    return out

# ---- 1) Fuel map base layer
fig, ax = plt.subplots(figsize=(11, 10))
_, ax, extent_ll = plot_fuel_map(FUELMAP_PATH, ax=ax, show_colorbar=True, dst_crs=OUT_CRS)
ax.set_xlabel('Longitude [$^\circ$]')
ax.set_ylabel('Latitude [$^\circ$]')
out=save_fig(fig, "fuelmap_categorical")
print(f"Fuel map saved to {out}.")

In [None]:
# ---- 2) Weather histograms
wx_list = [
    (WX_WS_PATH, "Wind Speed [mph]"),
    (WX_WD_PATH, "Wind Direction [°]"),
    (WX_M1_PATH, "1-hr Dead Fuel Moisture [%]"),
    (WX_M10_PATH, "10-hr Dead Fuel Moisture [%]"),
    (WX_M100_PATH, "100-hr Dead Fuel Moisture [%]"),
]
for path, xlabel in wx_list:
    fig, ax = plt.subplots(figsize=(6,4))
    plot_wx_hist(ax, path, bins=60)
    ax.set_xlabel(xlabel)
    save_fig(fig, f"hist_{path.stem}")

In [None]:
# ---- 3) Load TOA + compute pixel area
# If you have one grid definition, we can estimate pixel area from transform
phi_paths = load_phi_paths(PHI_GLOB)
with rasterio.open(phi_paths[0]) as src0:
    px_area = abs(src0.transform.a * src0.transform.e)  # (m/px)*(m/px) if UTM/proj in meters

toa_paths, toa_arrays, toa_times, toa_transform, toa_meta = load_toa_stack(TOA_GLOB)
toa_times_hist, count, toa_area_hist = burn_area_history_from_toa(toa_arrays, pixel_area_m2=px_area)


In [None]:
# ---- 4) VIIRS: load, timebin (half-day), filter to map extent, build hulls
viirs_gdf = load_viirs_points(VIIRS_DIR)
viirs_gdf = add_viirs_obstime(viirs_gdf)

# Figure out a target extent in projected CRS for point filtering
# If override is given, use it. Else derive fuel map extent in MAP_CRS
# fmap_arr, fmap_transform, fmap_meta = read_raster(FUELMAP_PATH)
# fmap_extent_mapcrs = array_extent(fmap_transform, fmap_meta["width"], fmap_meta["height"])
# target_extent = MAP_EXTENT_OVERRIDE if MAP_EXTENT_OVERRIDE else fmap_extent_mapcrs

# viirs_filtered = filter_viirs_to_extent(viirs_gdf, target_extent, target_crs=MAP_CRS)
hulls_by_half = viirs_concave_hulls_by_halfday(viirs_gdf)
viirs_times, viirs_areas = viirs_burn_area_history_from_hulls(hulls_by_half)  # area in map CRS units (m^2 if UTM)

In [None]:
# ---- 5) VIIRS vs TOA overlays for first 3 half-day bins (if available)
# A proper timezone-aware start time (Los Angeles)
start_time = datetime(2017, 10, 8, 21, 43, tzinfo=ZoneInfo("America/Los_Angeles"))

# Representative datetime for each half_day bin (median of obs in that bin)
half2time = (viirs_filtered
             .groupby("half_day")["observation_time"]
             .median()
             .to_dict())

# Get TOA grid's extent + CRS (we'll draw everything in the TOA CRS)
with rasterio.open(toa_paths[0]) as ref:
    target_crs = ref.crs.to_string()
    extent = array_extent(ref.transform, ref.width, ref.height)

# Reproject fuel map to the TOA CRS for consistent overlay (no colorbar in small multiples)
n_show = min(3, len(viirs_times))
fig, axs = plt.subplots(1, n_show, figsize=(6*n_show, 6))
if n_show == 1:
    axs = [axs]

for i, ax in enumerate(axs):
    half_id = viirs_times[i]  # e.g., 2*day + (0 or 1)
    # Convert VIIRS bin time to seconds since start_time
    t_obs = half2time[half_id].tz_convert("America/Los_Angeles") if hasattr(half2time[half_id], "tz_convert") else half2time[half_id]
    if t_obs.tzinfo is None:
        # if your observation_time is naive, assume LA timezone
        t_obs = t_obs.replace(tzinfo=ZoneInfo("America/Los_Angeles"))
    time_now_sec = (t_obs - start_time).total_seconds()

    # Base fuel map in TOA CRS
    _im, _ax, _ = plot_fuel_map(FUELMAP_PATH, ax=ax, show_colorbar=False, dst_crs=target_crs)

    # Burned map from a single TOA field (choose the TOA field you actually use)
    # If you loaded one per-pixel TOA raster, pick it here:
    toa_field = toa_arrays[0] if len(toa_arrays) >= 1 else toa_arrays  # 2D masked/ndarray
    plot_burnt_map_from_toa(ax, toa_field, time_now_sec, extent, alpha=0.25)

    # VIIRS points for this bin → TOA CRS
    sub_pts = viirs_filtered[viirs_filtered["half_day"] == half_id].to_crs(target_crs)
    plot_viirs_points(ax, sub_pts, s=4, marker='s')

    ax.set_title(f"TOA vs VIIRS (half_day={half_id})")
    ax.set_aspect("equal")

save_fig(fig, "simu_vs_viirs_examples")

In [None]:
# ---- 6) Burn area history plot
fig, ax = plt.subplots(figsize=(8,5))
ax.plot(np.array(toa_times_hist)/3600.0, np.array(toa_area_hist)/1e6, label="Simulated (TOA)", lw=2)
ax.plot(viirs_times, np.array(viirs_areas), label="Observed (VIIRS hull)", lw=2)
ax.set_xlabel("Time [s]")
ax.set_ylabel("Burned area [km²]")
ax.legend()
save_fig(fig, "burn_area_history")

In [None]:
# ---- 7) Cohen's Kappa (per φ slice vs nearest half-day VIIRS hull)
kappas = calc_cohen_kappa_for_case(phi_paths, hulls_by_half)
metrics = {f"kappa_{i+1}": (None if (i>=len(kappas) or np.isnan(kappas[i])) else round(float(kappas[i]), 6))
           for i in range(min(3, len(kappas)))}

# ---- 8) Save metrics + simple LaTeX snippets (optional)
(OUT_DIR / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print("[OK] Postprocess complete.")
print(f"  - Figures: {FIG_DIR}")
print(f"  - Metrics JSON: {OUT_DIR/'metrics.json'}")