#  GeoTIFF Data Inspection Toolkit

Flexible satellite image inspection for **any folder structure**.

Rather than forcing everything into one UI, this notebook provides
**separate entry points** depending on how your data is organized:

| Section | Use When | Input |
|---------|----------|-------|
| **§1 — Quick Look** | You have a single `.tif` file path | One file path string |
| **§2 — Flat Folder** | A folder with `.tif` files (no subfolders) | One directory path |
| **§3 — Nested ROI Browser** | `root/roi_name/*.tif` structure | Root directory path |
| **§4 — Glob Pattern** | Scattered files matching a pattern | Glob expression |
| **§5 — Multi-Path Comparison** | Compare 2-4 specific files side by side | List of file paths |

All sections share a common **visualization engine** (§0) — run it once, then
jump to whichever section fits your data.


---
## §0 · Visualization Engine (run this first)

Core functions shared by every section below.  
Includes: auto data-type detection, band info registry, and all plot routines.


In [None]:
"""
GeoTIFF Inspection Toolkit — Shared Visualization Engine
Author : Beomsik Kim (Dept. of Geoinformatics, UOS / GIST Hydro AI Intern)
Date   : 2026-02-05
"""

import os, glob
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import ipywidgets as widgets
from IPython.display import display

# =============================================================================
# Band information registry
# =============================================================================

BAND_INFO = {
    "S2 TOA": [
        "B1 Coastal 443nm", "B2 Blue 490nm", "B3 Green 560nm", "B4 Red 665nm",
        "B5 RE1 705nm", "B6 RE2 740nm", "B7 RE3 783nm", "B8 NIR 842nm",
        "B8A NIR-N 865nm", "B9 WV 945nm", "B10 Cirrus 1375nm",
        "B11 SWIR1 1610nm", "B12 SWIR2 2190nm",
    ],
    "S1 SAR": ["VV (co-pol)", "VH (cross-pol)"],
    "Cloud/Shadow": ["Cloud Prob %", "Flag", "Shadow", "Snow/Ice", "Haze/Cirrus"],
    "Dynamic World": ["Land Cover (0-8)"],
    "Water Label": ["Water=1 / Non-water=0 / NoData=-1"],
}


# =============================================================================
# Auto data-type identification
# =============================================================================

def identify_data_type(src, data):
    """Infer data type from band count, value range, and filename."""
    n = src.count
    fname = os.path.basename(src.name).lower()
    vmin, vmax = float(np.nanmin(data)), float(np.nanmax(data))
    uniq = len(np.unique(data[np.isfinite(data)]))

    # Priority 1 — filename keywords
    for keys, label in [
        (("cld", "shdw", "cloud"), "Cloud/Shadow"),
        (("dw", "dynamic"),        "Dynamic World"),
        (("s1", "sar"),            "S1 SAR"),
        (("label", "mask", "water", "jrc"), "Water Label"),
        (("s2", "optical", "toa", "sr"),    "S2 TOA"),
    ]:
        if any(k in fname for k in keys):
            return label

    # Priority 2 — band count + value range
    if n == 13:                    return "S2 TOA"
    if n == 2 and vmin < -10:      return "S1 SAR"
    if n == 5:                     return "Cloud/Shadow"
    if n == 1 and uniq <= 12:
        return "Dynamic World" if vmax > 1 else "Water Label"

    return f"Unknown ({n} bands, [{vmin:.1f}~{vmax:.1f}])"


def get_sensor_category(dtype):
    if dtype in ("S2 TOA",):  return "optical"
    if dtype in ("S1 SAR",):  return "sar"
    return "other"


def _pnorm(arr, lo=2, hi=98):
    """Percentile-based [0,1] normalization."""
    vlo, vhi = np.nanpercentile(arr, [lo, hi])
    return np.clip((arr - vlo) / max(vhi - vlo, 1e-6), 0, 1)


# =============================================================================
# inspect_tif()  —  THE main function, works on a single file path
# =============================================================================

def inspect_tif(fpath, mode="auto"):
    """
    One-shot inspection of a single GeoTIFF file.

    Parameters
    ----------
    fpath : str
        Path to a .tif file.
    mode : str
        "auto"        — pick the best view based on detected type
        "single"      — single band + histogram (band 1)
        "all"         — grid of all bands
        "rgb"         — optical composites / SAR composites
        "label"       — label mask inspection
        "metadata"    — text metadata + stats table
        "compare_idx" — spectral indices side-by-side (NDVI, NDWI, MNDWI, NDBI)
    """
    if not os.path.isfile(fpath):
        print(f"❌ File not found: {fpath}")
        return

    with rasterio.open(fpath) as src:
        data = src.read().astype(np.float64)
        dtype = identify_data_type(src, data)
        cat = get_sensor_category(dtype)

        header = f" {dtype}  |  {src.count} bands  |  {src.width}×{src.height}  |  CRS: {src.crs}"
        print(header)
        print("─" * len(header))

        # Auto mode selection
        if mode == "auto":
            if cat == "optical":   mode = "rgb"
            elif cat == "sar":     mode = "rgb"
            elif dtype in ("Water Label", "Cloud/Shadow", "Dynamic World"):
                                   mode = "label"
            else:                  mode = "all"

        # Dispatch
        if mode == "metadata":
            _show_metadata(src, data, dtype)
        elif mode == "all":
            _show_all_bands(src, data, dtype)
        elif mode == "rgb":
            if cat == "sar":       _show_sar(src, data, dtype)
            elif cat == "optical": _show_optical(src, data, dtype)
            else:                  _show_other(src, data, dtype)
        elif mode == "label":
            _show_label_inspect(src, data, dtype)
        elif mode == "compare_idx":
            _show_index_comparison(src, data, dtype)
        else:  # single
            _show_single(src, data, dtype, band_idx=0)


# =============================================================================
# Plot: Single band
# =============================================================================

def _show_single(src, data, dtype, band_idx=0):
    band = data[band_idx]
    bnames = BAND_INFO.get(dtype, [f"Band {i+1}" for i in range(src.count)])
    bname = bnames[band_idx] if band_idx < len(bnames) else f"Band {band_idx+1}"

    fig, axes = plt.subplots(1, 3, figsize=(18, 5),
                             gridspec_kw={"width_ratios": [3, 1, 1]})

    cmap = "gray" if get_sensor_category(dtype) == "sar" else "viridis"
    vp = np.nanpercentile(band, [2, 98])
    im = axes[0].imshow(band, cmap=cmap, vmin=vp[0], vmax=vp[1])
    axes[0].set_title(bname, fontsize=13); axes[0].axis("off")
    plt.colorbar(im, ax=axes[0], fraction=0.046)

    valid = band[np.isfinite(band)].ravel()
    if len(valid) > 0:
        axes[1].hist(valid, bins=80, color="steelblue", edgecolor="none", alpha=0.8)
        axes[1].axvline(np.nanmean(band), color="red", ls="--", lw=1,
                        label=f"mean={np.nanmean(band):.2f}")
        axes[1].legend(fontsize=9)
    axes[1].set_title("Histogram"); axes[1].set_xlabel("Value")

    if len(valid) > 0:
        axes[2].hist(valid, bins=80, color="darkorange", edgecolor="none",
                     alpha=0.8, cumulative=True, density=True)
        axes[2].axhline(0.5, color="gray", ls=":", lw=1, label="median")
        axes[2].legend(fontsize=9)
    axes[2].set_title("CDF"); axes[2].set_xlabel("Value")

    plt.tight_layout(); plt.show()
    nan_pct = np.isnan(band).sum() / band.size * 100
    print(f" Min={np.nanmin(band):.4f}  Max={np.nanmax(band):.4f}  "
          f"Mean={np.nanmean(band):.4f}  Std={np.nanstd(band):.4f}  "
          f"Median={np.nanmedian(band):.4f}  NaN={nan_pct:.1f}%")


# =============================================================================
# Plot: All bands grid + correlation heatmap
# =============================================================================

def _show_all_bands(src, data, dtype):
    n = src.count
    cols = min(n, 5)
    rows = (n + cols - 1) // cols
    bnames = BAND_INFO.get(dtype, [f"Band {i+1}" for i in range(n)])
    cmap = "gray" if get_sensor_category(dtype) == "sar" else "viridis"

    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3.5*rows))
    if n == 1: axes = np.array([[axes]])
    axes = np.atleast_2d(axes)

    for i in range(rows * cols):
        ax = axes[i // cols, i % cols]
        if i < n:
            b = data[i]; vlo, vhi = np.nanpercentile(b, [2, 98])
            ax.imshow(b, cmap=cmap, vmin=vlo, vmax=vhi)
            nm = bnames[i] if i < len(bnames) else f"Band {i+1}"
            ax.set_title(f"{nm}\n[{np.nanmin(b):.1f}~{np.nanmax(b):.1f}]", fontsize=9)
        ax.axis("off")
    plt.suptitle(f"{dtype} — All {n} Bands", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()

    # Correlation heatmap
    if n >= 2:
        flat = np.array([data[i].ravel() for i in range(n)])
        mask = np.all(np.isfinite(flat), axis=0)
        if mask.sum() > 10:
            corr = np.corrcoef(flat[:, mask])
            short = [(bnames[i][:12] if i < len(bnames) else f"B{i+1}") for i in range(n)]
            fig2, ax2 = plt.subplots(figsize=(max(6, n*0.7), max(5, n*0.6)))
            im2 = ax2.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
            ax2.set_xticks(range(n)); ax2.set_yticks(range(n))
            ax2.set_xticklabels(short, fontsize=8, rotation=45, ha="right")
            ax2.set_yticklabels(short, fontsize=8)
            plt.colorbar(im2, ax=ax2, fraction=0.046, label="Pearson r")
            ax2.set_title("Inter-Band Correlation", fontsize=13)
            plt.tight_layout(); plt.show()


# =============================================================================
# Plot: Optical composites — True Color, False Color, NDVI, NDWI, MNDWI, NDBI
# =============================================================================

def _show_optical(src, data, dtype):
    n = src.count

    if n >= 13:
        r, g, b = 3, 2, 1;  tc_label = "True Color (B4-B3-B2)"
    elif n >= 4:
        r, g, b = 2, 1, 0;  tc_label = "RGB (B3-B2-B1)"
    else:
        print("⚠️ Need ≥3 bands for RGB"); _show_all_bands(src, data, dtype); return

    fig, axes = plt.subplots(2, 3, figsize=(18, 11))

    # True Color
    rgb = np.stack([_pnorm(data[r]), _pnorm(data[g]), _pnorm(data[b])], axis=-1)
    axes[0,0].imshow(rgb); axes[0,0].set_title(tc_label, fontsize=12); axes[0,0].axis("off")

    # False Color (NIR-R-G)
    if n >= 8:
        fc = np.stack([_pnorm(data[7]), _pnorm(data[3]), _pnorm(data[2])], axis=-1)
        axes[0,1].imshow(fc); axes[0,1].set_title("False Color (B8-B4-B3)", fontsize=12)
    else:
        axes[0,1].text(0.5, 0.5, "N/A (<8 bands)", ha="center", va="center", fontsize=12)
    axes[0,1].axis("off")

    # NDVI
    if n >= 8:
        nir, red = data[7].astype(np.float64), data[3].astype(np.float64)
        d = nir + red; ndvi = np.where(d != 0, (nir - red) / d, 0)
        im = axes[0,2].imshow(ndvi, cmap="RdYlGn", vmin=-0.5, vmax=0.8)
        axes[0,2].set_title("NDVI (B8−B4)/(B8+B4)", fontsize=12)
        plt.colorbar(im, ax=axes[0,2], fraction=0.046)
    else:
        axes[0,2].text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12)
    axes[0,2].axis("off")

    # NDWI
    if n >= 12:
        grn, nir = data[2].astype(np.float64), data[7].astype(np.float64)
        d = grn + nir; ndwi = np.where(d != 0, (grn - nir) / d, 0)
        im = axes[1,0].imshow(ndwi, cmap="RdYlBu", vmin=-0.5, vmax=0.5)
        axes[1,0].set_title("NDWI (B3−B8)/(B3+B8)", fontsize=12)
        plt.colorbar(im, ax=axes[1,0], fraction=0.046)
    else:
        axes[1,0].text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12)
    axes[1,0].axis("off")

    # MNDWI
    if n >= 12:
        grn, swir = data[2].astype(np.float64), data[11].astype(np.float64)
        d = grn + swir; mndwi = np.where(d != 0, (grn - swir) / d, 0)
        im = axes[1,1].imshow(mndwi, cmap="RdYlBu", vmin=-0.5, vmax=0.5)
        axes[1,1].set_title("MNDWI (B3−B11)/(B3+B11)", fontsize=12)
        plt.colorbar(im, ax=axes[1,1], fraction=0.046)
    else:
        axes[1,1].text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12)
    axes[1,1].axis("off")

    # NDBI
    if n >= 12:
        swir, nir = data[11].astype(np.float64), data[7].astype(np.float64)
        d = swir + nir; ndbi = np.where(d != 0, (swir - nir) / d, 0)
        im = axes[1,2].imshow(ndbi, cmap="RdGy_r", vmin=-0.5, vmax=0.5)
        axes[1,2].set_title("NDBI (B11−B8)/(B11+B8)", fontsize=12)
        plt.colorbar(im, ax=axes[1,2], fraction=0.046)
    else:
        axes[1,2].text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12)
    axes[1,2].axis("off")

    plt.suptitle(f" Optical — {dtype}", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()

    # Spectral profile
    if n >= 4:
        bnames = BAND_INFO.get(dtype, [f"B{i+1}" for i in range(n)])
        means = [np.nanmean(data[i]) for i in range(n)]
        stds  = [np.nanstd(data[i]) for i in range(n)]
        short = [bnames[i].split()[0] if i < len(bnames) else f"B{i+1}" for i in range(n)]
        fig2, ax2 = plt.subplots(figsize=(12, 4))
        ax2.errorbar(range(n), means, yerr=stds, fmt="o-", color="steelblue",
                     ecolor="lightcoral", capsize=3, markersize=5)
        ax2.set_xticks(range(n))
        ax2.set_xticklabels(short, fontsize=9, rotation=45, ha="right")
        ax2.set_ylabel("Value"); ax2.set_title("Spectral Profile (mean ± std)")
        ax2.grid(True, alpha=0.3); plt.tight_layout(); plt.show()


# =============================================================================
# Plot: Spectral index comparison (standalone)
# =============================================================================

def _show_index_comparison(src, data, dtype):
    """Side-by-side spectral index comparison for optical data."""
    n = src.count
    if n < 8:
        print("⚠️ Need ≥8 bands for index comparison")
        _show_all_bands(src, data, dtype); return

    nir = data[7].astype(np.float64)
    red = data[3].astype(np.float64)
    grn = data[2].astype(np.float64)

    indices = {}
    # NDVI
    d = nir + red
    indices["NDVI\n(B8−B4)/(B8+B4)"] = (np.where(d != 0, (nir - red) / d, 0), "RdYlGn", -0.5, 0.8)
    # NDWI
    d = grn + nir
    indices["NDWI\n(B3−B8)/(B3+B8)"] = (np.where(d != 0, (grn - nir) / d, 0), "RdYlBu", -0.5, 0.5)

    if n >= 12:
        swir = data[11].astype(np.float64)
        # MNDWI
        d = grn + swir
        indices["MNDWI\n(B3−B11)/(B3+B11)"] = (np.where(d != 0, (grn - swir) / d, 0), "RdYlBu", -0.5, 0.5)
        # NDBI
        d = swir + nir
        indices["NDBI\n(B11−B8)/(B11+B8)"] = (np.where(d != 0, (swir - nir) / d, 0), "RdGy_r", -0.5, 0.5)

    ncols = len(indices)
    fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, 5))
    if ncols == 1: axes = [axes]
    for ax, (title, (arr, cmap, vmin, vmax)) in zip(axes, indices.items()):
        im = ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_title(title, fontsize=11); ax.axis("off")
        plt.colorbar(im, ax=ax, fraction=0.046)
    plt.suptitle(" Spectral Index Comparison", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()


# =============================================================================
# Plot: SAR — VV / VH / composite / difference + histogram
# =============================================================================

def _show_sar(src, data, dtype):
    n = src.count
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))

    vv = data[0]; vlo, vhi = np.nanpercentile(vv, [2, 98])
    im0 = axes[0,0].imshow(vv, cmap="gray", vmin=vlo, vmax=vhi)
    axes[0,0].set_title(f"VV [{np.nanmin(vv):.1f}~{np.nanmax(vv):.1f} dB]", fontsize=12)
    axes[0,0].axis("off"); plt.colorbar(im0, ax=axes[0,0], fraction=0.046, label="dB")

    if n >= 2:
        vh = data[1]; vlo2, vhi2 = np.nanpercentile(vh, [2, 98])
        im1 = axes[0,1].imshow(vh, cmap="gray", vmin=vlo2, vmax=vhi2)
        axes[0,1].set_title(f"VH [{np.nanmin(vh):.1f}~{np.nanmax(vh):.1f} dB]", fontsize=12)
        plt.colorbar(im1, ax=axes[0,1], fraction=0.046, label="dB")
    else:
        axes[0,1].text(0.5, 0.5, "VH — N/A", ha="center", va="center", fontsize=12)
    axes[0,1].axis("off")

    if n >= 2:
        sar_rgb = np.stack([_pnorm(vv), _pnorm(vh),
                            _pnorm(np.where(vh!=0, vv/vh, 0))], axis=-1)
        axes[1,0].imshow(sar_rgb)
        axes[1,0].set_title("Color Composite (R=VV, G=VH, B=VV/VH)", fontsize=11)
    else:
        axes[1,0].text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12)
    axes[1,0].axis("off")

    if n >= 2:
        diff = vv - vh; dlo, dhi = np.nanpercentile(diff, [2, 98])
        im3 = axes[1,1].imshow(diff, cmap="coolwarm", vmin=dlo, vmax=dhi)
        axes[1,1].set_title("VV − VH Difference", fontsize=12)
        plt.colorbar(im3, ax=axes[1,1], fraction=0.046, label="dB")
    else:
        axes[1,1].text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=12)
    axes[1,1].axis("off")

    plt.suptitle(f" SAR — {dtype}", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()

    if n >= 2:
        fig2, ax2 = plt.subplots(figsize=(10, 4))
        ax2.hist(vv[np.isfinite(vv)].ravel(), bins=80, alpha=0.6, color="steelblue",
                 edgecolor="none", label="VV", density=True)
        ax2.hist(vh[np.isfinite(vh)].ravel(), bins=80, alpha=0.6, color="darkorange",
                 edgecolor="none", label="VH", density=True)
        ax2.set_title("SAR Backscatter Distribution"); ax2.set_xlabel("dB"); ax2.set_ylabel("Density")
        ax2.legend(); ax2.grid(True, alpha=0.3); plt.tight_layout(); plt.show()


# =============================================================================
# Plot: Other types (Cloud / DW / Water Label)
# =============================================================================

def _show_other(src, data, dtype):
    if dtype == "Cloud/Shadow":   _show_cloud(src, data)
    elif dtype == "Dynamic World": _show_dw(src, data)
    elif dtype == "Water Label":   _show_water_label(src, data)
    else:                          _show_all_bands(src, data, dtype)


def _show_cloud(src, data):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    im = axes[0].imshow(data[0], cmap="Reds", vmin=0, vmax=100)
    axes[0].set_title("Cloud Probability (%)"); plt.colorbar(im, ax=axes[0], fraction=0.046)
    if src.count >= 3:
        axes[1].imshow(data[2], cmap="gray_r", vmin=0, vmax=1)
        axes[1].set_title("Shadow Mask")
        combined = ((data[0] > 50) | (data[2] == 1)).astype(float)
        axes[2].imshow(combined, cmap="Reds", vmin=0, vmax=1)
        axes[2].set_title("Cloud + Shadow (thresh=50%)")
    else:
        axes[1].text(0.5,0.5,"N/A",ha="center",va="center")
        axes[2].text(0.5,0.5,"N/A",ha="center",va="center")
    for ax in axes: ax.axis("off")
    plt.suptitle(" Cloud / Shadow", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()


def _show_dw(src, data):
    DW_COLORS = {0:"#419BDF",1:"#397D49",2:"#88B053",3:"#7A87C6",
                 4:"#E49635",5:"#DFC35A",6:"#C4281B",7:"#A59B8F",8:"#B39FE1"}
    DW_NAMES  = {0:"Water",1:"Trees",2:"Grass",3:"Flooded Veg",
                 4:"Crops",5:"Shrub",6:"Built",7:"Bare",8:"Snow/Ice"}
    dw = data[0]; fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    axes[0].imshow(dw, cmap="tab10", vmin=-0.5, vmax=8.5)
    axes[0].set_title("Dynamic World Classes"); axes[0].axis("off")
    unique = np.unique(dw[np.isfinite(dw)]).astype(int)
    total = np.isfinite(dw).sum()
    bars, labels, colors = [], [], []
    for c in sorted(unique):
        if c in DW_NAMES:
            pct = (dw==c).sum()/total*100
            bars.append(pct); labels.append(f"{DW_NAMES[c]} ({pct:.1f}%)")
            colors.append(DW_COLORS.get(c, "gray"))
    axes[1].barh(labels, bars, color=colors, edgecolor="white")
    axes[1].set_xlabel("Coverage %"); axes[1].set_title("Class Distribution")
    plt.suptitle(" Dynamic World", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()


def _show_water_label(src, data):
    fig, ax = plt.subplots(figsize=(8, 6))
    label = data[0]
    cmap = ListedColormap(["gray", "white", "dodgerblue"])
    im = ax.imshow(label, cmap=cmap, vmin=-1.5, vmax=1.5)
    ax.set_title("Water Label (-1=NoData, 0=Land, 1=Water)"); ax.axis("off")
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, ticks=[-1, 0, 1])
    cbar.set_ticklabels(["NoData", "Land", "Water"])
    total = label.size
    print(f" Water: {(label==1).sum()/total*100:.1f}%  |  "
          f" Land: {(label==0).sum()/total*100:.1f}%  |  "
          f" NoData: {(label==-1).sum()/total*100:.1f}%")
    plt.tight_layout(); plt.show()


# =============================================================================
# Plot: Label inspection (multi-panel)
# =============================================================================

def _show_label_inspect(src, data, dtype):
    fig = plt.figure(figsize=(20, 6))
    gs = fig.add_gridspec(1, 4, width_ratios=[1, 1, 1, 0.8])
    band = data[0]

    # Panel 1 — Raw
    ax0 = fig.add_subplot(gs[0])
    im0 = ax0.imshow(band, cmap="viridis"); ax0.set_title("Band 1 (Raw)"); ax0.axis("off")
    plt.colorbar(im0, ax=ax0, fraction=0.046)

    # Panel 2 — Colored mask
    ax1 = fig.add_subplot(gs[1])
    if dtype == "Water Label":
        vis = np.full((*band.shape, 3), 0.9)
        vis[band==0]=[0.95,0.9,0.8]; vis[band==1]=[0.1,0.4,0.9]; vis[band==-1]=[0.3,0.3,0.3]
        ax1.imshow(vis); ax1.set_title("Water Label")
    elif dtype == "Cloud/Shadow":
        cp = data[0]; sh = data[2] if src.count>=3 else np.zeros_like(cp)
        vis = np.full((*cp.shape, 3), 0.2)
        vis[cp>50]=[1,1,1]; vis[sh==1]=[0.4,0.4,0.6]; vis[(cp>50)&(sh==1)]=[1,0.6,0.6]
        ax1.imshow(vis); ax1.set_title("Cloud(white) + Shadow(gray)")
    elif dtype == "Dynamic World":
        DW_RGB = {0:[.25,.61,.87],1:[.22,.49,.29],2:[.53,.69,.33],3:[.48,.53,.78],
                  4:[.89,.59,.21],5:[.87,.76,.35],6:[.77,.16,.11],7:[.65,.61,.56],8:[.70,.62,.88]}
        vis = np.full((*band.shape,3), 0.5)
        for c, rgb_v in DW_RGB.items(): vis[band==c] = rgb_v
        ax1.imshow(vis); ax1.set_title("DW Classes")
    else:
        thresh = np.nanmean(band)
        ax1.imshow((band>thresh).astype(float), cmap="gray_r", vmin=0, vmax=1)
        ax1.set_title(f"Binary (thresh={thresh:.2f})")
    ax1.axis("off")

    # Panel 3 — Boundary
    ax2 = fig.add_subplot(gs[2])
    try:
        from scipy.ndimage import binary_dilation, binary_erosion
        if dtype == "Water Label":     target = (band==1).astype(np.uint8)
        elif dtype == "Dynamic World": target = (data[0]==0).astype(np.uint8)
        elif dtype == "Cloud/Shadow":  target = (data[0]>50).astype(np.uint8)
        else:                          target = (band>np.nanmean(band)).astype(np.uint8)
        dilated = binary_dilation(target, iterations=2)
        eroded  = binary_erosion(target, iterations=2)
        boundary = (dilated.astype(int) - eroded.astype(int)).clip(0, 1)
        overlay = np.zeros((*band.shape, 3))
        overlay[target==1]=[0.2,0.5,0.9]; overlay[boundary==1]=[1,0.2,0.2]
        ax2.imshow(overlay); ax2.set_title("Mask + Boundary (red)")
    except ImportError:
        ax2.imshow(band, cmap="gray"); ax2.set_title("(scipy not available)")
    ax2.axis("off")

    # Panel 4 — Pie chart
    ax3 = fig.add_subplot(gs[3]); ax3.axis("off")
    if dtype == "Water Label":
        cats = {"Water(1)":(band==1).sum(),"Land(0)":(band==0).sum(),"NoData(-1)":(band==-1).sum()}
        colors = ["dodgerblue","sandybrown","gray"]
    elif dtype == "Dynamic World":
        DW_N = {0:"Water",1:"Trees",2:"Grass",3:"Flooded",4:"Crops",5:"Shrub",6:"Built",7:"Bare",8:"Snow"}
        DW_C = ["#419BDF","#397D49","#88B053","#7A87C6","#E49635","#DFC35A","#C4281B","#A59B8F","#B39FE1"]
        uniq = sorted(np.unique(data[0][np.isfinite(data[0])]).astype(int))
        cats = {DW_N.get(c,f"C{c}"):(data[0]==c).sum() for c in uniq if c in DW_N}
        colors = [DW_C[c] if c<len(DW_C) else "gray" for c in uniq if c in DW_N]
    elif dtype == "Cloud/Shadow":
        cp = data[0]
        cats = {"Clear(<20%)":(cp<20).sum(),"Hazy(20-50%)":((cp>=20)&(cp<50)).sum(),
                "Cloudy(50-80%)":((cp>=50)&(cp<80)).sum(),"Thick(>80%)":(cp>=80).sum()}
        colors = ["lightgreen","khaki","lightcoral","gray"]
    else:
        v = band[np.isfinite(band)]
        cats = {f">mean({np.nanmean(band):.1f})":(v>np.nanmean(band)).sum(),
                "<=mean":(v<=np.nanmean(band)).sum()}
        colors = ["steelblue","lightgray"]

    total = sum(cats.values())
    labels = [f"{k}\n{v:,} ({v/total*100:.1f}%)" for k,v in cats.items()]
    if total > 0:
        wedges, _ = ax3.pie(list(cats.values()), colors=colors, startangle=90,
                            wedgeprops={"edgecolor":"white","linewidth":1.5})
        ax3.legend(wedges, labels, loc="center left", bbox_to_anchor=(-0.2,-0.3),
                   fontsize=9, frameon=False)
    ax3.set_title("Distribution")
    plt.suptitle(f" Label Inspection — {dtype}", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()

    print("\n Summary:")
    for k, v in cats.items():
        print(f"   {k}: {v:>10,} px ({v/total*100:>5.1f}%)")
    print(f"   Total: {total:>10,} px")


# =============================================================================
# Plot: Metadata
# =============================================================================

def _show_metadata(src, data, dtype):
    cat = get_sensor_category(dtype)
    cat_labels = {"optical":" Optical", "sar":" SAR", "other":" Other"}
    print(f"\n{'='*65}")
    print(f"  FILE:        {os.path.basename(src.name)}")
    print(f"  TYPE:        {dtype}")
    print(f"  SENSOR:      {cat_labels.get(cat, cat)}")
    print(f"{'='*65}")
    print(f"  Size:        {src.width} × {src.height} px")
    print(f"  Bands:       {src.count}")
    print(f"  Dtype:       {src.dtypes[0]}")
    print(f"  CRS:         {src.crs.to_string() if src.crs else 'Undefined'}")
    print(f"  Resolution:  {abs(src.res[0]):.2f} × {abs(src.res[1]):.2f} m")
    print(f"  Bounds:      {src.bounds}")
    print(f"  NoData:      {src.nodata}")
    print(f"  Compression: {src.profile.get('compress', 'None')}")
    t = src.transform
    print(f"  Transform:   [{t.a:.4f}  {t.b:.4f}  {t.c:.2f}]")
    print(f"               [{t.d:.4f}  {t.e:.4f}  {t.f:.2f}]")

    bnames = BAND_INFO.get(dtype, [f"Band {i+1}" for i in range(src.count)])
    print(f"\n{'─'*95}")
    print(f"  {'#':>3s}  {'Name':<22s}  {'Min':>10s}  {'Max':>10s}  "
          f"{'Mean':>10s}  {'Std':>10s}  {'Median':>10s}  {'NaN%':>6s}")
    print(f"  {'─'*92}")
    for i in range(src.count):
        b = data[i]; nan_pct = np.isnan(b).sum()/b.size*100
        nm = bnames[i] if i < len(bnames) else f"Band {i+1}"
        print(f"  {i+1:>3d}  {nm:<22s}  {np.nanmin(b):>10.2f}  {np.nanmax(b):>10.2f}  "
              f"{np.nanmean(b):>10.2f}  {np.nanstd(b):>10.2f}  {np.nanmedian(b):>10.2f}  "
              f"{nan_pct:>5.1f}%")
    if dtype in ("Dynamic World", "Water Label"):
        uniq = np.unique(data[0][np.isfinite(data[0])])
        print(f"\n  Unique values: {uniq}")
    print(f"{'─'*95}")


# =============================================================================
# compare_tifs()  —  side-by-side comparison of 2-4 files
# =============================================================================

def compare_tifs(fpaths, band_idx=0):
    """
    Quick side-by-side comparison of multiple GeoTIFF files.

    Parameters
    ----------
    fpaths : list of str
        Paths to 2-4 .tif files.
    band_idx : int
        Which band to display (0-indexed).
    """
    fpaths = [f for f in fpaths if os.path.isfile(f)]
    n = len(fpaths)
    if n < 2:
        print("❌ Need at least 2 valid files"); return
    if n > 4:
        print("⚠️ Showing first 4 only"); fpaths = fpaths[:4]; n = 4

    fig, axes = plt.subplots(1, n, figsize=(6*n, 5))
    if n == 1: axes = [axes]

    for ax, fp in zip(axes, fpaths):
        with rasterio.open(fp) as src:
            data = src.read().astype(np.float64)
            dtype = identify_data_type(src, data)
            cat = get_sensor_category(dtype)
            idx = min(band_idx, src.count - 1)
            band = data[idx]
            cmap = "gray" if cat == "sar" else "viridis"
            vlo, vhi = np.nanpercentile(band, [2, 98])
            im = ax.imshow(band, cmap=cmap, vmin=vlo, vmax=vhi)
            plt.colorbar(im, ax=ax, fraction=0.046)
            fname = os.path.basename(src.name)
            ax.set_title(f"{fname}\n{dtype} | Band {idx+1}", fontsize=10)
            ax.axis("off")

    plt.suptitle(" Side-by-Side Comparison", fontsize=14, fontweight="bold")
    plt.tight_layout(); plt.show()


# =============================================================================
# Interactive widget viewer (for nested ROI folder structure)
# =============================================================================

class TiffBrowser:
    """Widget-based browser for root/roi_name/*.tif folder structures."""

    def __init__(self, root_dir: str):
        self.root_dir = root_dir
        if not os.path.exists(self.root_dir):
            print(f"⚠️ \'{self.root_dir}\' not found")
            self.root_dir = "."

        style = {"description_width": "initial"}
        self.roi_dd = widgets.Dropdown(description=" ROI:", style=style,
                                       layout=widgets.Layout(width="350px"))
        self.file_dd = widgets.Dropdown(description=" File:", disabled=True,
                                        style=style, layout=widgets.Layout(width="450px"))
        self.band_slider = widgets.IntSlider(description="Band:", min=1, max=1, value=1,
                                             style=style, layout=widgets.Layout(width="280px"))
        self.mode_btn = widgets.ToggleButtons(
            options=["Single Band", "All Bands", "RGB / SAR", "Label", "Metadata"],
            value="Single Band", style=style)
        self.output = widgets.Output()
        self._init_ui()

    def _get_rois(self):
        if not os.path.exists(self.root_dir): return []
        return sorted(d for d in os.listdir(self.root_dir)
                      if os.path.isdir(os.path.join(self.root_dir, d)) and not d.startswith("."))

    def _get_files(self, roi):
        return sorted(os.path.basename(f)
                      for f in glob.glob(os.path.join(self.root_dir, roi, "*.tif")))

    def _on_roi(self, change):
        roi = change["new"]
        if roi:
            files = self._get_files(roi)
            self.file_dd.options = files; self.file_dd.disabled = False
            if files: self.file_dd.value = files[0]

    def _on_file(self, change):
        if change["new"] and self.roi_dd.value:
            self._update_band_range(); self._render()

    def _on_mode(self, change):  self._render()
    def _on_band(self, change):
        if self.mode_btn.value == "Single Band": self._render()

    def _update_band_range(self):
        fpath = os.path.join(self.root_dir, self.roi_dd.value, self.file_dd.value)
        try:
            with rasterio.open(fpath) as src:
                self.band_slider.max = src.count
                self.band_slider.value = min(self.band_slider.value, src.count)
        except Exception: pass

    def _init_ui(self):
        rois = self._get_rois()
        self.roi_dd.options = rois
        self.roi_dd.observe(self._on_roi, names="value")
        self.file_dd.observe(self._on_file, names="value")
        self.mode_btn.observe(self._on_mode, names="value")
        self.band_slider.observe(self._on_band, names="value")
        if rois: self.roi_dd.value = rois[0]

    def show(self):
        display(widgets.VBox([
            widgets.HBox([self.roi_dd, self.file_dd]),
            widgets.HBox([self.mode_btn, self.band_slider]),
            self.output,
        ]))

    def _render(self):
        roi, fname = self.roi_dd.value, self.file_dd.value
        if not roi or not fname: return
        fpath = os.path.join(self.root_dir, roi, fname)
        mode_map = {
            "Single Band": "single", "All Bands": "all", "RGB / SAR": "rgb",
            "Label": "label", "Metadata": "metadata",
        }
        mode = mode_map.get(self.mode_btn.value, "auto")

        self.output.clear_output(wait=True)
        with self.output:
            if mode == "single":
                with rasterio.open(fpath) as src:
                    data = src.read().astype(np.float64)
                    dtype = identify_data_type(src, data)
                    header = f" {dtype}  |  {src.count} bands  |  {src.width}×{src.height}  |  CRS: {src.crs}"
                    print(header); print("─" * len(header))
                    _show_single(src, data, dtype, self.band_slider.value - 1)
            else:
                inspect_tif(fpath, mode=mode)


# =============================================================================
# Flat folder browser (widget for a folder with .tif files, no subfolders)
# =============================================================================

class FlatFolderBrowser:
    """Widget-based browser for a single folder containing .tif files."""

    def __init__(self, folder: str):
        self.folder = folder
        if not os.path.isdir(folder):
            print(f"⚠️ \'{folder}\' is not a directory"); return

        style = {"description_width": "initial"}
        self.file_dd = widgets.Dropdown(description=" File:", style=style,
                                        layout=widgets.Layout(width="500px"))
        self.mode_btn = widgets.ToggleButtons(
            options=["Auto", "Single Band", "All Bands", "RGB / SAR",
                     "Indices", "Label", "Metadata"],
            value="Auto", style=style)
        self.output = widgets.Output()

        files = sorted(os.path.basename(f) for f in glob.glob(os.path.join(folder, "*.tif")))
        if not files:
            print(f"⚠️ No .tif files in {folder}"); return
        self.file_dd.options = files; self.file_dd.value = files[0]
        self.file_dd.observe(self._on_change, names="value")
        self.mode_btn.observe(self._on_change, names="value")

    def _on_change(self, change):
        self.output.clear_output(wait=True)
        with self.output:
            mode_map = {"Auto":"auto", "Single Band":"single", "All Bands":"all",
                        "RGB / SAR":"rgb", "Indices":"compare_idx",
                        "Label":"label", "Metadata":"metadata"}
            fpath = os.path.join(self.folder, self.file_dd.value)
            inspect_tif(fpath, mode=mode_map.get(self.mode_btn.value, "auto"))

    def show(self):
        display(widgets.VBox([self.file_dd, self.mode_btn, self.output]))
        self._on_change(None)


print("✅ Visualization engine loaded")
print("   Functions:  inspect_tif(path, mode)  |  compare_tifs([path1, path2, ...])")
print("   Widgets:    TiffBrowser(root_dir)    |  FlatFolderBrowser(folder)")


---
## §1 · Quick Look — Single File

Just point at **one `.tif`** and see it instantly.  
No folder structure needed. Pick a `mode` or let `"auto"` decide.

```python
# modes: "auto", "single", "all", "rgb", "label", "metadata", "compare_idx"
inspect_tif("/path/to/file.tif", mode="auto")
```


In [None]:
# Quick look at a single .tif file
inspect_tif(
    "/path/to/your/file.tif",
    mode="rgb",  # options: "auto", "single", "all", "rgb", "label", "metadata", "compare_idx"
)

---
## §2 · Flat Folder Browser

Your `.tif` files live directly inside **one folder** (no ROI subfolders).

```
my_folder/
├── optical_2024_01.tif
├── sar_2024_01.tif
└── label.tif
```


In [None]:
# Browse .tif files inside a single folder
browser = FlatFolderBrowser("/path/to/your/folder")
browser.show()

---
## §3 · Nested ROI Browser

Classic dataset layout with **one subfolder per ROI**:

```
root/
├── Bolivia/
│   ├── Bolivia_103757_S2Hand.tif
│   ├── Bolivia_103757_S1Hand.tif
│   └── Bolivia_103757_LabelHand.tif
├── India/
│   └── ...
└── USA/
    └── ...
```


In [None]:
# Sen1Floods11 HandLabeled
viewer = TiffBrowser(
    "/home/beomsik/data_2_intern_students/beomsik/floods/data/v1.1/data/flood_events/HandLabeled"
)
viewer.show()


---
## §4 · Glob Pattern Search

Files are scattered across subdirectories?  
Use a **glob pattern** to find and inspect them all.


In [None]:
# Find and inspect files matching a glob pattern
pattern = "/path/to/data/**/*S2*.tif"
files = sorted(glob.glob(pattern, recursive=True))

print(f"Found {len(files)} files")
for f in files[:5]:
    inspect_tif(f, mode="auto")

---
## §5 · Multi-Path Comparison

Pick **2-4 specific files** and compare them side by side.  
Great for comparing the same ROI across sensors, dates, or processing levels.


In [None]:
# Side-by-side comparison of 2-4 files
compare_tifs([
    "/path/to/optical.tif",
    "/path/to/sar.tif",
    "/path/to/label.tif",
], band_idx=0)

---
## Summary of Entry Points

| Function / Class | Usage |
|------------------|-------|
| `inspect_tif(path, mode)` | One-shot view of a single file |
| `compare_tifs([p1, p2, ...])` | Side-by-side comparison |
| `FlatFolderBrowser(folder)` | Widget for flat folder of `.tif` files |
| `TiffBrowser(root_dir)` | Widget for `root/roi/*.tif` structure |
| `glob.glob(pattern)` → loop `inspect_tif` | Batch inspection via pattern |
