In [None]:
%load_ext autoreload

from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv("../.env")

In [None]:
%autoreload 2

import json
from pathlib import Path

import geopandas as gpd
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd
import polars as pl

from estuary.model.data import parse_dt_from_pth

In [None]:
skipped_regions = pd.read_csv("/Volumes/x10pro/estuary/geos/skipped_regions.csv")[
    "Site code"
].to_list()

gdf = gpd.read_file("/Users/kyledorman/data/estuary/geos/ca_data_w_usgs.geojson")
gdf = gdf[~gdf["Site code"].isin(skipped_regions)].copy()
gdf = gdf.set_index("Site code")
gdf.head()

In [None]:
with open("/Users/kyledorman/data/estuary/geos/ca_empa_matching_sites.json") as f:
    matching_sites = json.load(f)

matching_sites

In [None]:
empa = pl.read_csv("/Volumes/x10pro/estuary/ca_all/empa/logger-raw-publish.csv")
# define time parsing (try multiple formats)
parsed_dt = pl.coalesce(
    [
        pl.col("samplecollectiontimestamp").str.strptime(
            pl.Datetime, "%d/%m/%Y %H:%M:%S", strict=False
        ),
        pl.col("samplecollectiontimestamp").str.strptime(
            pl.Datetime, "%d/%m/%Y %H:%M:%S%.f", strict=False
        ),
    ]
)

# offsets relative to UTC (Polars doesn’t know “PST/PDT” by name)
# PST = UTC−8, PDT = UTC−7
empa = empa.with_columns([parsed_dt.alias("samplecollectiontimestamp_parsed")])

# apply offset based on timezone
empa = empa.with_columns(
    [
        pl.when(pl.col("samplecollectiontimezone") == "PST")
        .then(pl.col("samplecollectiontimestamp_parsed") + pl.duration(hours=8))  # PST -> UTC
        .when(pl.col("samplecollectiontimezone") == "PDT")
        .then(pl.col("samplecollectiontimestamp_parsed") + pl.duration(hours=7))  # PDT -> UTC
        .when(pl.col("samplecollectiontimezone") == "UTC")
        .then(pl.col("samplecollectiontimestamp_parsed"))
        .otherwise(pl.col("samplecollectiontimestamp_parsed"))
        .alias("samplecollectiontimestamp_utc2")
    ]
)
empa = empa.filter(pl.col("siteid").is_in(list(matching_sites.values())))
empa.head()

In [None]:
site_ids = empa["siteid"].unique().sort()

for si in site_ids:
    region = next(k for k, v in matching_sites.items() if v == si)
    name = gdf.loc[int(region)]["Site name"]
    print(si, region, name)

In [None]:
counts = (
    empa.group_by(["siteid", "sensorid"])
    .agg(pl.len().alias("n_samples"))
    .sort(["siteid", "sensorid"])
)

counts

In [None]:
ranges = (
    empa.group_by(["siteid", "sensorid"])
    .agg(
        [
            pl.col("samplecollectiontimestamp_parsed").min().alias("start"),
            pl.col("samplecollectiontimestamp_parsed").max().alias("end"),
        ]
    )
    .sort(["siteid", "sensorid"])
)
print(ranges)

pdf = ranges.to_pandas()

# Combine siteid & sensorid as label
pdf["label"] = pdf["siteid"] + " - " + pdf["sensorid"]

fig, ax = plt.subplots(figsize=(8, 10))
for i, row in pdf.iterrows():
    ax.plot([row["start"], row["end"]], [i, i], lw=6, label=row["label"])

ax.set_yticks(range(len(pdf)))
ax.set_yticklabels(pdf["label"])
ax.set_xlabel("Date")
ax.set_title("Sensor Availability Timeline")
plt.tight_layout()
plt.show()

In [None]:
corr = pl.read_csv("/Volumes/x10pro/estuary/ca_all/empa/logger-raw-depth-correction-publish.csv")
corr = corr.with_columns(
    [
        pl.coalesce(
            [
                pl.col("samplecollectiontimestamp").str.strptime(
                    pl.Datetime, "%d/%m/%Y %H:%M:%S", strict=False
                ),
                pl.col("samplecollectiontimestamp").str.strptime(
                    pl.Datetime, "%d/%m/%Y %H:%M:%S%.f", strict=False
                ),
            ]
        ).alias("samplecollectiontimestamp_parsed")
    ]
)
corr = corr.filter(pl.col("siteid").is_in(list(matching_sites.values())))

corr

In [None]:
ranges = (
    corr.group_by(["siteid", "sensorid"])
    .agg(
        [
            pl.col("samplecollectiontimestamp_parsed").min().alias("start"),
            pl.col("samplecollectiontimestamp_parsed").max().alias("end"),
        ]
    )
    # Filter groups where (end - start) > 90 days
    .filter((pl.col("end") - pl.col("start")) > pl.duration(days=90))
    .sort(["siteid", "sensorid"])
)
# print(ranges)

pdf = ranges.to_pandas()
pdf = pdf.sort_values(by="siteid")

# Combine siteid & sensorid as label
pdf["label"] = pdf["siteid"] + " - " + pdf["sensorid"]

fig, ax = plt.subplots(figsize=(8, 10))
for i, row in pdf.iterrows():
    ax.plot([row["start"], row["end"]], [i, i], lw=6, label=row["label"])

ax.set_yticks(range(len(pdf)))
ax.set_yticklabels(pdf["label"])
ax.set_xlabel("Date")
ax.set_title("Sensor Availability Timeline")
plt.tight_layout()
plt.show()

In [None]:
preds_all = pd.read_csv(Path("/Volumes/x10pro/estuary/ca_all/") / "preds.csv")
preds_all["acquired"] = preds_all["source_tif"].apply(lambda p: parse_dt_from_pth(Path(p)))
preds_all

In [None]:
a = empa.filter((pl.col("siteid") == "SC-VEN") & (pl.col("sensorid") == "X1683"))[
    ["samplecollectiontimestamp_utc2", "raw_h2otemp", "raw_conductivity"]
].to_pandas()
a["raw_h2otemp"] = a["raw_h2otemp"].apply(pd.to_numeric)
a["raw_conductivity"] = a["raw_conductivity"].apply(pd.to_numeric)
a = a.rename(columns={"samplecollectiontimestamp_utc2": "acquired"})

b = corr.filter((pl.col("siteid") == "SC-VEN") & (pl.col("sensorid") == "X1683"))[
    ["samplecollectiontimestamp_parsed", "corrected_depth"]
].to_pandas()
b["corrected_depth"] = b["corrected_depth"].apply(pd.to_numeric)

b = b.rename(columns={"samplecollectiontimestamp_parsed": "acquired"})

pd.merge(a, b, on="acquired", how="outer").raw_conductivity.isna().sum()

In [None]:
pd.merge(a, b, on="acquired", how="outer").isna().any(axis=1)

In [None]:
def _find_gaps(
    df_time_sorted,
    time_col="acquired",
    min_gap=pd.Timedelta(days=2.5),
    edge_buffer=pd.Timedelta(hours=6),
):
    """
    Return a list of (gap_start, gap_end) where time delta between consecutive
    samples exceeds `min_gap`. Each interval is trimmed by `edge_buffer` on both ends.
    """
    t = df_time_sorted[time_col].values
    if len(t) < 2:
        return []

    # compute diffs
    ts = df_time_sorted[time_col].reset_index(drop=True)
    d = ts.diff()

    gaps = []
    for i in range(1, len(ts)):
        if d.iloc[i] > min_gap:
            start_raw = ts.iloc[i - 1]
            end_raw = ts.iloc[i]

            # dynamic safety: don’t over-trim for short-but-qualifying gaps
            # use the smaller of requested buffer and 10% of gap
            dyn_buf = min(edge_buffer, (end_raw - start_raw) / 10)

            start = start_raw + dyn_buf
            end = end_raw - dyn_buf
            if start < end:
                gaps.append((start, end))
    return gaps


def plot_metric(sdf, col, site, save=False):
    # preds_all: ['acquired','y_true','y_prob','y_pred','region']
    # sdf:       ['acquired', col]
    preds_all_plot = preds_all[preds_all.region == site].copy()
    sdf_plot = sdf.copy()

    # normalize times
    preds_all_plot["acquired"] = (
        pd.to_datetime(preds_all_plot["acquired"], errors="coerce", utc=True)
        .dt.tz_convert("UTC")
        .dt.tz_localize(None)
        .astype("datetime64[ns]")
    )
    sdf_plot["acquired"] = (
        pd.to_datetime(sdf_plot["acquired"], errors="coerce", utc=True)
        .dt.tz_convert("UTC")
        .dt.tz_localize(None)
        .astype("datetime64[ns]")
    )
    preds_all_plot = preds_all_plot.sort_values("acquired").dropna(subset=["acquired"])
    sdf_plot = (
        sdf_plot.sort_values("acquired")
        .dropna(subset=["acquired"])
        .drop_duplicates(subset="acquired", keep="first")
    )

    # keep preds within depth time span
    if not sdf_plot.empty:
        preds_all_plot = preds_all_plot[
            (preds_all_plot.acquired >= sdf_plot.acquired.min())
            & (preds_all_plot.acquired <= sdf_plot.acquired.max())
        ]

    def find_state_changes(df, state_col, include_first=True):
        s = df[state_col].astype("int64")
        changed = s.ne(s.shift(1))
        if not include_first and len(changed):
            changed.iloc[0] = False
        return df.loc[changed, ["acquired", state_col]].rename(columns={state_col: "new_state"})

    # change points
    true_changes = find_state_changes(preds_all_plot, "y_true", include_first=True)
    pred_changes = find_state_changes(preds_all_plot, "y_pred", include_first=True)

    # nearest join to grab metric value at change times
    true_cp = pd.merge_asof(
        true_changes.sort_values("acquired"),
        sdf_plot.sort_values("acquired"),
        on="acquired",
        direction="nearest",
    )
    pred_cp = pd.merge_asof(
        pred_changes.sort_values("acquired"),
        sdf_plot.sort_values("acquired"),
        on="acquired",
        direction="nearest",
    )

    # --- compute data gaps from the metric series once; reuse for both plots ---
    gaps = _find_gaps(
        preds_all_plot[["acquired"]].sort_values("acquired"),
        time_col="acquired",
        min_gap=pd.Timedelta(days=4.5),
        edge_buffer=pd.Timedelta(hours=12),
    )  # tweak as you like

    def format_time_axis(ax):
        ax.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()))

    # -------- Plot 1: labels --------
    fig1, ax1 = plt.subplots(figsize=(11, 4))
    ax1.plot(sdf_plot["acquired"], sdf_plot[col], lw=1.5, color="k")
    ax1.set_title(f"{col} with LABEL State Changes (y_true)")
    ax1.set_xlabel("Time")
    ax1.set_ylabel(col)

    # blue filled bands for gaps (solid, not dashed)
    for start, end in gaps:
        ax1.axvspan(start, end, color="blue", alpha=0.15, linewidth=0)

    for _, row in true_cp.iterrows():
        color = "green" if int(row["new_state"]) == 1 else "red"
        ax1.axvline(row["acquired"], linestyle="--", color=color, alpha=0.8, linewidth=1.25)
        ax1.scatter(row["acquired"], row[col], color=color, s=35, zorder=3)

    format_time_axis(ax1)
    fig1.tight_layout()
    if save:
        plt.savefig(f"/Users/kyledorman/data/estuary/display/malibu_{col}_label.png", dpi=200)

    # -------- Plot 2: predictions --------
    fig2, ax2 = plt.subplots(figsize=(11, 4))
    ax2.plot(sdf_plot["acquired"], sdf_plot[col], lw=1.5, color="k")
    ax2.set_title(f"{col} with PREDICTED State Changes (y_pred)")
    ax2.set_xlabel("Time")
    ax2.set_ylabel(col)

    for start, end in gaps:
        ax2.axvspan(start, end, color="blue", alpha=0.15, linewidth=0)

    for _, row in pred_cp.iterrows():
        color = "green" if int(row["new_state"]) == 1 else "red"
        ax2.axvline(row["acquired"], linestyle="--", color=color, alpha=0.8, linewidth=1.25)
        ax2.scatter(row["acquired"], row[col], color=color, s=35, zorder=3)

    format_time_axis(ax2)
    fig2.tight_layout()
    if save:
        plt.savefig(f"/Users/kyledorman/data/estuary/display/malibu_{col}_pred.png", dpi=200)

    plt.show()

In [None]:
aaa = (
    empa.filter((pl.col("siteid") == "SC-VEN") & (pl.col("sensorid") == "X1683"))[
        [
            "raw_depth",
            "raw_pressure",
            "raw_h2otemp",
            "raw_ph",
            "raw_conductivity",
            "raw_turbidity",
            "raw_do",
            "raw_salinity",
            "sensorid",
            "sensortype",
            "samplecollectiontimestamp_parsed",
        ]
    ]
    .to_pandas()
    .rename(columns={"samplecollectiontimestamp_parsed": "acquired"})
)

# Ensure the index is datetime, not a column
aaa["acquired"] = pd.to_datetime(aaa.acquired)
aaa = aaa.sort_values(["acquired"])

# aaa = aaa[aaa["acquired"] > pd.Timestamp("2021-09-15")].copy()

aaa["raw_depth"] = aaa["raw_depth"].apply(pd.to_numeric)
aaa["raw_pressure"] = aaa["raw_pressure"].apply(pd.to_numeric)
aaa["raw_h2otemp"] = aaa["raw_h2otemp"].apply(pd.to_numeric)
aaa["raw_ph"] = aaa["raw_ph"].apply(pd.to_numeric).clip(6, 100)
aaa["raw_conductivity"] = aaa["raw_conductivity"].apply(pd.to_numeric)
aaa["raw_turbidity"] = aaa["raw_turbidity"].apply(pd.to_numeric)
aaa["raw_do"] = aaa["raw_do"].apply(pd.to_numeric)
aaa["raw_salinity"] = aaa["raw_salinity"].apply(pd.to_numeric)

aaa

In [None]:
plot_metric(aaa, "raw_pressure", 21, save=False)

In [None]:
cc = counts.filter(pl.col("siteid") == "SC-VEN").to_pandas()
cc.sort_values("n_samples").head(11)