In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import math
from pathlib import Path

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
import seaborn as sns
import torch
import tqdm
from PIL import Image
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)
from torch.utils.data import DataLoader

from estuary.model.data import EstuaryDataModule, EstuaryDataset, _load_labels
from estuary.model.module import EstuaryModule
from estuary.util import broad_band, false_color

In [None]:
module = EstuaryModule.load_from_checkpoint(
    "/Users/kyledorman/data/results/estuary/train/20251008-151833/checkpoints/epoch=15-step=704.ckpt",
    accelerator="cpu",
    workers=0,
    persistent_workers=False,
    batch_size=1,
    prefetch_factor=0,
    strict=False,
)
module = module.eval()

In [None]:
dm = EstuaryDataModule(module.conf)
dm.prepare_data()
dm.setup()

In [None]:
dl = dm.val_dataloader()

y_prob = []
y_true = []
for batch in tqdm.tqdm(dl, total=len(dl)):
    batch = dm.val_aug(batch)
    for k in batch.keys():
        if isinstance(batch[k], list):
            continue
        batch[k] = batch[k].to(module.device)
    logits = module.forward(batch)
    probs_pos = torch.sigmoid(logits)
    y_prob.extend(probs_pos.detach().cpu().numpy()[:, 0])
    y_true.extend(batch["label"].detach().cpu().numpy().tolist())

thresholds = np.linspace(0, 1, 101)  # e.g., 0.00, 0.01, ..., 1.00
scores = []

for t in thresholds:
    y_pred = (y_prob >= t).astype(int)
    f1 = f1_score(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0.0)
    rec = recall_score(y_true, y_pred)
    scores.append((t, f1, acc, prec, rec))

# put into DataFrame for analysis
df_scores = pd.DataFrame(scores, columns=["threshold", "f1", "accuracy", "precision", "recall"])

# best threshold by F1
best_f1_row = df_scores.loc[df_scores["f1"].idxmax()]
print(best_f1_row)

In [None]:
plt.hist(y_prob)

In [None]:
df_scores.plot.scatter(x="threshold", y="f1")

In [None]:
df_scores[df_scores.threshold == 0.5]

In [None]:
results_df = []

In [None]:
dl = dm.val_dataloader()

y_prob = []
y_true = []
for batch in tqdm.tqdm(dl, total=len(dl)):
    batch = dm.val_aug(batch)
    for k in batch.keys():
        if isinstance(batch[k], list):
            continue
        batch[k] = batch[k].to(module.device)
    logits = module.forward(batch)
    probs_pos = torch.sigmoid(logits)
    y_prob.extend(probs_pos.detach().cpu().numpy()[:, 0])
    y_true.extend(batch["label"].detach().cpu().numpy().tolist())

    for i in range(len(probs_pos)):
        results_df.append(
            {
                "source_tif": batch["source_tif"][i],
                "y_true": batch["label"][i].detach().cpu().numpy(),
                "y_prob": probs_pos[i].detach().cpu().numpy()[0],
                "y_pred": (probs_pos[i] > 0.5).to(torch.int32).detach().cpu().numpy()[0],
                "region": int(Path(batch["source_tif"][i]).parents[2].name),
                "dataset": "val",
            }
        )

accuracy_score(y_true, np.int32(np.array(y_prob) > 0.5))

In [None]:
dl = dm.test_dataloader()

y_prob = []
y_true = []
for batch in tqdm.tqdm(dl, total=len(dl)):
    batch = dm.val_aug(batch)
    for k in batch.keys():
        if isinstance(batch[k], list):
            continue
        batch[k] = batch[k].to(module.device)
    logits = module.forward(batch)
    probs_pos = torch.sigmoid(logits)
    y_prob.extend(probs_pos.detach().cpu().numpy()[:, 0])
    y_true.extend(batch["label"].detach().cpu().numpy().tolist())

    for i in range(len(probs_pos)):
        results_df.append(
            {
                "source_tif": batch["source_tif"][i],
                "y_true": batch["label"][i].detach().cpu().numpy(),
                "y_prob": probs_pos[i].detach().cpu().numpy()[0],
                "y_pred": (probs_pos[i] > 0.5).to(torch.int32).detach().cpu().numpy()[0],
                "region": int(Path(batch["source_tif"][i]).parents[2].name),
                "dataset": "test",
            }
        )

accuracy_score(y_true, np.int32(np.array(y_prob) > 0.5))

In [None]:
dl = dm.train_dataloader()

y_prob = []
y_true = []
for batch in tqdm.tqdm(dl, total=len(dl)):
    batch = dm.val_aug(batch)
    for k in batch.keys():
        if isinstance(batch[k], list):
            continue
        batch[k] = batch[k].to(module.device)
    logits = module.forward(batch)
    probs_pos = torch.sigmoid(logits)
    y_prob.extend(probs_pos.detach().cpu().numpy()[:, 0])
    y_true.extend(batch["label"].detach().cpu().numpy().tolist())

    for i in range(len(probs_pos)):
        results_df.append(
            {
                "source_tif": batch["source_tif"][i],
                "y_true": batch["label"][i].detach().cpu().numpy(),
                "y_prob": probs_pos[i].detach().cpu().numpy()[0],
                "y_pred": (probs_pos[i] > 0.5).to(torch.int32).detach().cpu().numpy()[0],
                "region": int(Path(batch["source_tif"][i]).parents[2].name),
                "dataset": "train",
            }
        )

accuracy_score(y_true, np.int32(np.array(y_prob) > 0.5))

In [None]:
results_df = pd.DataFrame(results_df)
results_df["correct"] = results_df.y_true == results_df.y_pred

results_df.head(5)

In [None]:
results_df = pd.read_csv("/Users/kyledorman/data/results/estuary/train/20251021-151419/preds.csv")

results_df.head()

In [None]:
# Group by region and compute accuracy
acc_by_region = results_df.groupby(["region", "dataset"]).correct.mean()
open_pct = results_df.groupby(["region", "dataset"]).y_true.mean().astype(np.float32)

# Combine into a DataFrame for display
acc_df = pd.DataFrame(
    {
        "accuracy": acc_by_region,
        "open_pct": open_pct,
    }
).reset_index()

# Show the result
region_stats = acc_df.set_index(["region", "dataset"]).sort_values(by="accuracy")

region_stats.head(5).round(2)

In [None]:
region_stats[region_stats.accuracy < 0.9].round(2)

In [None]:
# df has columns: accuracy, open_pct, dataset
plt.figure(figsize=(6, 4))
sns.scatterplot(
    data=region_stats,
    x="open_pct",
    y="accuracy",
    hue="dataset",  # color by dataset
    palette="Set1",  # you can pick another palette
    s=20,  # point size
)
plt.title("Accuracy vs. Open % by dataset")
plt.show()

In [None]:
N = 5
best_counts = region_stats.tail(N)
worse_counts = region_stats.head(N)

display(worse_counts.round(2))

In [None]:
gdf = gpd.read_file("/Users/kyledorman/data/estuary/geos/ca_data_w_usgs.geojson")
gdf[gdf["Site code"].isin(region_stats[region_stats.dataset == "test"].index)]

In [None]:
def image_iter(df, dataset, count):
    fdf = df[(df["dataset"] == dataset) & ~df.correct]
    for region, gdf in fdf.groupby("region"):
        group = []
        for _, row in gdf.iterrows():
            group.append((row.source_tif, row.y_true))
            if len(group) == count:
                yield region, group
                group = []
        if len(group):
            yield region, group


val_iter = image_iter(results_df, "val", 6)
test_iter = image_iter(results_df, "test", 6)
train_iter = image_iter(results_df, "train", 6)

In [None]:
region, images = next(val_iter)

assert len(images), region

cols = min(len(images), 3)
rows = max(1, min(math.ceil(len(images) // 2), 2))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(5 * cols, 5 * rows))

if len(images) == 1:
    axs = [[axs]]
elif rows == 1:
    axs = [axs]
axs = [ax for axx in axs for ax in axx]
for (source_tif, y_true), ax in zip(images, axs, strict=False):
    ax.set_axis_off()

    with rasterio.open(source_tif) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    label = "Closed" if y_true == 0 else "Open"
    ax.imshow(img)
    ax.set_title(f"Region: {region} Label: {label}")

plt.tight_layout()
plt.show()

In [None]:
region, images = next(test_iter)

assert len(images), region

cols = min(len(images), 3)
rows = max(1, min(math.ceil(len(images) // 2), 2))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(5 * cols, 5 * rows))

if len(images) == 1:
    axs = [[axs]]
elif rows == 1:
    axs = [axs]
axs = [ax for axx in axs for ax in axx]
for (source_tif, y_true), ax in zip(images, axs, strict=False):
    with rasterio.open(source_tif) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    label = "Closed" if y_true == 0 else "Open"
    ax.imshow(img)
    ax.set_title(f"Region: {region} Label: {label}")

plt.tight_layout()
plt.show()

In [None]:
region, images = next(train_iter)

assert len(images), region

cols = min(len(images), 3)
rows = max(1, min(math.ceil(len(images) // 2), 2))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(5 * cols, 5 * rows))

if len(images) == 1:
    axs = [[axs]]
elif rows == 1:
    axs = [axs]
axs = [ax for axx in axs for ax in axx]
for (source_tif, y_true), ax in zip(images, axs, strict=False):
    with rasterio.open(source_tif) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    label = "Closed" if y_true == 0 else "Open"
    ax.imshow(img)
    ax.set_title(f"Region: {region} Label: {label}")

plt.tight_layout()
plt.show()

In [None]:
all_labels = _load_labels(module.conf.classes, module.conf.data)
results_df = pd.merge(
    results_df, all_labels[["source_tif", "orig_label"]], on="source_tif", how="left"
)

In [None]:
# for _, row in tqdm.tqdm(
#     results_df[~results_df.correct].iterrows(), total=(~results_df.correct).sum()
# ):
#     label = "_".join(row.orig_label.split(" "))
#     with rasterio.open(row.source_tif) as src:
#         data = src.read(out_dtype=np.float32)
#         nodata = src.read(1, masked=True).mask
#         if len(data) == 4:
#             img = false_color(data, nodata)
#         else:
#             img = broad_band(data, nodata)
#         img = Image.fromarray(img)

#     save_dir = Path("/Volumes/x10pro/estuary/ca_all/inspect_all_sites/") / label
#     save_dir.mkdir(exist_ok=True, parents=True)
#     img.save(save_dir / f"{Path(row.source_tif).stem}.jpg")

In [None]:
ca_all_labels = _load_labels(
    module.conf.classes, "/Volumes/x10pro/estuary/ca_all/dove/labeling/labels.csv"
)
ca_results_list = []

ds = EstuaryDataset(
    df=ca_all_labels,
    conf=module.conf,
    train=False,
)
dl = DataLoader(
    ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

y_prob = []
y_true = []
for batch in tqdm.tqdm(dl, total=len(dl)):
    batch = ds.transforms(batch)
    for k in batch.keys():
        if isinstance(batch[k], list):
            continue
        batch[k] = batch[k].to(module.device)
    logits = module.forward(batch)
    probs_pos = torch.sigmoid(logits)
    y_prob.extend(probs_pos.detach().cpu().numpy()[:, 0])
    y_true.extend(batch["label"].detach().cpu().numpy().tolist())

    for i in range(len(probs_pos)):
        ca_results_list.append(
            {
                "source_tif": batch["source_tif"][i],
                "y_true": batch["label"][i].detach().cpu().numpy(),
                "y_prob": probs_pos[i].detach().cpu().numpy()[0],
                "y_pred": (probs_pos[i] > 0.5).to(torch.int32).detach().cpu().numpy()[0],
                "region": int(Path(batch["source_tif"][i]).parents[1].name),
                "dataset": "train",
            }
        )

accuracy_score(y_true, np.int32(np.array(y_prob) > 0.5))

In [None]:
ca_results_df = pd.DataFrame(ca_results_list)
ca_results_df = pd.merge(
    ca_results_df, ca_all_labels[["source_tif", "orig_label"]], on="source_tif", how="left"
)
ca_results_df["correct"] = ca_results_df.y_true == ca_results_df.y_pred
ca_results_df.head()

In [None]:
ca_results_df = pd.read_csv(
    "/Users/kyledorman/data/results/estuary/train/20251021-151419/timeseries_preds.csv"
)

In [None]:
ca_results_df.groupby("orig_label").correct.mean()

In [None]:
ca_results_df.groupby(["region", "orig_label"]).correct.mean()