In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import math
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
from PIL import Image

from estuary.model.data import parse_dt_from_pth
from estuary.util import broad_band, false_color

In [None]:
tdf = pd.read_csv(
    "/Users/kyledorman/data/results/estuary/train/20251008-151833/timeseries_preds.csv"
)
tdf["acquired"] = tdf.source_tif.apply(lambda a: parse_dt_from_pth(Path(a)))
tdf["year"] = tdf.acquired.dt.year
tdf["y_prob_true"] = 0.05
tdf.loc[tdf.orig_label == "perched open", "y_prob_true"] = 0.6
tdf.loc[tdf.orig_label == "open", "y_prob_true"] = 0.95
tdf = tdf.sort_values("acquired").reset_index(drop=True)
print(sorted(tdf.region.unique().tolist()))

In [None]:
tdf[tdf.year == 2024].groupby(["region"]).correct.size()

In [None]:
tdf[tdf.year == 2024].groupby(["region"]).correct.mean().round(3)

In [None]:
tdf[tdf.year == 2024].groupby(["region", "orig_label"]).correct.mean().round(3)

In [None]:
def low_pri_iter():
    for region, rrdf in tdf[tdf.region == 84].groupby("region"):
        torun = rrdf[
            # ((rrdf.orig_label == "open") & (rrdf.y_prob < 0.2))
            ((rrdf.orig_label == "closed") & (rrdf.correct == False))
        ]
        for _, row in torun.iterrows():
            yield region, row.orig_label, row.source_tif, row.y_prob


iii = low_pri_iter()

In [None]:
region, orig_label, pth, y_prob = next(iii)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))

ax.set_axis_off()

with rasterio.open(pth) as src:
    data = src.read(out_dtype=np.float32)
    nodata = src.read(1, masked=True).mask
    img = false_color(data, nodata)
    img = Image.fromarray(img)
ax.imshow(img)
ax.set_title(f"{region} - {orig_label} - {Path(pth).stem} - {y_prob:0.3}")

fig.tight_layout()
plt.show()

In [None]:
rdf = tdf[tdf.region == 2145].copy()
rdf.head(3)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(10, 5))
rdf[rdf.orig_label == "open"].plot.scatter(x="acquired", y="y_prob", ax=axes, color="green")
rdf[rdf.orig_label == "closed"].plot.scatter(x="acquired", y="y_prob", ax=axes, color="red")
rdf[rdf.orig_label == "perched open"].plot.scatter(x="acquired", y="y_prob", ax=axes, color="blue")

In [None]:
def image_iter(df, count):
    print(len(df))
    group = []
    for _, row in df.iterrows():
        group.append((row.source_tif, row.y_prob))
        if len(group) == count:
            yield group
            group = []
    if len(group):
        yield group

In [None]:
ii = image_iter(rdf[(rdf.orig_label == "open") & (rdf.y_prob < 0.2)].sort_values("y_prob"), 4)

In [None]:
images = next(ii)

assert len(images)

cols = min(len(images), 2)
rows = max(1, min(math.ceil(len(images) // 2), 2))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(4 * cols, 4 * rows))

if len(images) == 1:
    axs = [[axs]]
elif rows == 1:
    axs = [axs]
axs = [ax for axx in axs for ax in axx]
for (source_tif, y_prob), ax in zip(images, axs, strict=False):
    ax.set_axis_off()

    with rasterio.open(source_tif) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    ax.imshow(img)
    ax.set_title(f"{y_prob:0.3}")

fig.tight_layout()
plt.show()

In [None]:
def extract_changes(
    times: pd.Series,
    states: np.ndarray,
) -> list[tuple[pd.Timestamp, int]]:
    """Return change timestamps with new-state labels (0 or 1).
    Resets across large gaps. Returns a list of
    (timestamp, new_state) tuples so callers can distinguish 0→1 vs 1→0.
    """
    t = pd.to_datetime(times).reset_index(drop=True)
    dts = t.diff().dt.total_seconds().fillna(0)
    changes: list[tuple[pd.Timestamp, pd.Timedelta, int]] = []

    prev_state = int(states[0])
    prev_time = t.iat[0]
    for i in range(1, len(states)):
        curr_time = t.iat[i]
        s = int(states[i])

        # Normal (non-gap) transition detection at the boundary sample
        if s != prev_state:
            delta = curr_time - prev_time
            changes.append((curr_time, delta, s))
            prev_state = s
            prev_time = curr_time

    return changes

In [None]:
def rolling_smooth(df, time_col, prob_col, days=2):
    hours = days * 24 + 4
    delta = pd.Timedelta(hours=hours)
    g = df[[time_col, prob_col]].copy()
    g[time_col] = pd.to_datetime(g[time_col], errors="coerce")
    g = g.sort_values(time_col)
    new_col = f"rolling_{prob_col}"
    g[new_col] = g[prob_col].copy()

    for idx, row in g.iterrows():
        sd = row.acquired - delta
        ed = row.acquired + delta
        num_before = ((g.acquired > sd) & (g.acquired < row.acquired)).sum()
        if not num_before:
            continue
        cols = g[(g.acquired > sd) & (g.acquired < ed)]
        g.loc[idx, new_col] = np.mean(cols[prob_col].to_numpy()).item()

    return g


rolling = rolling_smooth(rdf, "acquired", "y_prob_true", days=2)
rdf.loc[rolling.index, "rolling_y_prob_true"] = rolling["rolling_y_prob_true"].copy()
rdf.loc[rdf.index, "rolling_y_true"] = (rdf["rolling_y_prob_true"] > 0.5).astype(int).copy()

prolling = rolling_smooth(rdf, "acquired", "y_prob", days=2)
rdf.loc[prolling.index, "rolling_y_prob"] = prolling["rolling_y_prob"].copy()
rdf.loc[rdf.index, "rolling_y_pred"] = (rdf["rolling_y_prob"] > 0.5).astype(int).copy()
rdf.loc[rdf.index, "rolling_correct"] = (rdf["rolling_y_pred"] == rdf["y_true"]).copy()
rdf.loc[rdf.index, "rolling_rolling_correct"] = (
    rdf["rolling_y_pred"] == rdf["rolling_y_true"]
).copy()

In [None]:
rdf[["correct", "rolling_correct", "rolling_rolling_correct"]].astype(np.int32).mean().round(3)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(10, 5))
rdf[rdf.orig_label == "open"].plot.scatter(x="acquired", y="rolling_y_prob", ax=axes, color="green")
rdf[rdf.orig_label == "closed"].plot.scatter(x="acquired", y="rolling_y_prob", ax=axes, color="red")
rdf[rdf.orig_label == "perched open"].plot.scatter(
    x="acquired", y="rolling_y_prob", ax=axes, color="blue"
)

In [None]:
fig, ax = plt.subplots(1, 1)
rdf[rdf.rolling_y_true != rdf.y_true].plot.scatter(
    x="acquired", y="rolling_y_prob_true", ax=ax, color="purple"
)
rdf[rdf.rolling_y_true != rdf.y_true].plot.scatter(
    x="acquired", y="y_prob_true", ax=ax, color="green"
)

plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
rdf[rdf.rolling_y_true != rdf.rolling_y_pred].plot.scatter(
    x="rolling_y_prob", y="y_prob", ax=ax, color="purple"
)
rdf[rdf.rolling_y_true == rdf.rolling_y_pred].plot.scatter(
    x="rolling_y_prob", y="y_prob", ax=ax, color="green"
)

plt.show()

In [None]:
import numpy as np
import pandas as pd


def hmm_decode_irregular(
    times: pd.Series,  # datetime-like, sorted
    p_open: np.ndarray,  # per-frame probs in [0,1]
    tau_open_h: float = 72.0,  # avg dwell time open (hours)
    tau_closed_h: float = 72.0,  # avg dwell time closed (hours)
    gap_reset_h: float = 72.0,  # break sequence if Δt > gap_reset_h
    reliability: float = 1.0,  # 1.0=trust model fully; 0.0=ignore model (0.5)
    eps: float = 1e-6,  # numerical floor
) -> tuple[np.ndarray, list[tuple[pd.Timestamp, int]]]:
    """
    Two-state HMM (0=closed, 1=open) with time-varying transitions from dwell times.
    Viterbi decoding per contiguous segment; gap midpoints generate a change event
    if the state after the gap differs from the state before it.

    Returns:
      states   : np.ndarray[int] of shape (T,)
      events   : list of (timestamp, new_state) where state changes
    """
    times = pd.to_datetime(times).reset_index(drop=True)
    p_open = np.clip(np.asarray(p_open, dtype=float), eps, 1 - eps)

    # Down-weight unreliable sites by mixing with neutral 0.5
    if reliability < 1.0:
        p_open = reliability * p_open + (1.0 - reliability) * 0.5
        p_open = np.clip(p_open, eps, 1 - eps)

    p_closed = 1.0 - p_open
    loge = np.vstack([np.log(p_closed), np.log(p_open)])  # shape (2, T)

    dt_s = times.diff().dt.total_seconds().fillna(0).to_numpy()
    dt_h = dt_s / 3600.0

    # Identify segment boundaries by large gaps
    breaks = np.where(dt_h > gap_reset_h)[0]  # indices i where gap between i-1 and i is large
    seg_starts = np.r_[0, breaks + 1]
    seg_ends = np.r_[breaks, len(times) - 1]

    states = np.zeros(len(times), dtype=np.int8)
    events: list[tuple[pd.Timestamp, int]] = []

    def _viterbi_segment(s0: int, t_idx: np.ndarray):
        """
        Viterbi on a time slice [t_idx], using time-varying transitions from dt_h.
        Assumes dt_h[t] is the delta from t-1->t in hours.
        """
        idx0, idx1 = t_idx[0], t_idx[-1]
        Tseg = len(t_idx)

        # dp and backpointers
        dp = np.full((2, Tseg), -np.inf, dtype=float)
        prev = np.full((2, Tseg), -1, dtype=np.int8)

        # init with a soft prior favoring s0, but not forcing it (0.8/0.2)
        init_prior = np.array([0.8, 0.2]) if s0 == 0 else np.array([0.2, 0.8])
        dp[:, 0] = np.log(init_prior + eps) + loge[:, idx0]

        for k in range(1, Tseg):
            i = t_idx[k]  # global index
            # time-varying transition probs from dwell times
            p_stay_closed = np.exp(-dt_h[i] / max(tau_closed_h, eps))
            p_stay_open = np.exp(-dt_h[i] / max(tau_open_h, eps))
            A = np.array(
                [
                    [p_stay_closed, 1 - p_stay_closed],  # from closed -> [closed, open]
                    [1 - p_stay_open, p_stay_open],  # from open   -> [closed, open]
                ],
                dtype=float,
            )
            logA = np.log(np.clip(A, eps, 1.0))

            # transition to closed (0)
            cand0 = np.array([dp[0, k - 1] + logA[0, 0], dp[1, k - 1] + logA[1, 0]])
            prev[0, k] = np.argmax(cand0)
            dp[0, k] = loge[0, i] + np.max(cand0)
            # transition to open (1)
            cand1 = np.array([dp[0, k - 1] + logA[0, 1], dp[1, k - 1] + logA[1, 1]])
            prev[1, k] = np.argmax(cand1)
            dp[1, k] = loge[1, i] + np.max(cand1)

        # backtrack
        sT = int(np.argmax(dp[:, -1]))
        path = np.empty(Tseg, dtype=np.int8)
        path[-1] = sT
        for k in range(Tseg - 1, 0, -1):
            path[k - 1] = prev[path[k], k]
        return path

    # Decode each contiguous segment, seeding each by the first frame’s argmax
    last_state: int | None = None
    last_time: pd.Timestamp | None = None

    for start, end in zip(seg_starts, seg_ends, strict=False):
        seg_idx = np.arange(start, end + 1)
        if not len(seg_idx):
            continue
        seed = int(p_open[start] >= 0.5) if last_state is None else last_state
        seg_states = _viterbi_segment(seed, seg_idx)
        states[seg_idx] = seg_states

        # cross-gap event at segment boundary (midpoint) if state changed across the gap
        if last_state is not None and start > 0:
            if states[start] != last_state:
                delta = times[start] - last_time
                events.append((times[start], delta, int(states[start])))

        # within-segment events at boundary samples
        for k in range(start + 1, end + 1):
            if states[k] != states[k - 1]:
                delta = times[k] - times[k - 1]
                events.append((times[k], delta, int(states[k])))

        last_state = int(states[end])
        last_time = times[end]

    return states, events

In [None]:
import numpy as np
import pandas as pd


# ---- helper: compute events from HMM states (already returned by hmm_decode_irregular,
# but we keep this for completeness / unit tests) ----
def _events_from_states(times: pd.Series, states: np.ndarray) -> list[tuple[pd.Timestamp, int]]:
    times = pd.to_datetime(times).reset_index(drop=True)
    ev: list[tuple[pd.Timestamp, int]] = []
    for i in range(1, len(states)):
        if states[i] != states[i - 1]:
            ev.append((times[i], int(states[i])))
    return ev


# ---- decode + build events for one site ----
def _decode_site_to_events(
    g: pd.DataFrame,
    time_col: str,
    prob_col: str,
    tau_open_h: float,
    tau_closed_h: float,
    gap_reset_h: float,
    reliability: float,
) -> tuple[np.ndarray, list[tuple[pd.Timestamp, int]]]:
    states, events = hmm_decode_irregular(
        g[time_col],
        g[prob_col].to_numpy(),
        tau_open_h=tau_open_h,
        tau_closed_h=tau_closed_h,
        gap_reset_h=gap_reset_h,
        reliability=reliability,
    )
    return states, events  # events: [(timestamp, new_state), ...]


# ---- evaluate a single site's params using Hungarian matching ----
def _score_site_params(
    g: pd.DataFrame,
    time_col: str,
    prob_col: str,
    ytrue_col: str,
    tau_open_h: float,
    tau_closed_h: float,
    gap_reset_h: float,
    reliability: float,
    tol_hours: float,
    smooth_gt: bool = False,
    hyst_T_high: float = 0.65,
    hyst_T_low: float = 0.45,
    hyst_min_run: int = 2,
    hyst_gap_reset_h: float | None = None,
) -> dict[str, float]:
    g = g.sort_values(time_col)
    # HMM decode predictions
    _, pred_events = _decode_site_to_events(
        g, time_col, prob_col, tau_open_h, tau_closed_h, gap_reset_h, reliability
    )

    # Ground-truth states (optionally smooth with hysteresis to suppress 1-frame glitches)
    if smooth_gt:
        gt_states = hysteresis_decode(
            g[time_col],
            g[ytrue_col].astype(float).to_numpy(),  # 0/1 -> float for hysteresis
            T_high=hyst_T_high,
            T_low=hyst_T_low,
            min_run=hyst_min_run,
            gap_reset_hours=(hyst_gap_reset_h if hyst_gap_reset_h is not None else gap_reset_h),
        )
        gt_events = extract_changes(g[time_col], gt_states)
    else:
        gt_events = extract_changes(g[time_col], g[ytrue_col].to_numpy())

    # Hungarian matching (direction required, ± tol_hours)
    res = hungarian_match_events_with_tolerance(gt_events, pred_events, max_hours=tol_hours)
    tp = len(res[0])
    fp = len(res[1])
    fn = len(res[2])

    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0

    return {
        "tp": tp,
        "fp": fp,
        "fn": fn,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "gt_events": len(gt_events),
        "pred_events": len(pred_events),
    }


# ---- LOSO grid search over small param sets ----
def evaluate_hmm_loso(
    df: pd.DataFrame,
    site_col: str = "region",
    time_col: str = "acquired",
    prob_col: str = "y_prob",
    ytrue_col: str = "y_true",
    param_grid: list[dict] | None = None,
    gap_reset_h: float = 72.0,
    tol_hours: float = 48.0,
    smooth_gt: bool = True,
    hyst_T_high: float = 0.65,
    hyst_T_low: float = 0.45,
    hyst_min_run: int = 2,
) -> tuple[pd.DataFrame, dict]:
    """
    LOSO over sites. For each held-out site, pick the param set that maximizes mean F1
    on the other sites, then evaluate on the held-out.

    param_grid: list of dicts with keys in {tau_open_h, tau_closed_h, reliability}.
                If None, a small default grid is used.
    """
    if param_grid is None:
        param_grid = [
            {"tau_open_h": 48, "tau_closed_h": 48, "reliability": 1.0},
            {"tau_open_h": 72, "tau_closed_h": 72, "reliability": 1.0},
            {"tau_open_h": 96, "tau_closed_h": 72, "reliability": 1.0},
            {"tau_open_h": 72, "tau_closed_h": 96, "reliability": 1.0},
            {"tau_open_h": 72, "tau_closed_h": 72, "reliability": 0.85},
            {"tau_open_h": 96, "tau_closed_h": 96, "reliability": 0.85},
        ]

    sites = list(df[site_col].unique())
    per_site_rows = []
    best_params_per_site: dict = {}

    for holdout in sites:
        dev_sites = [s for s in sites if s != holdout]
        df_holdout = df[df[site_col] == holdout]
        df_dev = df[df[site_col].isin(dev_sites)]

        # grid score on dev
        grid_scores = []
        for p in param_grid:
            f1s = []
            for s in dev_sites:
                g = df_dev[df_dev[site_col] == s]
                m = _score_site_params(
                    g,
                    time_col,
                    prob_col,
                    ytrue_col,
                    tau_open_h=p["tau_open_h"],
                    tau_closed_h=p["tau_closed_h"],
                    gap_reset_h=gap_reset_h,
                    reliability=p["reliability"],
                    tol_hours=tol_hours,
                    smooth_gt=smooth_gt,
                    hyst_T_high=hyst_T_high,
                    hyst_T_low=hyst_T_low,
                    hyst_min_run=hyst_min_run,
                    hyst_gap_reset_h=gap_reset_h,
                )
                f1s.append(m["f1"])
            grid_scores.append((p, float(np.mean(f1s) if f1s else 0.0)))

        # choose best param set by mean dev F1
        best_p, _ = max(grid_scores, key=lambda kv: kv[1]) if grid_scores else (param_grid[0], 0.0)
        best_params_per_site[holdout] = best_p

        # evaluate on holdout with best params
        metrics = _score_site_params(
            df_holdout,
            time_col,
            prob_col,
            ytrue_col,
            tau_open_h=best_p["tau_open_h"],
            tau_closed_h=best_p["tau_closed_h"],
            gap_reset_h=gap_reset_h,
            reliability=best_p["reliability"],
            tol_hours=tol_hours,
            smooth_gt=smooth_gt,
            hyst_T_high=hyst_T_high,
            hyst_T_low=hyst_T_low,
            hyst_min_run=hyst_min_run,
            hyst_gap_reset_h=gap_reset_h,
        )
        metrics.update(
            {
                "site": holdout,
                "tau_open_h": best_p["tau_open_h"],
                "tau_closed_h": best_p["tau_closed_h"],
                "reliability": best_p["reliability"],
            }
        )
        per_site_rows.append(metrics)

    per_site = pd.DataFrame(per_site_rows).sort_values("site")

    # micro
    tp = per_site["tp"].sum()
    fp = per_site["fp"].sum()
    fn = per_site["fn"].sum()
    micro = {
        "precision_micro": tp / (tp + fp) if (tp + fp) else 0.0,
        "recall_micro": tp / (tp + fn) if (tp + fn) else 0.0,
        "f1_micro": (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 0.0,
    }
    macro = {
        "precision_macro": per_site["precision"].mean(),
        "recall_macro": per_site["recall"].mean(),
        "f1_macro": per_site["f1"].mean(),
    }
    return per_site, {"micro": micro, "macro": macro, "best_params": best_params_per_site}

In [None]:
df = tdf.sort_values(["region", "acquired"])  # needs ['region','acquired','y_prob','y_true']

per_site, summary = evaluate_hmm_loso(
    df,
    site_col="region",
    time_col="acquired",
    prob_col="y_prob",
    ytrue_col="y_true",
    param_grid=None,  # use the default small grid above
    gap_reset_h=75,  # your 2–3 day gap rule
    tol_hours=75,  # event matching tolerance
    smooth_gt=False,  # smooth GT to suppress 1-frame glitches
    hyst_T_high=0.65,
    hyst_T_low=0.45,
    hyst_min_run=2,
)

print(per_site)
print(summary)

In [None]:
tdf = tdf.sort_values(["region", "acquired"]).reset_index(drop=True)
for region in tdf.region.unique():
    g = tdf[tdf["region"] == region]
    changes = extract_changes(g.acquired, g.y_true.to_numpy())

    events = extract_changes(g.acquired, g.y_pred.to_numpy())

    res = hungarian_match_events_with_tolerance(changes, events)
    tp = len(res[0])
    fp = len(res[1])
    fn = len(res[2])

    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0

    print(
        region,
        round(
            {
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "precision": prec,
                "recall": rec,
                "f1": f1,
            }["f1"],
            3,
        ),
    )

In [None]:
per_site

In [None]:
changes = extract_changes(rdf.acquired, rdf.y_true.to_numpy())
print(len(changes), len([c for c in changes if c[1] > pd.Timedelta(days=5)]))

changes2 = extract_changes(rdf.acquired, rdf.rolling_y_true.to_numpy())
len(changes2), len([c for c in changes2 if c[1] > pd.Timedelta(days=5)])

In [None]:
# add near the imports
import math

from scipy.optimize import linear_sum_assignment


def hungarian_match_events_with_tolerance(
    gt: list[tuple[pd.Timestamp, pd.Timedelta, int]],
    pr: list[tuple[pd.Timestamp, pd.Timedelta, int]],
    max_hours: float = 52.0,
):
    """
    Optimal 1-to-1 matching of change events using the Hungarian algorithm.

    Inputs (timestamp, delta, new_state) tuples.
    - Pairs farther than `max_hours` are disallowed.

    Cost = absolute time difference in HOURS. We minimize total cost.

    Returns:
        list tp pairs, fps, fns.
    """
    G, P = len(gt), len(pr)
    if G == 0 and P == 0:
        return [], [], []
    if G == 0:
        return [], pr, []
    if P == 0:
        return [], [], gt

    # --- build cost matrix (G x P) in HOURS ---
    BIG = 1e12  # "infinite" cost to forbid a pairing
    cost = np.full((G, P), BIG, dtype=float)

    for i, (gt_t, _, gt_s) in enumerate(gt):
        for j, (pr_t, _, pr_s) in enumerate(pr):
            # direction constraint
            if gt_s != pr_s:
                continue  # leave as BIG (forbidden)
            # time window constraint
            dhours = abs((pr_t - gt_t).total_seconds()) / 3600.0
            if dhours <= max_hours:
                cost[i, j] = dhours  # smaller is better

    # --- Hungarian solve ---
    row_ind, col_ind = linear_sum_assignment(cost)

    unused_g = [True] * G
    unused_p = [True] * P

    tp = []
    # collect feasible matches (cost < BIG)
    pairs: list[tuple[pd.Timestamp, pd.Timestamp, pd.Timedelta]] = []
    for r, c in zip(row_ind, col_ind, strict=False):
        if cost[r, c] >= BIG:  # forbidden assignments are ignored
            continue
        gt_t = gt[r]
        pr_t = pr[c]
        tp.append((gt_t, pr_t))
        unused_g[r] = False
        unused_p[c] = False
    fp = [a for unused, a in zip(unused_p, pr, strict=False) if unused]
    fn = [a for unused, a in zip(unused_g, gt, strict=False) if unused]

    return tp, fp, fn


tp, fp, fn = hungarian_match_events_with_tolerance(changes, changes2)

len(tp), len(fp), len(fn)

In [None]:
iii = iter(fp)

In [None]:
ts, td, s = next(iii)

print(td)
print(s)

delta = pd.Timedelta(hours=4 * 24 + 3)
start = ts - delta
end = ts + delta
cols = rdf[(rdf.acquired > start) & (rdf.acquired < end)]
fig, axes = plt.subplots(nrows=1, ncols=len(cols), figsize=(5 * len(cols), 5))

if len(cols) == 1:
    axes = [axes]

for ii, ax in zip(range(len(cols)), axes, strict=False):
    row = cols.iloc[ii]
    ax.set_axis_off()

    with rasterio.open(row.source_tif) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    ax.imshow(img)
    ax.set_title(row.orig_label + " " + str(round(row.rolling_y_prob_true, 3)))

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd


# Each input is a list of (timestamp, state_int) pairs
#   gt_tp  : ground-truth events that were matched (TP)
#   gt_fn  : ground-truth events that were missed (FN)
#   pr_tp  : predicted events that were matched (TP)
#   pr_fp  : predicted events that were extra (FP)
def plot_change_events_timeline(
    gt_tp: list[tuple[pd.Timestamp, int]],
    gt_fn: list[tuple[pd.Timestamp, int]],
    pr_tp: list[tuple[pd.Timestamp, int]],
    pr_fp: list[tuple[pd.Timestamp, int]],
    title: str = "Change Events Timeline (0=closed, 1=open)",
    ymin: float = -0.2,
    ymax: float = 1.2,
    gt_offset: float = +0.04,  # vertical offset so GT and PRED don't overlap exactly
    pr_offset: float = -0.04,
    figsize=(11, 3.8),
):
    def _split_xy(events, offset=0.0):
        if not events:
            return [], []
        xs = [pd.to_datetime(t) for (t, _, s) in events]
        ys = [int(s) + offset for (t, _, s) in events]
        return xs, ys

    fig, ax = plt.subplots(figsize=figsize)

    # Base guide lines for states
    ax.axhline(0, color="0.85", lw=1)
    ax.axhline(1, color="0.85", lw=1)

    # Unpack to X/Y
    gt_tp_x, gt_tp_y = _split_xy(gt_tp, gt_offset)
    gt_fn_x, gt_fn_y = _split_xy(gt_fn, gt_offset)
    pr_tp_x, pr_tp_y = _split_xy(pr_tp, pr_offset)
    pr_fp_x, pr_fp_y = _split_xy(pr_fp, pr_offset)

    # Plot
    # GT events: triangles
    if gt_tp_x:
        ax.scatter(gt_tp_x, gt_tp_y, marker="^", s=60, color="green", label="GT TP")
    if gt_fn_x:
        ax.scatter(
            gt_fn_x,
            gt_fn_y,
            marker="^",
            s=60,
            color="yellow",
            edgecolor="k",
            linewidth=0.5,
            label="GT FN",
        )

    # PRED events: circles
    if pr_tp_x:
        ax.scatter(pr_tp_x, pr_tp_y, marker="o", s=50, color="green", alpha=0.9, label="Pred TP")
    if pr_fp_x:
        ax.scatter(pr_fp_x, pr_fp_y, marker="o", s=50, color="red", alpha=0.9, label="Pred FP")

    # Axis cosmetics
    ax.set_ylim(ymin, ymax)
    ax.set_yticks([0, 1])
    ax.set_yticklabels(["closed (0)", "open (1)"])
    ax.set_title(title)
    ax.set_xlabel("Time")
    ax.set_ylabel("State")

    # Nice date ticks
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()))

    # # Build legend only for shown handles
    # handles, labels = ax.get_legend_handles_labels()
    # if handles:
    #     ax.legend(loc="middle left", ncols=2, frameon=False)

    plt.tight_layout()
    plt.show()


plot_change_events_timeline([a[0] for a in tp], fn, [a[1] for a in tp], fp)