In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import folium
from pathlib import Path
from matplotlib.cm import ScalarMappable
from shapely.geometry import Polygon, MultiPolygon
import json
from random import shuffle
import cartopy.crs as ccrs
from matplotlib import colors, ticker
from branca.colormap import linear

from src.gen_points_map import compute_step, make_equal_area_hex_grid
from src.geo_util import assign_intersection_id

import warnings
warnings.filterwarnings("ignore")           # hide every warning

In [None]:
BASE = Path("/Users/kyledorman/data/planet_coverage/points_30km/")
SHORELINES = BASE.parent / "shorelines"
FIG_DIR = BASE.parent / "figs" / "displays"
FIG_DIR.mkdir(exist_ok=True, parents=True)

In [None]:
display_crs = "EPSG:4326"
robinson_crs = "ESRI:54030"
sinus_crs = "ESRI:54008"

query_df = gpd.read_file(SHORELINES / "ocean_grids.gpkg")
grids_df = gpd.read_file(SHORELINES / "coastal_grids.gpkg").rename(columns={"cell_id": "grid_id"})
eco_df = gpd.read_file(SHORELINES / "marine_ecoregions")
eco_df["eco_id"] = list(range(len(eco_df)))

cell_size_m = compute_step(1.5)
_, hex_grid = make_equal_area_hex_grid(cell_size_m, robinson_crs)
hex_grid = hex_grid.rename(columns={"cell_id": "hex_id"})
hex_grid = hex_grid.to_crs(sinus_crs)

# Assign hex_id to query_df and grid_df
grids_df = assign_intersection_id(grids_df, hex_grid, "grid_id", "hex_id")
query_df = assign_intersection_id(query_df, hex_grid, "cell_id", "hex_id")

# Assign cell_id to grid_df
grids_df = assign_intersection_id(grids_df, query_df, "grid_id", "cell_id")

# Assign eco_id to grid_df and query_df
query_df = assign_intersection_id(query_df, eco_df.to_crs(sinus_crs), "cell_id", "eco_id")
ids = grids_df.set_index("cell_id")[["grid_id"]].join(query_df.set_index("cell_id")[["eco_id"]]).reset_index()
grids_df = grids_df.set_index('grid_id').join(ids.set_index("grid_id")[["eco_id"]]).reset_index()

# Set plot crs
query_df = query_df.to_crs(display_crs)
grids_df = grids_df.to_crs(display_crs)
hex_grid = hex_grid.to_crs(display_crs)
eco_df = eco_df.to_crs(display_crs)

# Set indexes
query_df = query_df.set_index("cell_id")
grids_df = grids_df.set_index("grid_id")
hex_grid = hex_grid.set_index("hex_id")
eco_df = eco_df.set_index("eco_id")

In [None]:
local_grid = gpd.read_file(BASE / "la.geojson")
# Filter grids to CA region
query_local = query_df[query_df.geometry.intersects(local_grid.union_all())]
grids_local = grids_df[grids_df.geometry.intersects(query_local.union_all())]
hex_grid_local = hex_grid[hex_grid.geometry.intersects(query_local.union_all())]

In [None]:
centroid = local_grid.geometry[0].centroid
base_map = folium.Map(location=[centroid.y, centroid.x], zoom_start=7, width=1000, height=800)

# for idx, geo in enumerate(local_grid.geometry):
#     folium.GeoJson(
#         geo,
#         style_function=lambda feature: {
#             "color": "purple",
#             "weight": 4,
#         }
#     ).add_to(base_map)

for gid, row in query_local.iterrows():
    folium.GeoJson(
        row.geometry,
        popup=str(gid),
        style_function=lambda feature: {
            "color": "blue",
            "weight": 2,
        }
    ).add_to(base_map)

for gid, row in grids_local.iterrows():
    folium.GeoJson(
        row.geometry,
        popup=str(gid),
        style_function=lambda feature: {
            "color": "green",
            "weight": 1,
        }
    ).add_to(base_map)

for gid, row in hex_grid_local.iterrows():
    folium.GeoJson(
        row.geometry,
        popup=str(gid),
        style_function=lambda feature: {
            "color": "yellow",
            "weight": 1,
        }
    ).add_to(base_map)

# Display the map
base_map

In [None]:
save_path = FIG_DIR / "la_grids.png"
pad_fraction = 0.05

fig = plt.figure(figsize=(11, 6))
ax = plt.axes(projection=ccrs.Robinson())

# Compute extent from data bounds (EPSG:4326) and add a small margin
xmin, ymin, xmax, ymax = hex_grid_local.total_bounds
dx, dy = xmax - xmin, ymax - ymin
if dx == 0 or dy == 0:        # degenerate case (single point / line)
    dx = dy = max(dx, dy) or 1.0  # give it 1° span to avoid zero-width
pad_x = dx * pad_fraction
pad_y = dy * pad_fraction
ax.set_extent([xmin - pad_x, xmax + pad_x, ymin - pad_y, ymax + pad_y], crs=ccrs.PlateCarree())

# ------------------------------------------------------------------
# Plot data
# ------------------------------------------------------------------
hex_grid_local.plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    edgecolor="yellow",
    linewidth=1.0,
    facecolor="none",
)
query_local.plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    edgecolor="blue",
    linewidth=1.0,
    facecolor="none",
)
grids_local.plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    edgecolor="green",
    linewidth=0.5,
    facecolor="none",
)

ax.coastlines(resolution="110m", linewidth=0.3)
ax.gridlines(draw_labels=False, linewidth=0.2)

import cartopy

ax.add_feature(cartopy.feature.OCEAN, zorder=0)  # type: ignore
ax.add_feature(cartopy.feature.LAND, zorder=0)  # type: ignore

plt.tight_layout()
if save_path is not None:
    plt.savefig(save_path)
plt.show()

In [None]:
gdf = hex_grid.loc[grids_df[grids_df.hex_id >= 0].hex_id.unique()][["geometry"]].copy()

ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf["id"].nunique()
base_cmap = plt.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.PlateCarree(),   # <- incoming lon/lat coords
)
ax.stock_img()
plt.title("Coastal Hex Grids (1.5 Degree)", pad=12)
plt.tight_layout()
plt.savefig(FIG_DIR / "hex_grids_1_5.png")
plt.show()

In [None]:
gdf = query_df.loc[grids_df[grids_df.cell_id >= 0].cell_id.unique()][["geometry"]].copy()

ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf["id"].nunique()
base_cmap = plt.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.PlateCarree(),   # <- incoming lon/lat coords
)
ax.stock_img()
plt.title("Coastal Query Grids (1.0 Degree)", pad=12)
plt.tight_layout()
plt.savefig(FIG_DIR / "query_grids.png")
plt.show()

In [None]:
gdf = eco_df.copy()

ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf["id"].nunique()
base_cmap = plt.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.PlateCarree(),   # <- incoming lon/lat coords
)

ax.stock_img()
plt.title("Global Ecoregions", pad=12)
plt.tight_layout()
plt.savefig(FIG_DIR / "eco_regions.png")
plt.show()

In [None]:
gdf = gpd.read_file(BASE / "coastal_strips.gpkg")
ids = list(range(len(gdf)))
shuffle(ids)
gdf["id"] = ids

n_ids  = gdf.id.nunique()
base_cmap = plt.get_cmap("tab20", n_ids)  # up to 20 unique colours
cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# Pick any Cartopy projection
proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

fig = plt.figure(figsize=(12, 6))
ax  = plt.axes(projection=proj)
ax.set_global()

# Re-project your data on the fly with `transform`
gdf.plot(
    column="id",
    ax=ax,
    cmap=cmap,
    norm=norm,
    linewidth=0.15,
    edgecolor="black",
    transform=ccrs.Sinusoidal(),   # <- incoming lon/lat coords
)

plt.title("Coastal Area", pad=12)
plt.tight_layout()
plt.savefig(FIG_DIR / "coastal_strips.png")
plt.show()

In [None]:
# import polars as pl
# from src.query_udms import DataFrameRow

# all_lazy = pl.scan_parquet(
#     str(BASE / "*/results/*/*/*/*/data.parquet"),
#     schema=DataFrameRow.polars_schema(),
# )
# valid_cell_ids = all_lazy.select(pl.col("cell_id").unique().sort()).collect().to_series().to_list()
# print(len(valid_cell_ids))
# gdf = gpd.read_file(BASE / "ocean_grids.gpkg") # tide_heuristics_grid_df.reset_index()[["geometry"]].copy()
# gdf = gdf.set_index("cell_id").loc[valid_cell_ids].reset_index()

# ids = list(range(len(gdf)))
# shuffle(ids)
# gdf["id"] = ids

# n_ids  = gdf.id.nunique()
# base_cmap = cm.get_cmap("tab20", n_ids)  # up to 20 unique colours
# cmap      = colors.ListedColormap(base_cmap(range(n_ids)))
# norm      = colors.BoundaryNorm(range(n_ids + 1), n_ids)

# # Pick any Cartopy projection
# proj = ccrs.Robinson()           # or ccrs.Mollweide(), ccrs.Robinson(), …

# fig = plt.figure(figsize=(12, 6))
# ax  = plt.axes(projection=proj)
# ax.set_global()

# # Re-project your data on the fly with `transform`
# gdf.plot(
#     column="id",
#     ax=ax,
#     cmap=cmap,
#     norm=norm,
#     linewidth=0.15,
#     edgecolor="black",
#     transform=ccrs.Sinusoidal(),   # <- incoming lon/lat coords
# )

# plt.title("Observed Ocean Grids", pad=12)
# plt.tight_layout()
# plt.savefig("/Users/kyledorman/Desktop/observed_ocean_grids.png")
# plt.show()

In [None]:
grids_df["cent"] = grids_df.geometry.centroid

In [None]:
eco_grid_dist = grids_df[grids_df.dist_km > 45].dropna(subset=["dist_km"]).reset_index().groupby("eco_id").grid_id.count()
eco_grid_dist = eco_grid_dist[eco_grid_dist.index >= 0]

eco_grid_dist = eco_grid_dist.sort_values(ascending=False)

eco_grid_dist

In [None]:
eco_grid_counts = grids_df.dropna(subset=["dist_km"]).reset_index().groupby("eco_id").grid_id.count()
eco_grid_counts = eco_grid_counts[eco_grid_counts.index >= 0]

eco_grid_counts = eco_grid_counts.sort_values(ascending=False)
eco_iter = enumerate(eco_grid_counts.keys())

In [None]:
radius = 3

index, eco_id = next(eco_iter)
eco_id = int(eco_id)

print(eco_id)
print(index, "/", len(eco_grid_counts))

eco_grids = grids_df[(grids_df.eco_id == eco_id) & ~grids_df.is_land]

print(len(eco_grids))

eco_region = eco_df.loc[eco_id]

centroid = eco_region.geometry.centroid
m = folium.Map(location=[centroid.y, centroid.x], zoom_start=4, width=1000, height=800)

color_scale = linear.viridis.scale(0.0, 50.0)  # type: ignore

folium.GeoJson(
    eco_region.geometry,
    style_function=lambda feature: {
        "color": "red",
        "weight": 4,
    }
).add_to(m)

for grid_id, row in eco_grids.iterrows():
    value = row["dist_km"]
    centroid = row.cent
    folium.CircleMarker(
        location=[centroid.y, centroid.x],
        radius=radius,
        fill=True,
        fill_opacity=0.6,
        color=None,
        fill_color=color_scale(value),
        popup=f"{grid_id}<br>{value:.1f}"
    ).add_to(m)

color_scale.add_to(m)

# Display the map
m

In [None]:
from PIL import Image
import os

image_paths = sorted(list(Path("/Users/kyledorman/data/planet_coverage/figs/points_30km/days_with_sample/").glob("max_*all*.png")))[:-1]

def make_gif(image_paths):
    frames = [Image.open(image) for image in image_paths]
    frame_one = frames[0]
    pp = "/Users/kyledorman/data/planet_coverage/figs/points_30km/days_with_sample/max_all.gif"
    if os.path.exists(pp):
        os.remove(pp)
    frame_one.save(pp, format="GIF", append_images=frames, save_all=True, duration=700, loop=0)
    
make_gif(image_paths)