In [None]:
import json
from pathlib import Path

import folium
import geopandas as gpd
import pandas as pd
import rasterio
from matplotlib import pyplot as plt
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score

from estuary.util import broad_band, false_color

In [None]:
BASE = Path("/Volumes/x10pro/estuary/")

In [None]:
skipped_regions = pd.read_csv(BASE / "geos/skipped_regions.csv")["Site code"].to_list()

In [None]:
gdf = gpd.read_file(BASE / "geos/ca_data_w_usgs.geojson")
gdf = gdf[~gdf["Site code"].isin(skipped_regions)].copy()
gdf = gdf.set_index("Site code")
gdf.head()

In [None]:
rect_df = []

for pth in Path(BASE / "ca_grids").iterdir():
    gid = int(pth.stem)
    if gid in skipped_regions:
        continue
    tp_df = gpd.read_file(pth)
    geo = tp_df.iloc[0].geometry
    rect_df.append({"Site code": gid, "geometry": geo, "Site name": gdf.loc[gid]["Site name"]})

rect_df = gpd.GeoDataFrame(rect_df, geometry="geometry", crs=tp_df.crs)
rect_df.head()

In [None]:
with open(BASE / "geos/ca_empa_matching_sites.json") as f:
    matching_sites = json.load(f)

revmatching_sites = {v: k for k, v in matching_sites.items()}

matching_sites

In [None]:
ss_labels = pd.read_csv("/Volumes/x10pro/estuary/skysat/labels.csv")
ss_labels["acquired"] = pd.to_datetime(ss_labels["acquired"], errors="coerce").dt.tz_localize(
    "UTC"
)  # interpret naive times as already UTC
ss_labels["year"] = ss_labels.acquired.dt.year
ss_labels["month"] = ss_labels.acquired.dt.month
ss_labels = ss_labels[~ss_labels.region.isin(skipped_regions)].copy()
ss_labels.head()

In [None]:
# --- Params ---
nth = 6  # show every nth month label

# If you already have 'year' and 'month' cols:
period = pd.PeriodIndex(year=ss_labels["year"], month=ss_labels["month"], freq="M")

# Counts per month (sorted)
counts = period.value_counts().sort_index()

# Fill missing months between start and end
full = pd.period_range(counts.index.min(), counts.index.max(), freq="M")
counts = counts.reindex(full, fill_value=0)

# Plot
ax = counts.plot(kind="bar", figsize=(12, 6))
ax.set_xlabel("Year–Month")
ax.set_ylabel("Count")
ax.set_title("Counts per Month")

# Only label every nth month for readability
labels = [p.strftime("%Y-%m") for p in counts.index]
for i, label in enumerate(ax.get_xticklabels()):
    label.set_visible(i % nth == 0)

plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
empa_labels = pd.read_csv(BASE / "geos" / "empa_labels.csv")
empa_labels["acquired"] = pd.to_datetime(empa_labels["acquired"], errors="coerce")
empa_labels.head()

In [None]:
ss_empa_labels = pd.concat([ss_labels, empa_labels])
ss_empa_labels.head()

In [None]:
# dove = []
# for base in [
#     Path("/Volumes/x10pro/estuary/dove/results"),
#     Path("/Volumes/x10pro/estuary/superdove/results"),
# ]:
#     for pth in base.glob("*/*/*/images_to_download.csv"):
#         df = pd.read_csv(pth).drop(columns=["ordered_idx"])
#         df["region"] = int(pth.parent.name)
#         df["capture_datetime"] = pd.to_datetime(
#             df["capture_datetime"], errors="coerce"
#         ).dt.tz_localize("UTC")  # interpret naive times as already UTC
#         df["instrument"] = base.parent.name
#         dove.append(df)

# dove_df = pd.concat(dove)
# dove_df = dove_df[dove_df.include_image]

# print(len(dove_df))

# dove_df.head()

dove_df = pd.read_csv(BASE / "dove" / "labels.csv")
dove_df["acquired"] = pd.to_datetime(dove_df["acquired"], errors="coerce").dt.tz_localize("UTC")
dove_df = dove_df.rename(columns={"acquired": "capture_datetime"})
dove_df["year"] = dove_df.capture_datetime.dt.year
dove_df["month"] = dove_df.capture_datetime.dt.month
dove_df.head()

In [None]:
# # --- Params ---
# nth = 6  # show every nth month label

# # If you already have 'year' and 'month' cols:
# period = pd.PeriodIndex.from_fields(year=dove_df["year"], month=dove_df["month"], freq="M")

# # Counts per month (sorted)
# counts = period.value_counts().sort_index()

# # Fill missing months between start and end
# full = pd.period_range(counts.index.min(), counts.index.max(), freq="M")
# counts = counts.reindex(full, fill_value=0)

# # Plot
# ax = counts.plot(kind="bar", figsize=(12, 6))
# ax.set_xlabel("Year–Month")
# ax.set_ylabel("Count")
# ax.set_title("Counts per Month")

# # Only label every nth month for readability
# labels = [p.strftime("%Y-%m") for p in counts.index]
# for i, label in enumerate(ax.get_xticklabels()):
#     label.set_visible(i % nth == 0)

# plt.xticks(rotation=45, ha="right")
# plt.tight_layout()
# plt.show()

In [None]:
ss_empa_labels = (
    ss_empa_labels[ss_empa_labels.label != "unsure"].sort_values("acquired").reset_index(drop=True)
)
dove_df = dove_df.sort_values("capture_datetime").reset_index(drop=True)

tol = pd.Timedelta("14h")

# cross-join within region, then filter by window
tmp = dove_df.merge(ss_empa_labels, on="region", suffixes=("_dd", "_ss"))
mask = (tmp["acquired"] >= tmp["capture_datetime"] - tol) & (
    tmp["acquired"] <= tmp["capture_datetime"] + tol
)
pairs = tmp.loc[mask].sort_values(["region", "capture_datetime", "acquired"])

pairs = pairs.drop(columns=["year_ss", "month_ss"]).rename(
    columns={"year_dd": "year", "month_dd": "month"}
)
# pairs = pd.merge_asof(
#     dove_df,
#     ss_empa_labels,
#     by="region",
#     right_on="acquired",
#     left_on="capture_datetime",
#     direction="nearest",
#     tolerance=pd.Timedelta("28h"),
#     suffixes=("_dd", "_ss"),
# )
# pairs = pairs[~pairs.source_tif.isna()]
len(pairs)

In [None]:
pairs.head()

In [None]:
pairs_dedup = pairs.sort_values(
    by=["region", "year", "month", "instrument_dd", "label_ss", "instrument_ss"],
    ascending=[True, True, True, False, True, True],
).drop_duplicates(["region", "year", "month", "label_ss"])

len(pairs_dedup)

In [None]:
len(pairs[pairs.instrument_ss == "empa"])

In [None]:
# --- Params ---
nth = 6  # show every nth month label

# If you already have 'year' and 'month' cols:
period = pd.PeriodIndex.from_fields(year=pairs["year"], month=pairs["month"], freq="M")

# Counts per month (sorted)
counts = period.value_counts().sort_index()

# Fill missing months between start and end
full = pd.period_range(counts.index.min(), counts.index.max(), freq="M")
counts = counts.reindex(full, fill_value=0)

# Plot
ax = counts.plot(kind="bar", figsize=(12, 6))
ax.set_xlabel("Year–Month")
ax.set_ylabel("Count")
ax.set_title("Counts per Month")

# Only label every nth month for readability
labels = [p.strftime("%Y-%m") for p in counts.index]
for i, label in enumerate(ax.get_xticklabels()):
    label.set_visible(i % nth == 0)

plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
missed_regions = set(rect_df["Site code"].unique()) - set(pairs.region.unique())

print("num missed regions", len(missed_regions))

# Group by region and label, then count rows
counts = pairs.groupby(["region", "label_ss"]).size().unstack(fill_value=0)

all_regions = counts.index.union(missed_regions)

counts = counts.reindex(all_regions, fill_value=0)

# Plot stacked bar chart
counts.plot(kind="bar", stacked=True, figsize=(10, 6))

plt.xlabel("Region")
plt.ylabel("Count")
plt.title("Counts per Region grouped by Label")
plt.xticks(rotation=45, ha="right")
plt.legend(title="Label")
plt.tight_layout()
plt.show()

In [None]:
counts = pairs["label_ss"].value_counts()

# Plot as bar chart
counts.plot(kind="bar", figsize=(10, 6))
plt.xlabel("Label")
plt.ylabel("Count")
plt.title("Number of Records per Label")
plt.tight_layout()
plt.show()

In [None]:
counts = pairs["instrument_dd"].value_counts()

# Plot as bar chart
counts.plot(kind="bar", figsize=(10, 6))
plt.xlabel("Satellite")
plt.ylabel("Count")
plt.title("Number of Records per Satellite Type")
plt.tight_layout()
plt.show()

In [None]:
pairs["instrument_ss"].value_counts()

In [None]:
# to_download = []
# for (region, month, year), df in dove_df.groupby(["region", "month", "year"]):
#     super_label_df = pairs_dedup[
#         (pairs_dedup.region == region) & (pairs_dedup.year == year) & (pairs_dedup.month == month)
#     ]
#     if len(super_label_df):
#         to_download.append(super_label_df[["region", "year", "month", "asset_id", "dove"]])
#     else:
#         to_download.append(df[["region", "year", "month", "asset_id", "dove"]].iloc[0:1])

# to_download = (
#     pd.concat(to_download)
#     .sort_values(["region", "year", "month", "dove", "asset_id"])
#     .reset_index(drop=True)
# )

# print(len(to_download))
# to_download.head()

In [None]:
# --- Params ---
nth = 6  # show every nth month label

# If you already have 'year' and 'month' cols:
# If you already have capture_datetime, easiest is to make a period column
to_download["period"] = pd.PeriodIndex.from_fields(
    year=to_download["year"], month=to_download["month"], freq="M"
)

# Group by month + dove, then unstack to make dove the stacked key
counts = to_download.groupby(["period", "dove"]).size().unstack(fill_value=0).sort_index()

# Fill in missing months across the full range
full = pd.period_range(counts.index.min(), counts.index.max(), freq="M")
counts = counts.reindex(full, fill_value=0)

# Plot
ax = counts.plot(kind="bar", figsize=(12, 6))
ax.set_xlabel("Year–Month")
ax.set_ylabel("Count")
ax.set_title("Counts per Month (stacked by Dove)")

# Sparse tick labels
for i, label in enumerate(ax.get_xticklabels()):
    label.set_visible(i % nth == 0)

plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
missed_regions = set(rect_df["Site code"].unique()) - set(to_download.region.unique())

print("num missed regions", len(missed_regions))

# Group by region and label, then count rows
counts = to_download.region.value_counts()

all_regions = counts.index.union(missed_regions)

counts = counts.reindex(all_regions, fill_value=0).sort_index()

# Plot stacked bar chart
counts.plot(kind="bar", figsize=(10, 6))

plt.xlabel("Region")
plt.ylabel("Count")
plt.title("Counts per Region")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
missed_regions = set(rect_df["Site code"].unique()) - set(dove_df.region.unique())

print("num missed regions", len(missed_regions))

# Group by region and label, then count rows
counts = dove_df.region.value_counts()

all_regions = counts.index.union(missed_regions)

counts = counts.reindex(all_regions, fill_value=0).sort_index()

# Plot stacked bar chart
counts.plot(kind="bar", figsize=(10, 6))

plt.xlabel("Region")
plt.ylabel("Count")
plt.title("Counts per Region")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
# base = Path("/Volumes/x10pro/estuary/")
# for (region, year, month, dove), df in to_download.groupby(["region", "year", "month", "dove"]):
#     df = df[["asset_id"]].copy()
#     if not len(df):
#         continue
#     df["include_image"] = True

#     save_path = base / dove / "results" / str(year) / str(month) / str(region) / "subset_images_to_download.csv"
#     assert save_path.parent.exists()
#     df.to_csv(save_path)


In [None]:
d_df = pd.read_csv("/Volumes/x10pro/estuary/dove/labels.csv")
d_df["acquired"] = pd.to_datetime(d_df["acquired"], errors="coerce").dt.tz_localize(
    "UTC"
)  # interpret naive times as already UTC
d_df["year"] = d_df.acquired.dt.year
d_df["month"] = d_df.acquired.dt.month
d_df = d_df[d_df["label"] != "unsure"].copy()
d_df = d_df[~d_df.region.isin(skipped_regions)].copy()

print("Num Regions:", len(d_df.region.unique()), "Num images:", len(d_df))
d_df.head()

In [None]:
missed_regions = set(rect_df["Site code"].unique()) - set(d_df.region.unique())

print("num missed regions", len(missed_regions))

# Group by region and label, then count rows
counts = d_df.groupby(["region", "label"]).size().unstack(fill_value=0)

all_regions = counts.index.union(missed_regions)

counts = counts.reindex(all_regions, fill_value=0)

# Plot stacked bar chart
counts.plot(kind="bar", stacked=True, figsize=(10, 6))

plt.xlabel("Region")
plt.ylabel("Count")
plt.title("Dove Label Counts per Region grouped by Label")
plt.xticks(rotation=45, ha="right")
plt.legend(title="Label")
plt.tight_layout()
plt.show()

In [None]:
counts = d_df.label.value_counts()

# Plot as bar chart
counts.plot(kind="bar", figsize=(5, 4))
plt.xlabel("Label")
plt.ylabel("Count")
plt.title("Count per Label type")
plt.tight_layout()
plt.show()

In [None]:
ss_empa_labels[(~ss_empa_labels.source_tif.isna())].region.unique().size

In [None]:
tmp = d_df.merge(ss_empa_labels, on="region", suffixes=("_dd", "_ss"))
tol = pd.Timedelta("14h")
mask = (tmp["acquired_dd"] >= tmp["acquired_ss"] - tol) & (
    tmp["acquired_dd"] <= tmp["acquired_ss"] + tol
)
dove_pairs = tmp.loc[mask].sort_values(["region", "acquired_dd", "acquired_ss"])

dove_pairs = dove_pairs.drop(columns=["year_ss", "month_ss"]).rename(
    columns={"year_dd": "year", "month_dd": "month"}
)

print(len(dove_pairs))

dove_pairs.head()

In [None]:
labels = ["closed", "perched open", "open"]
aaa = dove_pairs[dove_pairs.instrument_ss != "empa"]
y_true = [labels.index(l) for l in aaa.label_ss]
y_pred = [labels.index(l) for l in aaa.label_dd]

print("Accuracy", round(100 * accuracy_score(y_true, y_pred), 1))

ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=labels)

In [None]:
labels = ["closed", "open"]
aaa = dove_pairs[dove_pairs.instrument_ss != "empa"]
y_true = [int("open" in l) for l in aaa.label_ss]
y_pred = [int("open" in l) for l in aaa.label_dd]

print("Accuracy", round(100 * accuracy_score(y_true, y_pred), 1))

ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=labels)

In [None]:
aaa = dove_pairs[dove_pairs.instrument_ss == "empa"]
labels = ["closed", "open"]
y_true = [int("open" in l) for l in aaa.label_ss]
y_pred = [int("open" in l) for l in aaa.label_dd]

print("Accuracy", round(100 * accuracy_score(y_true, y_pred), 1))

ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=labels)

In [None]:
import numpy as np

empa_closed = dove_pairs[
    (dove_pairs.instrument_ss == "empa")
    & (dove_pairs.label_dd != "closed")
    & (dove_pairs.label_ss == "closed")
]
empa_open = dove_pairs[
    (dove_pairs.instrument_ss == "empa")
    & (dove_pairs.label_dd == "closed")
    & (dove_pairs.label_ss != "closed")
]

fig, axes = plt.subplots(ncols=2, nrows=len(empa_open), figsize=(10, 5 * len(empa_open)))
if len(axes.shape) == 1:
    axes = [axes]

for i, (_, row) in enumerate(empa_closed.iterrows()):
    with rasterio.open(row.source_tif_dd) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    axes[i][0].axis("off")
    axes[i][0].imshow(img)
    if i == 0:
        axes[i][0].set_title("close")

for i, (_, row) in enumerate(empa_open.iterrows()):
    with rasterio.open(row.source_tif_dd) as src:
        data = src.read(out_dtype=np.float32)
        nodata = src.read(1, masked=True).mask
        if len(data) == 4:
            img = false_color(data, nodata)
        else:
            img = broad_band(data, nodata)
        img = Image.fromarray(img)
    axes[i][1].axis("off")
    axes[i][1].imshow(img)
    if i == 0:
        axes[i][1].set_title("open")

fig.tight_layout()
plt.show()

In [None]:
# from PIL import ImageFont, ImageDraw, Image

# inspect = dove_pairs[
#     ((
#         (dove_pairs.label_ss == "closed") &
#         (dove_pairs.label_dd != "closed")
#     ) |
#     (
#         (dove_pairs.label_ss != "closed") &
#         (dove_pairs.label_dd == "closed")
#     )) &
#     (dove_pairs.instrument_ss == "skysat")
# ]

# save_path = Path(BASE / "inspect_ss_dd")
# save_path.mkdir(exist_ok=True, parents=True)
# for _, row in inspect.iterrows():
#     region = row.region
#     year = row.year
#     month = row.month

#     with rasterio.open(row.source_tif_dd) as src:
#         data = src.read(out_dtype=np.float32)
#         nodata = src.read(1, masked=True).mask
#         if len(data) == 4:
#             img = false_color(data, nodata)
#         else:
#             img = broad_band(data, nodata)
#         dd_img = Image.fromarray(img)

#     with rasterio.open(row.source_tif_ss) as src:
#         data = src.read(out_dtype=np.float32)
#         nodata = src.read(1, masked=True).mask
#         if len(data) == 4:
#             img = false_color(data, nodata)
#         else:
#             img = broad_band(data, nodata)
#         ss_img = Image.fromarray(img)

#     dd_img = dd_img.resize(ss_img.size)

#     # You can load a TTF font; fallback to default
#     try:
#         font = ImageFont.truetype("arial.ttf", 72)
#     except IOError:
#         font = ImageFont.load_default(72)

#     # Top-right for first image
#     # Draw text (top-right corner of each)
#     draw = ImageDraw.Draw(ss_img)
#     text1 = "-".join(row.label_ss.split(" "))
#     w1 = draw.textlength(text1, font=font)
#     draw.text((ss_img.width - w1 - 5, 5), text1, fill="white", font=font)

#     # Top-right for second image
#     draw = ImageDraw.Draw(dd_img)
#     text2 = "-".join(row.label_dd.split(" "))
#     w2 = draw.textlength(text2, font=font)
#     draw.text((dd_img.width - w2 - 5, 5), text2, fill="white", font=font)

#     # Concatenate horizontally
#     total_width = ss_img.width + ss_img.width
#     max_height = max(ss_img.height, ss_img.height)
#     new_im = Image.new("RGB", (total_width, max_height))

#     new_im.paste(ss_img, (0, 0))
#     new_im.paste(dd_img, (ss_img.width, 0))

#     # Save
#     new_im.save(save_path / f"{region}_{year}-{month}_{text1}.png")

In [None]:
gdf["empa_site_id"] = None
for region, _ in gdf.iterrows():
    gdf.loc[region, "empa_site_id"] = matching_sites.get(region)

label_pcts = d_df.groupby("region")["label"].value_counts(normalize=True).unstack(fill_value=0)
label_pcts.index.name = "Site code"

labeled_gdf = gdf.join(label_pcts)

labeled_gdf.head(5)

In [None]:
labeled_gdf[~labeled_gdf.empa_site_id.isna() & ~labeled_gdf.station_nm.isna()]

In [None]:
TEST_SITES = [72]
VAL_SITES = []
TRAIN_SITES = []

In [None]:
labeled_gdf[~labeled_gdf.empa_site_id.isna()]

In [None]:
TEST_SITES.extend([11, 18, 48, 50])
VAL_SITES.extend([21, 25, 51, 2145])
TRAIN_SITES.extend([28, 43, 84, 2161, 2162, 2163])

In [None]:
labeled_gdf[~labeled_gdf.site_no.isna()]

In [None]:
TEST_SITES.extend([15, 27])
VAL_SITES.extend([16, 56, 77])
TRAIN_SITES.extend([17, 2147, 57])

In [None]:
len(VAL_SITES)

In [None]:
remainder = set(d_df.region.unique().tolist()) - set(TEST_SITES) - set(VAL_SITES) - set(TRAIN_SITES)
TRAIN_SITES.extend(remainder)

print(len(TRAIN_SITES))

In [None]:
splits = []
for r in TRAIN_SITES:
    splits.append({"region": r, "is_train": True, "is_test": False, "is_val": False})
for r in TEST_SITES:
    splits.append({"region": r, "is_test": True, "is_train": False, "is_val": False})
for r in VAL_SITES:
    splits.append({"region": r, "is_val": True, "is_test": False, "is_train": False})
splits = pd.DataFrame(splits).sort_values("region")
splits

In [None]:
# splits.to_csv("/Volumes/x10pro/estuary/dataset/region_splits.csv", index=False)

In [None]:
len(dove_df[dove_df.dove == "dove"]), len(dove_df[dove_df.dove == "superdove"])

In [None]:
len(to_download[to_download.dove == "dove"]), len(to_download[to_download.dove == "superdove"])

In [None]:
all_images = list(
    Path("/Volumes/x10pro/estuary/ca_all/").glob("*/results/*/*/*/files/*_AnalyticMS_SR*.tif")
)
all_images[:10]

In [None]:
from estuary.model.data import parse_dt_from_pth

all_df = []
for p in all_images:
    dt = parse_dt_from_pth(p)
    region = int(p.parents[1].name)
    if region in skipped_regions:
        continue
    all_df.append(
        {
            "pth": str(p),
            "acquired": dt,
            "year": int(p.parents[3].name),
            "month": int(p.parents[2].name),
            "day": dt.day,
            "region": region,
        }
    )

all_df = pd.DataFrame(all_df)
all_df = all_df.sort_values(by=["region", "acquired"]).reset_index(drop=True)
all_df = all_df.drop_duplicates(["region", "year", "month", "day"])

print(len(all_df))

all_df.head(5)

In [None]:
all_df.region.value_counts().hist()

In [None]:
region_counts = all_df.region.value_counts().sort_values()

N = 5
largest_counts = region_counts.tail(N)
smallest_counts = region_counts.head(N)

display(smallest_counts)
display(largest_counts)

In [None]:
all_df["delta_days"] = all_df.groupby("region")["acquired"].diff().dt.total_seconds() / (24 * 3600)

# 3. Compute quantiles per region
summary = (
    all_df.groupby("region")["delta_days"]
    .quantile([0.5, 0.95, 0.99])  # 0.5 = p50, 0.95 = p95
    .unstack(level=-1)
    .rename(columns={0.5: "p50_days", 0.95: "p95_days", 0.99: "p99_days"})
)

summary.round(3).to_csv("/Users/kyledorman/data/estuary/display/ca_site_temporal.csv")

summary.round(1)

In [None]:
N = 5
# Top N worst (largest) p95
worst_p95 = summary.sort_values("p95_days", ascending=False).head(N)

# Top N worst (largest) p99
worst_p99 = summary.sort_values("p99_days", ascending=False).head(N)

display(worst_p95.round(1))
display(worst_p99.round(1))

In [None]:
import matplotlib.pyplot as plt
import numpy as np

mv = np.ceil(summary["p99_days"].max())
bins = np.linspace(0, mv, int(mv) // 2 + 1)

fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)

summary["p50_days"].hist(ax=axes[0], bins=bins, color="skyblue", edgecolor="k")
axes[0].set_title("Distribution of p50 (days)")
axes[0].set_ylabel("Number of regions")
axes[0].set_xlabel("Days between acquisitions")

summary["p95_days"].hist(ax=axes[1], bins=bins, color="salmon", edgecolor="k")
axes[1].set_title("Distribution of p95 (days)")

summary["p99_days"].hist(ax=axes[2], bins=bins, color="teal", edgecolor="k")
axes[2].set_title("Distribution of p99 (days)")

plt.tight_layout()
plt.savefig("/Users/kyledorman/data/estuary/display/ca_time_between.png")
plt.show()

In [None]:
import geopandas as gpd
from branca.colormap import linear

gdf["region"] = gdf.index

# Suppose:
# - summary has index "region" and columns including p99_days
# - gdf is a GeoDataFrame with a "region" column and Point geometries

# 2. Join stats to gdf
gdf_stats = gdf.merge(summary[["p95_days"]], on="region", how="inner")

# 3. Create folium map, centered roughly on your data
m = folium.Map(
    location=[gdf_stats.geometry.y.mean(), gdf_stats.geometry.x.mean()],
    zoom_start=6,
    tiles="cartodbpositron",
    width=600,
    height=600,
)

# 4. Set up color scale for p99 values
colormap = linear.YlOrRd_09.scale(gdf_stats["p95_days"].min(), gdf_stats["p95_days"].max())
colormap.caption = "p95 days between acquisitions"

# 5. Add points with colors
for _, row in gdf_stats.iterrows():
    if pd.notnull(row["p95_days"]):
        color = colormap(row["p95_days"])
        folium.CircleMarker(
            location=[row.geometry.y, row.geometry.x],
            radius=5,
            fill=True,
            fill_opacity=0.8,
            color=color,
            popup=f"Region: {row['region']}<br>p99: {row['p95_days']:.1f} days",
        ).add_to(m)

# 6. Add colorbar legend
colormap.add_to(m)

m