# Tidal Analysis

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import folium
from pathlib import Path
from matplotlib.cm import ScalarMappable
from shapely.geometry import Polygon, MultiPolygon
import json
from random import shuffle

import warnings
warnings.filterwarnings("ignore")           # hide every warning

import cartopy.crs as ccrs
from matplotlib import cm, colors, ticker
from typing import Optional, Tuple

from src.gen_points_map import compute_step, make_equal_area_hex_grid

In [None]:
BASE = Path("/Users/kyledorman/data/planet_coverage/points_30km/")
FIG_DIR = BASE.parent / "figs" / "simulated_tidal"
FIG_DIR.mkdir(exist_ok=True, parents=True)

In [None]:
def fill_small_holes(geom, area_thresh):
    """Return a polygon with interior rings (holes) smaller than
    `area_thresh` removed (i.e. filled)."""
    if geom.is_empty or geom.geom_type not in {"Polygon", "MultiPolygon"}:
        return geom  # nothing to do

    def _fill(poly: Polygon) -> Polygon:
        # Keep exterior ring, drop any interior ring (hole) whose area < threshold
        holes_to_keep = [ring for ring in poly.interiors
                         if Polygon(ring).area >= area_thresh]
        return Polygon(poly.exterior, holes_to_keep)

    if geom.geom_type == "Polygon":
        return _fill(geom)
    else:  # MultiPolygon
        return MultiPolygon([_fill(p) for p in geom.geoms])

def clean_and_simplify(gdf: gpd.GeoDataFrame,
                       hole_area_thresh: float,
                       simplify_tolerance: float) -> gpd.GeoDataFrame:
    """Fill small holes, then simplify geometries (topology-preserving)."""
    gdf = gdf.copy()
    holeless = gdf.geometry.apply(fill_small_holes, area_thresh=hole_area_thresh)
    simp_geom = holeless.simplify(simplify_tolerance, preserve_topology=True)
    gdf.geometry = simp_geom
    return gdf

def assign_intersection_id(gdf, other_gdf, left_key, right_key, inter_crs, sinu_crs):
    gdf = gdf.copy()
    gdf["poly_area"] = gdf.geometry.area
    orig_crs = gdf.crs
    gdf = gdf.to_crs(inter_crs)
    other_gdf = other_gdf.to_crs(inter_crs)
    
    # Assign right_key to gdf
    inter_df = gdf[[left_key, "geometry"]].overlay(other_gdf[[right_key, "geometry"]], how="intersection")
    inter_df = inter_df.set_index(left_key).join(gdf.set_index(left_key)[["poly_area"]], how="left")
    inter_df = inter_df.to_crs(sinu_crs)
    inter_df["overlap_pct"] = inter_df.geometry.area / inter_df.poly_area
    inter_df = inter_df.reset_index().sort_values(by=[left_key, "overlap_pct"], ascending=False)
    inter_df = inter_df.drop_duplicates(subset=left_key).set_index(left_key)

    
    gdf = gdf.set_index(left_key).join(inter_df[[right_key]], how='left')
    
    invalid = gdf[right_key].isna()
    gdf.loc[invalid, right_key] = -1
    gdf[right_key] = gdf[right_key].astype(int)
    gdf = gdf.to_crs(orig_crs)

    return gdf.reset_index()
    

In [None]:
display_crs = "EPSG:4326"
sinu_crs = gpd.read_file(BASE / "ocean_grids.gpkg").crs

ecoregions = gpd.read_file(BASE.parent / "shorelines" / "marine_ecoregions").to_crs(display_crs)
ecoregions["eco_id"] = ecoregions.index
orig_area = ecoregions.geometry.area
ecoregions = ecoregions.to_crs(sinu_crs)
ecoregions = clean_and_simplify(ecoregions, hole_area_thresh=1e10, simplify_tolerance=1e4)
ecoregions = ecoregions.to_crs(display_crs)
new_area = ecoregions.geometry.area
ecoregions = ecoregions[orig_area / new_area > 0.5]

cell_size_m = compute_step(1.5)
_, hex_grid = make_equal_area_hex_grid(cell_size_m, "ESRI:54030")
hex_grid = hex_grid.to_crs(display_crs)
hex_grid["hex_id"] = hex_grid.cell_id

query_df = gpd.read_file(BASE / "ocean_grids.gpkg")
query_df = assign_intersection_id(query_df, ecoregions, "cell_id", "eco_id", display_crs, sinu_crs)
query_df = assign_intersection_id(query_df, hex_grid, "cell_id", "hex_id", display_crs, sinu_crs)

# Set plot crs
query_df = query_df.to_crs(display_crs)

# Load tidal data
tide_df = pd.read_csv(BASE / "simulated_tidal_coverage.csv").set_index("cell_id")
heuristics_df = pd.read_csv(BASE / "simulated_tidal_coverage_heuristics.csv").set_index("cell_id")

# Mark null values as full year
for col in tide_df.columns:
    tide_df.loc[tide_df[col].isna(), col] = 365.0
    assert not tide_df[col].isna().any()

# Merge all dataframes
query_df = query_df.set_index("cell_id")
tide_heuristics_grid_df = tide_df.join(heuristics_df).join(query_df[["geometry", "eco_id", "hex_id"]], how='inner')
tide_heuristics_grid_df = gpd.GeoDataFrame(tide_heuristics_grid_df, geometry="geometry", crs=display_crs)
tide_heuristics_grid_df["plot_id"] = tide_heuristics_grid_df.index.astype(str)

In [None]:
tide_heuristics_grid_df.head(3)

In [None]:
heuristics_df.head(5)

In [None]:
gdf = tide_heuristics_grid_df[["geometry"]].copy()

ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf["id"].nunique()
base_cmap = cm.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.PlateCarree(),   # <- incoming lon/lat coords
)

plt.title("Coastal Tide Grids", pad=12)
plt.tight_layout()
plt.savefig("/Users/kyledorman/Desktop/tidal_area.png")
plt.show()

In [None]:
gdf = ecoregions.copy()

ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf["id"].nunique()
base_cmap = cm.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.PlateCarree(),   # <- incoming lon/lat coords
)

plt.title("Global Ecoregions", pad=12)
plt.tight_layout()
plt.savefig("/Users/kyledorman/Desktop/eco_regions.png")
plt.show()

In [None]:
gdf = hex_grid[hex_grid.hex_id.isin(tide_heuristics_grid_df.hex_id.unique())]
ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf.id.nunique()
base_cmap = cm.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.PlateCarree(),   # <- incoming lon/lat coords
)

plt.title("Hex Coastal Area", pad=12)
plt.tight_layout()
plt.savefig("/Users/kyledorman/Desktop/hex_coastal_area.png")
plt.show()

In [None]:
# import polars as pl
# from src.query_udms import DataFrameRow

# all_lazy = pl.scan_parquet(
#     str(BASE / "*/results/*/*/*/*/data.parquet"),
#     schema=DataFrameRow.polars_schema(),
# )
# valid_cell_ids = all_lazy.select(pl.col("cell_id").unique().sort()).collect().to_series().to_list()
# print(len(valid_cell_ids))
# gdf = gpd.read_file(BASE / "ocean_grids.gpkg") # tide_heuristics_grid_df.reset_index()[["geometry"]].copy()
# gdf = gdf.set_index("cell_id").loc[valid_cell_ids].reset_index()

# ids = list(range(len(gdf)))
# shuffle(ids)
# gdf["id"] = ids

# n_ids  = gdf.id.nunique()
# base_cmap = cm.get_cmap("tab20", n_ids)  # up to 20 unique colours
# cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
# norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# # Pick any Cartopy projection
# proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

# fig = plt.figure(figsize=(12, 6))
# ax  = plt.axes(projection=proj)
# ax.set_global()

# # Re-project your data on the fly with `transform`
# gdf.plot(
#     column="id",
#     ax=ax,
#     cmap=cmap,
#     norm=norm,
#     linewidth=0.15,
#     edgecolor="black",
#     transform=ccrs.Sinusoidal(),   # <- incoming lon/lat coords
# )

# plt.title("Observed Ocean Grids", pad=12)
# plt.tight_layout()
# plt.savefig("/Users/kyledorman/Desktop/observed_ocean_grids.png")
# plt.show()

In [None]:
gdf = gpd.read_file(BASE / "coastal_strips.gpkg")
ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf.id.nunique()
base_cmap = cm.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.Sinusoidal(),   # <- incoming lon/lat coords
)

plt.title("Coastal Area", pad=12)
plt.tight_layout()
plt.savefig("/Users/kyledorman/Desktop/coastal_area.png")
plt.show()

In [None]:
def plot_gdf_column(
    gdf: gpd.GeoDataFrame,
    column: str,
    *,
    projection: ccrs.CRS = ccrs.Robinson(),
    cmap: str = "viridis",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    scale: str = "linear",               # "linear"  or  "log"
    figsize: Tuple[int, int] = (12, 6),
    edgecolor: str = "black",
    linewidth: float = 0.15,
    show_coastlines: bool = False,
    show_grid: bool = False,
    title: Optional[str] = None,
    save_path: str | None = None,
) -> None:
    """
    Plot a numeric column from a GeoDataFrame on a Cartopy map.

    Parameters
    ----------
    scale : {"linear", "log"}
        Colour normalisation.  "log" uses a base-10 LogNorm and
        *requires positive values*.
    All other parameters unchanged from previous version.
    """
    # ------------------------------------------------------------------
    # Basic checks
    # ------------------------------------------------------------------
    if gdf.crs is None or gdf.crs.to_epsg() != 4326:
        raise ValueError("GeoDataFrame must be in EPSG:4326 (lon/lat degrees)")
    if column not in gdf.columns:
        raise KeyError(f"{column!r} not found in GeoDataFrame")

    data = gdf[column].astype(float)

    # ------------------------------------------------------------------
    # Determine colour range & normalisation
    # ------------------------------------------------------------------
    if vmin is None:
        vmin = data[data > 0].min() if scale == "log" else data.min()
    if vmax is None:
        vmax = data.max()

    if scale == "log":
        if (data <= 0).any():
            raise ValueError("Log scale selected but column contains non-positive values.")
        norm = colors.LogNorm(vmin=vmin, vmax=vmax)
        # Tick locator/formatter for clean, linear-value ticks
        formatter = ticker.FuncFormatter(lambda y, _: f"{y:g}")
        locator = ticker.LogLocator(base=10, numticks=10)
    else:  # linear
        norm = colors.Normalize(vmin=vmin, vmax=vmax)
        formatter = ticker.ScalarFormatter()
        locator = ticker.MaxNLocator(nbins=6)

    cmap = cm.get_cmap(cmap)

    # ------------------------------------------------------------------
    # Plot
    # ------------------------------------------------------------------
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=projection)
    ax.set_global()

    gdf.plot(
        column=column,
        cmap=cmap,
        norm=norm,
        ax=ax,
        transform=ccrs.PlateCarree(),
        # edgecolor=edgecolor,
        linewidth=linewidth,
    )

    if show_coastlines:
        ax.coastlines(resolution="110m", linewidth=0.3)
    if show_grid:
        ax.gridlines(draw_labels=False, linewidth=0.2)

    # Colour bar with human-readable ticks
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation="vertical",
                        shrink=0.65, pad=0.02, format=formatter)
    cbar.locator = locator
    cbar.update_ticks()
    cbar.set_label(column)

    if title:
        ax.set_title(title, pad=12)

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path)
    plt.show()

In [None]:
df = pd.DataFrame(query_df.groupby("eco_id").eco_id.count()).rename(columns={"eco_id": "grid_count"})
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(gdf, "grid_count", title="Grid Count Ecoregion")

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")["hex_id"].count()).rename(columns={"hex_id": "grid_count"})
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(gdf, "grid_count", title="Grid Count Hex")


In [None]:
key = 'planet_observed_high_tide_offset'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_observed_low_tide_offset'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
    save_path=f"/Users/kyledorman/Desktop/{key}.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    save_path=f"/Users/kyledorman/Desktop/{key}_eco.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    save_path=f"/Users/kyledorman/Desktop/{key}_hex.png"
)

In [None]:
key = 'planet_observed_spread'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_high_days_between_p95'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
    scale='log',
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    scale='log',
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    scale='log',
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_low_days_between_p95'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
    scale='log',
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    scale='log',
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    scale='log',
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_observed_low_tide_offset_rel'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_observed_high_tide_offset_rel'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_low_count'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
key = 'planet_high_count'

plot_gdf_column(
    tide_heuristics_grid_df, 
    key, 
    title=key,
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("eco_id")[key].median())
df = df[df.index >= 0]
df = df.join(ecoregions[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " ecoregions",
    # save_path="/Users/kyledorman/Desktop/tide_range_ecoregions.png"
)

df = pd.DataFrame(tide_heuristics_grid_df.groupby("hex_id")[key].median())
df = df[df.index >= 0]
df = df.join(hex_grid[["geometry"]])
gdf = gpd.GeoDataFrame(df, geometry="geometry")

plot_gdf_column(
    gdf, 
    key, 
    title=key + " hex",
    # save_path="/Users/kyledorman/Desktop/tide_range_hex.png"
)

In [None]:
[p for p in tide_heuristics_grid_df.columns if p.startswith("planet")]