In [None]:
import datetime
import json
from pathlib import Path

import imageio.v2 as imageio
import numpy as np
import pandas as pd
import rasterio
import tqdm
import yaml
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from sklearn.metrics import balanced_accuracy_score
from torch.utils.data import DataLoader

from estuary.clay.data import EstuaryDataModule, EstuaryDataset, load_labels, parse_dt_from_pth
from estuary.clay.module import EstuaryModule
from estuary.util import masked_contrast_stretch

In [None]:
BASE = Path("/Users/kyledorman/data/results/estuary/train/20250827-145944/")
CKPT = BASE / "checkpoints" / "last.ckpt"
LABEL_PATH = Path("/Users/kyledorman/data/estuary/label_studio/00025/labels.csv")
CROP_PATH = Path("/Users/kyledorman/data/estuary/label_studio/region_crops.json")
region_crops = json.loads(CROP_PATH.read_bytes())

In [None]:
def draw_label(
    img: Image.Image, text: str, color: tuple[int, int, int], add_border=True
) -> Image.Image:
    """Draw a semi-transparent banner with outlined text, and optional colored border."""
    # Optional: try a nicer font; fall back to default if not available
    try:
        FONT = ImageFont.truetype("/System/Library/Fonts/Supplemental/Arial Bold.ttf", 20)
    except Exception:
        FONT = ImageFont.load_default()

    draw = ImageDraw.Draw(img, "RGBA")
    w, h = img.size

    # Banner box
    pad_x, pad_y = 10, 8
    text_w, text_h = draw.textbbox((0, 0), text, font=FONT)[2:]
    box_w = min(w - 2 * pad_x, text_w + 2 * pad_x)
    box_h = text_h + 2 * pad_y

    # Top-left anchor for banner
    x0, y0 = pad_x, pad_y
    x1, y1 = x0 + box_w, y0 + box_h

    # Semi-transparent dark banner
    draw.rounded_rectangle([x0, y0, x1, y1], radius=10, fill=(0, 0, 0, 110))

    # Outlined text (stroke) for readability
    draw.text(
        (x0 + pad_x, y0 + pad_y),
        text,
        font=FONT,
        fill=(255, 255, 255, 255),
        stroke_width=2,
        stroke_fill=(0, 0, 0, 220),
    )

    # Optional border matching class color
    if add_border:
        draw.rectangle([0, 0, w - 1, h - 1], outline=color + (255,), width=4)

    return img

In [None]:
valid_df = pd.read_csv(BASE / "valid.csv")
valid_df["acquired"] = pd.to_datetime(valid_df["acquired"], errors="coerce")
valid_df["acquired_date"] = valid_df.acquired.dt.date
valid_df.head(3)

In [None]:
counts = valid_df.groupby("region")["acquired_date"].nunique()
counts.head(15)

In [None]:
preds = pd.read_csv(BASE / "preds.csv")
preds["acquired"] = pd.to_datetime(preds["acquired"], errors="coerce")
preds["acquired_date"] = preds.acquired.dt.date

preds.head(3)

In [None]:
labels = pd.read_csv(LABEL_PATH)
labels["acquired"] = labels.source_tif.apply(lambda a: parse_dt_from_pth(Path(a)))
labels["acquired"] = pd.to_datetime(labels["acquired"], errors="coerce")
labels["acquired_date"] = labels.acquired.dt.date
labels = labels[labels.label != "unsure"]
labels = labels.sort_values(by=["region", "acquired"]).reset_index(drop=True)

labels.head(3)

In [None]:
counts = labels.groupby("region")["acquired"].nunique()
counts.head(15)

In [None]:
high_res = []
for pth in Path("/Users/kyledorman/data/estuary/skysat/results/").glob(
    "*/*/files/*_pansharpened_clip.tif"
):
    yearmonthday = pth.stem.split("_")[0]
    dt = pd.to_datetime(yearmonthday, format="%Y%m%d")
    high_res.append([pth, pth.parent.parent.name, dt])
high_res_df = pd.DataFrame(high_res, columns=["path", "region", "acquired"])
high_res_df = high_res_df.sort_values(by=["region", "acquired"]).reset_index(drop=True)

high_res_df.head(3)

In [None]:
counts = high_res_df.groupby("region")["acquired"].nunique()
counts.head(15)

In [None]:
module = EstuaryModule.load_from_checkpoint(CKPT, batch_size=1, strict=False)
module = module.eval()
dm = EstuaryDataModule(module.conf)
dm.prepare_data()
dm.setup()


def test_acc():
    dl = dm.test_dataloader()

    preds = []
    labels = []
    for batch, blabel in tqdm.tqdm(dl, total=len(dl)):
        for k in batch.keys():
            batch[k] = batch[k].to(module.device)
        pred_batch = module.forward(batch)
        preds.extend(pred_batch.argmax(axis=1).detach().cpu().numpy().tolist())
        labels.extend(blabel.detach().cpu().numpy().tolist())

    test_pred = pd.DataFrame(
        list(zip(preds, labels, dm.test_ds.df.region.tolist(), strict=False)),
        columns=["pred", "label", "region"],
    )
    # Compute balanced accuracy per region
    region_balanced_acc = (
        test_pred.groupby("region")
        .apply(lambda df: balanced_accuracy_score(df["label"], df["pred"]))
        .reset_index(name="balanced_accuracy")
    ).set_index("region")

    acc = balanced_accuracy_score(labels, preds)

    return region_balanced_acc, acc


test_region_balanced_acc, test_acc = test_acc()
print(round(100 * test_acc, 1))

test_region_balanced_acc.head(3)

In [None]:
holdout_acc = []
runs = []
for pth in Path("/Users/kyledorman/data/results/estuary/train/").glob("202*"):
    if not pth.is_dir():
        continue
    with open(pth / "cli_diff.yaml") as f:
        config = yaml.safe_load(f)
    if "holdout_region" not in config:
        continue
    runs.append(pth)

print(f"Found {len(runs)} holdout regions")

for pth in runs:
    with open(pth / "cli_diff.yaml") as f:
        config = yaml.safe_load(f)
    ckpt = next((pth / "checkpoints").glob("epoch*"))
    module = EstuaryModule.load_from_checkpoint(ckpt, batch_size=1, strict=False)
    module = module.eval()
    holdout_region = module.conf.holdout_region
    run_labels = load_labels(module.conf)

    ds = EstuaryDataset(
        run_labels[run_labels.region == holdout_region], crops_map=region_crops, conf=module.conf
    )
    dl = DataLoader(ds, batch_size=1)
    rpreds = []
    rlabels = []
    for batch, blabel in tqdm.tqdm(dl, total=len(dl)):
        for k in batch.keys():
            batch[k] = batch[k].to(module.device)
        pred_batch = module.forward(batch)
        rpreds.extend(pred_batch.argmax(axis=1).detach().cpu().numpy().tolist())
        rlabels.extend(blabel.detach().cpu().numpy().tolist())

    holdout_acc.append((holdout_region, balanced_accuracy_score(rlabels, rpreds)))

holdout_acc = pd.DataFrame(holdout_acc, columns=["region", "holdout_acc"]).set_index("region")

In [None]:
data_stats = pd.concat(
    [
        valid_df.groupby("region")["acquired_date"].nunique().rename("dove_total"),
        labels.groupby("region").acquired.nunique().rename("dove_labeled"),
        high_res_df.groupby("region").acquired.nunique().rename("skysat_total"),
        test_region_balanced_acc.balanced_accuracy.rename("test_accuracy").round(2),
        holdout_acc.round(2),
    ],
    axis=1,
)

data_stats.to_csv("/Users/kyledorman/data/estuary/display/region_stats.csv")
data_stats.head(15)

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(5 * 2.9, 3 * 3), constrained_layout=True)
for (region, rows), ax in zip(preds.groupby("region"), axes.flatten(), strict=False):
    ax.axis("off")
    ax.set_title(" ".join([r.capitalize() for r in region.split("_")]))
    crop = region_crops[region]
    start_w, start_h, end_w, end_h = crop
    w = end_w - start_w
    h = end_h - start_h
    pth = rows.iloc[6].source_tif
    with rasterio.open(pth) as src:
        data = src.read(out_dtype=np.float32)[:, start_h:end_h, start_w:end_w]
        nodata = src.read(1, masked=True).mask[start_h:end_h, start_w:end_w]
    data = np.log10(data + 1)
    imgd = masked_contrast_stretch(data, ~nodata, p_low=1, p_high=99)
    rgb = imgd[[2, 1, 0]].transpose((1, 2, 0))
    img = Image.fromarray(np.array(np.clip(rgb * 255, 0, 255), dtype=np.uint8)).resize((256, 256))
    ax.imshow(img)

plt.savefig("/Users/kyledorman/data/estuary/display/all_sites_dove.png")
plt.show()

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(5 * 2.9, 3 * 3), constrained_layout=True)
for (region, rows), ax in zip(high_res_df.groupby("region"), axes.flatten(), strict=False):
    ax.axis("off")
    ax.set_title(" ".join([r.capitalize() for r in region.split("_")]))
    pth = rows.iloc[0].path
    with rasterio.open(pth) as src:
        data = src.read([3, 2, 1], out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
    data = np.log10(data + 1)
    imgd = masked_contrast_stretch(data, ~nodata, p_low=1, p_high=99)
    rgb = imgd.transpose((1, 2, 0))
    img = Image.fromarray(np.array(np.clip(rgb * 255, 0, 255), dtype=np.uint8)).resize((512, 512))
    ax.imshow(img)

plt.savefig("/Users/kyledorman/data/estuary/display/all_sites_skysat.png")
plt.show()

In [None]:
region = "topanga"
start = datetime.datetime(year=2024, month=9, day=1)
end = datetime.datetime(year=2024, month=12, day=31)
crop = region_crops[region]
start_w, start_h, end_w, end_h = crop
w = end_w - start_w
h = end_h - start_h

ddf = preds
gif_df = ddf[(ddf.region == region) & (ddf.acquired > start) & (ddf.acquired < end)]

len(gif_df)

In [None]:
save_path = Path(f"/Users/kyledorman/data/estuary/display/gifs/{region}_{start.date()}.mp4")
save_path.parent.mkdir(exist_ok=True, parents=True)

frames = []
for _, row in tqdm.tqdm(gif_df.iterrows(), total=len(gif_df)):
    pth = row.source_tif
    pred_name = "open" if row.pred == 0 else "close"
    pred_color = (44, 160, 44) if row.pred == 0 else (214, 39, 40)  # green/red
    conf_str = f"{row.conf:.2f}" if "conf" in gif_df.columns else "—"
    date_str = getattr(row, "acquired", None)
    if date_str is not None:
        # Parse YYYYMMDD or ISO-like strings robustly
        try:
            # if already datetime-like, this is a no-op; else try %Y%m%d
            dt = pd.to_datetime(date_str, format="%Y%m%d", errors="ignore")
            dt = pd.to_datetime(dt)  # ensure Timestamp
            date_disp = dt.strftime("%Y-%m-%d")
        except Exception:
            date_disp = str(date_str)
    else:
        date_disp = ""

    with rasterio.open(pth) as src:
        data = src.read(out_dtype=np.float32)[:, start_h:end_h, start_w:end_w]
        nodata = src.read(1, masked=True).mask[start_h:end_h, start_w:end_w]
    data = np.log10(data + 1)
    imgd = masked_contrast_stretch(data, ~nodata, p_low=1, p_high=99)
    rgb = imgd[[2, 1, 0]].transpose((1, 2, 0))
    img = Image.fromarray(np.array(np.clip(rgb * 255, 0, 255), dtype=np.uint8)).resize((256, 256))

    # Compose label text — include region/pred/conf/date as you like
    label_text = f"{pred_name}"
    if date_disp:
        label_text = f"{date_disp} • " + label_text

    img = draw_label(img, label_text, pred_color, add_border=True)

    frames.append(img)

# Convert each PIL frame to a NumPy array (imageio needs ndarray or PIL)
frame_arrays = [np.array(im.convert("RGB")) for im in frames]

In [None]:
# Write MP4 (H.264)
fps = 1
imageio.mimsave(
    save_path,
    frame_arrays,
    fps=fps,
    codec="libx264",  # H.264 for compatibility
    quality=10,  # 0 (lowest) - 10 (highest) for libx264
    macro_block_size=None,  # keeps original frame size
)
print(f"Saved video → {save_path}")

In [None]:
from IPython.display import Video

Video(str(save_path), embed=True, width=600)

In [None]:
save_base = Path("/Users/kyledorman/data/estuary/display/skysat")
ddf = valid_df
for region in ddf.region.unique():
    pdf = ddf[ddf.region == region]
    hdf = high_res_df[high_res_df.region == region]
    # Work on sorted copies (required by merge_asof)
    hdf_s = hdf.sort_values("acquired").reset_index(drop=True)
    pdf_s = pdf.sort_values("acquired").reset_index(drop=True)
    # Nearest match within n days
    pairs = pd.merge_asof(
        hdf_s,
        pdf_s,
        on="acquired",
        direction="nearest",
        tolerance=pd.Timedelta("2D"),
        suffixes=("_h", "_p"),
    )
    # Keep only rows that found a match (otherwise columns from pdf will be NaN)
    pairs = pairs.dropna(subset=["path", "source_tif"])

    save = save_base / region
    save.mkdir(exist_ok=True, parents=True)

    crop = region_crops[region]
    start_w, start_h, end_w, end_h = crop
    w = end_w - start_w
    h = end_h - start_h

    for state in [0, 1]:
        for _, row in pairs[pairs.pred == state].iterrows():
            with rasterio.open(row.path) as src:
                data = src.read(out_dtype=np.float32)
                nodata = src.read(1, masked=True).mask
            data = np.log10(data + 1)
            imgd = masked_contrast_stretch(data, ~nodata, p_low=1, p_high=99)
            rgb = imgd[[2, 1, 0]].transpose((1, 2, 0))
            sky_img = Image.fromarray(np.array(np.clip(rgb * 255, 0, 255), dtype=np.uint8)).resize(
                (512, 512)
            )

            pred_name = "open" if row.pred == 0 else "close"

            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            axes[0].imshow(sky_img)
            axes[0].axis("off")

            with rasterio.open(row.source_tif) as src:
                data = src.read(out_dtype=np.float32)[:, start_h:end_h, start_w:end_w]
                nodata = src.read(1, masked=True).mask[start_h:end_h, start_w:end_w]
            data = np.log10(data + 1)
            imgd = masked_contrast_stretch(data, ~nodata, p_low=1, p_high=99)
            rgb = imgd[[2, 1, 0]].transpose((1, 2, 0))
            img = Image.fromarray(np.array(np.clip(rgb * 255, 0, 255), dtype=np.uint8)).resize(
                (256, 256)
            )

            axes[1].imshow(img)
            axes[1].axis("off")

            fig.suptitle(f"Prediction: {pred_name}   Date: {row.acquired.date()}")

            plt.tight_layout()
            # plt.show()
            plt.savefig(save / f"{pred_name}_{row.acquired.date()}.png")
            plt.close()