# Realistic Bay Groundwater Dynamics with Tides & Sea Level Rise

**Author:** Marine Denolle  
**Date:** 1/30/2026  
**Description:** 
Synthetic estuary/bay groundwater model with realistic physics, coupled coastal-inland dynamics, and animated visualization.

## Model Improvements

### 1. Realistic Bay Domain
- **Geometry**: 10 km x 5 km coastal zone (100m grid spacing)
- **Bathymetry**: Gentle landward slope (coast: ~0m, inland: ~30m elevation)
- **Tidal flats**: Shallow intertidal zone (0-2m elevation, 0-1 km from coast)
- **River valley**: Sinuous north-south drainage channel (~3m below terrain)

### 2. Physical Parameters (Calibrated for 5-20m WT fluctuations)
- **Specific Yield (S_y)**: 0.06-0.16 → amplified head response to recharge
- **Coastal coupling**: τ_coast = 3 days, L_coast = 3 km → faster tidal transmission
- **Baseline WT depth**: 2.0 m below surface → water table stays mostly above sea level
- **Baseflow timescale**: τ_gw = 8-150 days → realistic seasonal memory

### 3. Realistic Forcing
- **Sea level rise**: 4 mm/yr (PNW realistic)
- **Spring/neap tides**: ±0.8 m, 14.77-day period
- **King tides**: ±0.2-0.3 m, seasonal & episodic
- **Precipitation**: 0.5-5 mm/day seasonal cycle + atmospheric rivers

### 4. Interactive Visualization
- **4-panel layout**: DEM geometry | WT change | Saturation | Coastal forcing time series
- **Animated playback**: Weekly snapshots with 0.3 sec sleep between frames
- **Real-time marker**: Shows current time on forcing plot; watch tides propagate inland!

## Expected Behavior

**Coastal zone (0-2 km):**
- Oscillates at tidal frequency (~14 days, ±0.5 m)
- King tides drive larger pulses (~±1-2 m every few weeks)
- Winter peaks ~50% higher than summer

**Inland zone (2-10 km):**
- Slower seasonal envelope (3-6 month timescale)
- High in winter (wet season + coastal head rise)
- Low in summer (ET drawdown + coastal head drop)
- Shows clear 10-20 m range seasonal amplitude

**River valley:**
- Responds quickly to atmospheric river events
- Flood pulses propagate seaward


## 1. Import Libraries

Import all necessary libraries for data analysis and visualization.

In [11]:
# Standard library imports
import os
import sys
from pathlib import Path

# Data manipulation
import numpy as np
import pandas as pd

# Scientific computing
from scipy import signal, stats

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns


from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Dict, Tuple, Optional

import numpy as np
import pandas as pd
import xarray as xr

try:
    from scipy.ndimage import distance_transform_edt
except Exception:
    distance_transform_edt = None


# Configure plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_context('notebook')
%matplotlib inline

In [None]:

@dataclass
class DomainConfig:
    nx: int = 100
    ny: int = 50
    dx_m: float = 100.0      # grid spacing in meters (synthetic)
    dy_m: float = 100.0
    seed: int = 7



@dataclass
class TimeConfig:
    start: str = "2015-01-01"
    years: int = 1
    freq: str = "M"

@dataclass
class ForcingConfig:
    # Seasonal precipitation baseline (PNW-like): wetter winters, drier summers
    precip_mean_winter_mm_d: float = 6.0
    precip_mean_summer_mm_d: float = 1.0
    precip_seasonality_sharpness: float = 1.7  # >1 makes winters sharper

    # Atmospheric rivers (AR): Poisson events, mostly Nov–Mar, multi-day pulses
    ar_lambda_per_winter_day: float = 0.035  # expected events/day in winter window
    ar_duration_mean_d: float = 2.5
    ar_intensity_mean_mm_d: float = 30.0
    ar_intensity_sigma_ln: float = 0.55

    # ET0 seasonality (dry summer drawdown)
    et0_winter_mm_d: float = 0.4
    et0_summer_mm_d: float = 3.5

    # Sea level rise + tides
    sea_level_rise_m_per_yr: float = 0.0035  # ~3.5 mm/yr
    tide_spring_neap_period_d: float = 14.77
    tide_spring_neap_amp_m: float = 0.9      # effective coastal head oscillation amplitude
    king_tide_annual_amp_m: float = 0.25     # seasonal modulation (bigger in winter)
    king_tide_event_prob_per_day: float = 0.005
    king_tide_event_amp_m: float = 0.35
    king_tide_event_decay_d: float = 1.5


@dataclass
class SoilConfig:
    zr_m: float = 1.2  # root-zone depth [m]
    # Percolation (gravity drainage) time scale at field capacity
    perc_tau_d_default: float = 6.0


@dataclass
class GroundwaterConfig:
    depth0_m: float = 2.5     # baseline depth to water table (below ground) [m]
    Sy_default: float = 0.08  # specific yield
    tau_gw_d_default: float = 40.0

    # Coastal coupling (how strongly tide/sea level affects inland head)
    tau_coast_d: float = 5.0
    Lcoast_m: float = 7000.0

    # River-flood influence (optional)
    enable_river_flooding: bool = True
    Lriver_m: float = 3500.0
    flood_gain_m_per_mm: float = 0.0025   # converts AR/flood rainfall to head forcing near river
    flood_recession_d: float = 3.0


@dataclass
class SubsidenceConfig:
    enable: bool = True
    reclaimed_fraction: float = 0.08
    subsidence_m_per_yr_min: float = 0.001
    subsidence_m_per_yr_max: float = 0.006




In [13]:

def _day_of_year_index(time: pd.DatetimeIndex) -> np.ndarray:
    # 1..366
    return time.dayofyear.values.astype(float)

def _seasonal_weight_wet(doy: np.ndarray) -> np.ndarray:
    """
    Wetness seasonal index in [0,1] peaking in winter.
    Uses a cosine with peak around Jan 15.
    """
    phase = 2.0 * np.pi * (doy - 15.0) / 365.25
    w = 0.5 * (1.0 + np.cos(phase))  # 1 in mid-winter, 0 in mid-summer
    return np.clip(w, 0.0, 1.0)

def _smoothstep(x: np.ndarray, sharpness: float) -> np.ndarray:
    # Raises seasonal contrast: x^sharp / (x^sharp + (1-x)^sharp)
    x = np.clip(x, 0.0, 1.0)
    a = np.power(x, sharpness)
    b = np.power(1.0 - x, sharpness)
    return a / (a + b + 1e-12)

def _mm_to_m(x_mm: np.ndarray) -> np.ndarray:
    return x_mm / 1000.0


#

In [14]:

# Spatial fields (synthetic DEM, soils, masks)
# ----------------------------

def make_synthetic_dem_and_masks(dom: DomainConfig) -> Dict[str, np.ndarray]:
    """
    Creates a realistic bay/estuary domain:
    - West edge (x=0): ocean/coast at sea level
    - Shallow tidal flats (x: 0-1 km) with 0-2 m elevation, intertidal zone
    - Gentle upland slope (1-5 km): gradual rise to ~30 m
    - River valley down the center (north-south)
    
    Returns:
      dem_m [ny,nx], coast_mask, reclaimed_mask, river_mask,
      dist_coast_m, dist_river_m
    """
    rng = np.random.default_rng(dom.seed)

    ny, nx = dom.ny, dom.nx
    y = np.linspace(0, 1, ny)[:, None]
    x = np.linspace(0, 1, nx)[None, :]

    # Gentle coastal-to-inland slope (bay bathymetry rising eastward)
    # Low near coast, rises gradually inland to ~25-30 m
    slope = -1.0 + 35.0 * x**1.5
    
    # Tidal flats: shallow depression near coast (x < 0.15)
    tidal_flat = 2.0 * np.exp(-((x - 0.05) ** 2) / (2 * 0.05**2))
    
    # River valley: sinuous north-south channel, ~3m below surrounding terrain
    river_x_norm = 0.35 + 0.08 * np.sin(2 * np.pi * y)
    river_valley_strength = np.exp(-((x - river_x_norm) ** 2) / (2 * 0.06**2))
    river_valley = -3.5 * river_valley_strength
    
    # Smooth random roughness (low amplitude, subtle)
    roughness = np.zeros((ny, nx), dtype=float)
    for k in range(3):
        fx = rng.uniform(0.5, 2.0)
        fy = rng.uniform(0.5, 2.0)
        amp = rng.uniform(0.5, 1.5)
        roughness += amp * np.sin(2*np.pi*fx*x) * np.sin(2*np.pi*fy*y)

    dem = slope + tidal_flat + river_valley + 0.3 * roughness
    dem = np.clip(dem, -2.0, 40.0)  # Allow shallow submarine areas, cap at ~40m inland

    # Coast: define cells with x < 0.05 as water/coastal boundary
    coast_mask = (x < 0.05).astype(bool).repeat(ny, axis=0)
    
    # Tidal flats: intertidal zone (x: 0.05-0.15, dem: 0-2m)
    tidal_zone = ((x >= 0.05) & (x < 0.15) & (dem < 2.5)).astype(bool).repeat(ny, axis=0)

    # Reclaimed land: small fraction in lower elevation areas (disabled for this scenario)
    reclaimed_mask = np.zeros_like(coast_mask, dtype=bool)

    # River corridor: sinuous channel
    river_mask = river_valley_strength > 0.7

    # Distances
    def dist_from_mask(mask_bool: np.ndarray) -> np.ndarray:
        if distance_transform_edt is None:
            pts = np.argwhere(mask_bool)
            if pts.size == 0:
                return np.full_like(dem, 1e9, dtype=float)
            yy_, xx_ = np.indices(mask_bool.shape)
            pts = pts[:: max(1, pts.shape[0] // 4000)]
            d2 = np.min((yy_[..., None] - pts[:, 0])**2 + (xx_[..., None] - pts[:, 1])**2, axis=-1)
            return np.sqrt(d2) * dom.dx_m
        else:
            dist_pix = distance_transform_edt(~mask_bool)
            return dist_pix * dom.dx_m

    dist_coast = dist_from_mask(coast_mask)
    dist_river = dist_from_mask(river_mask)

    return dict(
        dem_m=dem,
        coast_mask=coast_mask,
        tidal_zone=tidal_zone,
        reclaimed_mask=reclaimed_mask,
        river_mask=river_mask,
        dist_coast_m=dist_coast,
        dist_river_m=dist_river,
    )


def make_synthetic_soils(dom: DomainConfig, dem_m: np.ndarray) -> Dict[str, np.ndarray]:
    """
    Bay/estuary soils: sandy lowlands, more clay-silt in uplands.
    """
    rng = np.random.default_rng(dom.seed + 1)
    ny, nx = dem_m.shape

    # Normalize elevation to 0..1
    z = (dem_m - np.nanmin(dem_m)) / (np.nanmax(dem_m) - np.nanmin(dem_m) + 1e-12)

    # Texture: lowlands (near coast/tidal flats) more sandy; uplands more silty
    texture = np.clip(0.60*(1.0 - z) + 0.40*rng.normal(0, 0.20, size=(ny, nx)), 0.0, 1.0)

    # Porosity: higher in sandy lowlands
    porosity = np.clip(0.40 + 0.08*(1.0 - z) + 0.02*rng.normal(size=(ny, nx)), 0.35, 0.55)

    # Ksat [m/day]: higher in sandy areas
    log10_ksat = -1.8 + 1.0*texture + 0.2*rng.normal(size=(ny, nx))
    Ksat = np.power(10.0, log10_ksat)
    Ksat = np.clip(Ksat, 5e-4, 1.5)

    # Field capacity and wilting point
    theta_fc = np.clip(0.20 + 0.15*(1.0 - texture) + 0.02*rng.normal(size=(ny, nx)), 0.13, 0.40)
    theta_wp = np.clip(0.08 + 0.10*(1.0 - texture) + 0.01*rng.normal(size=(ny, nx)), 0.05, 0.22)

    # Ensure wp < fc < porosity
    theta_fc = np.minimum(theta_fc, porosity - 0.04)
    theta_wp = np.minimum(theta_wp, theta_fc - 0.03)
    theta_wp = np.clip(theta_wp, 0.04, None)

    # Percolation time scale: faster in higher Ksat
    perc_tau_d = np.clip(8.0 / np.sqrt(Ksat + 1e-6), 1.5, 15.0)

    # Groundwater params: Sy (sensitive areas have lower Sy for amplified response)
    Sy = np.clip(0.08 + 0.05*texture + 0.01*rng.normal(size=(ny, nx)), 0.06, 0.16)
    tau_gw_d = np.clip(20.0 + 50.0*(1.0 - texture) + 6.0*rng.normal(size=(ny, nx)), 8.0, 150.0)

    return dict(
        porosity=porosity,
        Ksat_m_d=Ksat,
        theta_fc=theta_fc,
        theta_wp=theta_wp,
        perc_tau_d=perc_tau_d,
        Sy=Sy,
        tau_gw_d=tau_gw_d,
        texture_proxy=texture,
    )


def make_subsidence_rate(dom: DomainConfig, reclaimed_mask: np.ndarray, sub: SubsidenceConfig) -> np.ndarray:
    rng = np.random.default_rng(dom.seed + 2)
    rate = np.zeros_like(reclaimed_mask, dtype=float)
    if not sub.enable:
        return rate
    low = sub.subsidence_m_per_yr_min
    high = sub.subsidence_m_per_yr_max
    rate[reclaimed_mask] = rng.uniform(low, high, size=int(reclaimed_mask.sum()))
    return rate


In [15]:
# Simulation functions
# ----------------------------

def make_precip_series(tconf: TimeConfig, fconf: ForcingConfig) -> np.ndarray:
    """Generate daily precipitation time series [mm/day]"""
    rng = np.random.default_rng(42)
    n_days = len(pd.date_range(start=tconf.start, periods=tconf.years*365, freq="D"))
    
    # Base precipitation with seasonal variation
    days_of_year = np.arange(n_days) % 365
    seasonal_weight = _seasonal_weight_wet(days_of_year)
    base = fconf.precip_mean_winter_mm_d * seasonal_weight + fconf.precip_mean_summer_mm_d * (1 - seasonal_weight)
    
    # Add random variability
    precip = base + rng.exponential(2.0, size=n_days)
    precip = np.clip(precip, 0.3, 15.0)
    
    return precip


def make_atmospheric_river_series(tconf: TimeConfig, fconf: ForcingConfig, precip: np.ndarray) -> np.ndarray:
    """Add atmospheric river events to precipitation"""
    rng = np.random.default_rng(43)
    n_days = len(precip)
    precip_total = precip.copy()
    
    # Generate AR events (winter-focused)
    for day in range(n_days):
        day_of_year = day % 365
        # AR more likely Oct-March (days 274-90)
        if day_of_year < 90 or day_of_year > 274:
            if rng.random() < fconf.ar_lambda_per_winter_day:
                duration = int(rng.exponential(fconf.ar_duration_mean_d) + 1)
                intensity = rng.lognormal(np.log(fconf.ar_intensity_mean_mm_d), fconf.ar_intensity_sigma_ln)
                for d in range(duration):
                    if day + d < n_days:
                        precip_total[day + d] += intensity * np.exp(-d / 2.0)  # Decay over duration
    
    return precip_total


def make_et0_series(tconf: TimeConfig, fconf: ForcingConfig) -> np.ndarray:
    """Generate potential evapotranspiration [mm/day]"""
    n_days = len(pd.date_range(start=tconf.start, periods=tconf.years*365, freq="D"))
    days_of_year = np.arange(n_days) % 365
    
    # ET higher in summer, lower in winter
    seasonal_cycle = 0.5 * (1 + np.cos(2*np.pi * (days_of_year - 182) / 365))  # Peak in summer (day 182)
    et0 = fconf.et0_winter_mm_d + (fconf.et0_summer_mm_d - fconf.et0_winter_mm_d) * seasonal_cycle
    
    return et0


def make_sea_level_and_tides(tconf: TimeConfig, fconf: ForcingConfig) -> np.ndarray:
    """Generate sea level boundary condition: long-term rise + tides + king tides"""
    rng = np.random.default_rng(44)
    n_days = len(pd.date_range(start=tconf.start, periods=tconf.years*365, freq="D"))
    time_years = np.arange(n_days) / 365.0
    
    # Long-term sea level rise
    slr = fconf.sea_level_rise_m_per_yr * time_years
    
    # Spring-neap tidal cycle
    tides = fconf.tide_spring_neap_amp_m * np.sin(2*np.pi * np.arange(n_days) / fconf.tide_spring_neap_period_d)
    
    # King tides (seasonal + occasional events)
    king_seasonal = fconf.king_tide_annual_amp_m * np.sin(2*np.pi * time_years)
    king_events = np.zeros(n_days)
    for day in range(n_days):
        if rng.random() < fconf.king_tide_event_prob_per_day:
            # Add king tide event with exponential decay
            for d in range(int(fconf.king_tide_event_decay_d * 5)):
                if day + d < n_days:
                    king_events[day + d] += fconf.king_tide_event_amp_m * np.exp(-d / fconf.king_tide_event_decay_d)
    
    sea_level = slr + tides + king_seasonal + king_events
    
    return sea_level


def river_flood_pulse_series(tconf: TimeConfig, gconf: GroundwaterConfig, precip_total: np.ndarray) -> np.ndarray:
    """River stage responds to precipitation with lag"""
    from scipy.signal import lfilter
    
    if not gconf.enable_river_flooding:
        return np.zeros_like(precip_total)
    
    # Simple exponential filter (lag response)
    tau_days = gconf.flood_recession_d
    b = [1.0 / tau_days]
    a = [1.0, -np.exp(-1.0 / tau_days)]
    
    river_stage = lfilter(b, a, precip_total) * gconf.flood_gain_m_per_mm
    river_stage = np.clip(river_stage, 0, 3.0)
    
    return river_stage


def infiltration_from_precip(precip_mm_d: float, theta: float, porosity: float, Ksat: float) -> float:
    """Calculate infiltration rate [mm/day]"""
    # Simplified: infiltration limited by saturation deficit and Ksat
    deficit = max(0.0, porosity - theta)
    inf_capacity = min(Ksat * 1000.0, precip_mm_d)  # Convert Ksat m/d to mm/d
    infiltration = min(inf_capacity, deficit * 1000.0)  # mm/day
    
    return infiltration


def evapotranspiration(et0: float, theta: float, theta_fc: float, theta_wp: float) -> float:
    """Calculate actual ET [mm/day]"""
    # ET reduced when soil below field capacity
    if theta < theta_wp:
        return 0.0
    elif theta < theta_fc:
        return et0 * (theta - theta_wp) / (theta_fc - theta_wp)
    else:
        return et0


def percolation(theta: float, theta_fc: float, tau_d: float) -> float:
    """Calculate percolation rate [mm/day]"""
    # Drainage when above field capacity
    if theta > theta_fc:
        return (theta - theta_fc) * 1000.0 / tau_d  # mm/day
    return 0.0


def groundwater_step(wt_depth: np.ndarray, dem: np.ndarray, perc_mm_d: np.ndarray, 
                     Sy: np.ndarray, tau_gw: np.ndarray, dist_coast: np.ndarray,
                     dist_river: np.ndarray, sea_level: float, river_stage: float,
                     tau_coast: float, L_coast: float, L_river: float, dt_d: float) -> np.ndarray:
    """Update groundwater depth with boundary conditions"""
    
    # Recharge from percolation
    recharge_m_d = perc_mm_d / 1000.0 / Sy
    
    # Lateral drainage (simplified diffusion)
    wt_elev = dem - wt_depth
    grad_y, grad_x = np.gradient(wt_elev)
    ddh = grad_y**2 + grad_x**2
    lateral_drainage = ddh / tau_gw * dt_d
    
    # Coastal boundary forcing
    coast_weight = np.exp(-dist_coast / L_coast)
    coast_target_depth = dem - sea_level
    coast_pull = (coast_target_depth - wt_depth) * coast_weight / tau_coast * dt_d
    
    # River boundary forcing
    river_weight = np.exp(-dist_river / L_river)
    river_target_depth = dem - river_stage
    river_pull = (river_target_depth - wt_depth) * river_weight / 5.0 * dt_d
    
    # Update water table depth
    wt_depth_new = wt_depth - recharge_m_d * dt_d + lateral_drainage + coast_pull + river_pull
    wt_depth_new = np.clip(wt_depth_new, 0.1, 50.0)
    
    return wt_depth_new


def run_simulation() -> xr.Dataset:
    """Execute the full simulation"""
    
    # Configuration
    dom = DomainConfig()
    tconf = TimeConfig()
    fconf = ForcingConfig()
    sconf = SoilConfig()
    gconf = GroundwaterConfig()
    subconf = SubsidenceConfig()
    
    # Generate spatial fields
    dem_dict = make_synthetic_dem_and_masks(dom)
    dem_m = dem_dict["dem_m"]
    dist_coast_m = dem_dict["dist_coast_m"]
    dist_river_m = dem_dict["dist_river_m"]
    coast_mask = dem_dict["coast_mask"]
    tidal_zone = dem_dict["tidal_zone"]
    river_mask = dem_dict["river_mask"]
    
    soils = make_synthetic_soils(dom, dem_m)
    
    # Generate forcing time series
    time_index = pd.date_range(start=tconf.start, periods=tconf.years*365, freq="D")
    n_days = len(time_index)
    
    precip_base = make_precip_series(tconf, fconf)
    precip_total = make_atmospheric_river_series(tconf, fconf, precip_base)
    et0 = make_et0_series(tconf, fconf)
    sea_level = make_sea_level_and_tides(tconf, fconf)
    river_stage = river_flood_pulse_series(tconf, gconf, precip_total)
    
    # Initialize state variables
    ny, nx = dom.ny, dom.nx
    theta = np.full((ny, nx), soils["theta_fc"])
    wt_depth = np.full((ny, nx), gconf.depth0_m)
    
    # Storage arrays
    wt_change = np.zeros((n_days, ny, nx), dtype=np.float32)
    above_sat_change = np.zeros((n_days, ny, nx), dtype=np.float32)
    
    # Time-stepping loop
    dt_d = 1.0  # 1 day timestep
    
    print(f"Starting simulation: {n_days} days, {ny}x{nx} grid...")
    
    for it in range(n_days):
        if it % 365 == 0:
            print(f"  Year {it//365 + 1}/{tconf.years}...")
        
        # Root zone water balance
        for iy in range(ny):
            for ix in range(nx):
                # Infiltration
                inf = infiltration_from_precip(precip_total[it], theta[iy, ix], 
                                              soils["porosity"][iy, ix], soils["Ksat_m_d"][iy, ix])
                
                # ET
                et = evapotranspiration(et0[it], theta[iy, ix], soils["theta_fc"][iy, ix], 
                                       soils["theta_wp"][iy, ix])
                
                # Update soil moisture
                dtheta = (inf - et) / (sconf.zr_m * 1000.0)  # Change in m3/m3
                theta[iy, ix] += dtheta * dt_d
                theta[iy, ix] = np.clip(theta[iy, ix], 0.0, soils["porosity"][iy, ix])
        
        # Percolation
        perc = np.zeros((ny, nx))
        for iy in range(ny):
            for ix in range(nx):
                perc[iy, ix] = percolation(theta[iy, ix], soils["theta_fc"][iy, ix], 
                                          soils["perc_tau_d"][iy, ix])
        
        # Groundwater update
        wt_depth = groundwater_step(
            wt_depth, dem_m, perc, soils["Sy"], soils["tau_gw_d"],
            dist_coast_m, dist_river_m, sea_level[it], river_stage[it],
            gconf.tau_coast_d, gconf.Lcoast_m, gconf.Lriver_m, dt_d
        )
        
        # Compute water table elevation change from initial
        wt_elev = dem_m - wt_depth
        wt_elev_init = dem_m - gconf.depth0_m
        wt_change[it] = wt_elev - wt_elev_init
        
        # Above-ground saturation (ponding)
        above_sat = np.maximum(0.0, wt_elev - dem_m)
        above_sat_change[it] = above_sat
    
    print("  Simulation complete!")
    
    # Create xarray Dataset
    x = np.arange(nx) * dom.dx_m
    y = np.arange(ny) * dom.dy_m
    
    ds = xr.Dataset(
        {
            "water_table_elev_change": (["time", "y", "x"], wt_change),
            "above_saturation_change": (["time", "y", "x"], above_sat_change),
            "dem_m": (["y", "x"], dem_m),
            "dist_coast_m": (["y", "x"], dist_coast_m),
            "dist_river_m": (["y", "x"], dist_river_m),
            "coast_mask": (["y", "x"], coast_mask.astype(np.int8)),
            "tidal_zone": (["y", "x"], tidal_zone.astype(np.int8)),
            "river_mask": (["y", "x"], river_mask.astype(np.int8)),
            "precip_total_mm_d": (["time"], precip_total),
            "et0_mm_d": (["time"], et0),
            "sea_level_boundary_m": (["time"], sea_level),
            "river_stage_m": (["time"], river_stage),
        },
        coords={
            "time": time_index,
            "y": y,
            "x": x,
        },
    )
    
    return ds


def save_dataset(ds: xr.Dataset, path: str) -> None:
    """Save dataset to NetCDF"""
    ds.to_netcdf(path)
    print(f"Dataset saved to {path}")

In [None]:
import time
from pathlib import Path

def animate_simulation(
    ds: xr.Dataset,
    frame_step: int = 7,
    sleep_time: float = 0.3,
    year_subset: Optional[int] = None,
    save_pngs: bool = True,
    output_dir: str = "groundwater_animation",
) -> None:
    """
    Animate groundwater dynamics with slow-motion playback.
    Optionally saves frames as PNG and creates an animated GIF.
    
    Args:
        ds: xarray Dataset from run_simulation
        frame_step: Plot every N days (7=weekly, 30=monthly)
        sleep_time: Pause between frames (seconds)
        year_subset: Restrict to specific year (1=first year, etc.)
        save_pngs: Whether to save individual PNG frames
        output_dir: Directory to store outputs
    """
    import imageio
    
    # Create output directory
    out_path = Path(output_dir)
    out_path.mkdir(exist_ok=True)
    png_dir = out_path / "frames"
    png_dir.mkdir(exist_ok=True)
    
    time_vals = ds.time.values
    nt = len(time_vals)
    
    # Optional: subset to specific year
    if year_subset is not None:
        start_idx = (year_subset - 1) * 365
        end_idx = min(year_subset * 365, nt)
        time_indices = range(start_idx, end_idx, frame_step)
    else:
        time_indices = range(0, nt, frame_step)
    
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    ax_dem = fig.add_subplot(gs[0, 0])
    ax_wt = fig.add_subplot(gs[0, 1])
    ax_sat = fig.add_subplot(gs[0, 2])
    ax_coastal = fig.add_subplot(gs[1, :])
    
    # Static maps (DEM, coast, river)
    dem = ds["dem_m"].values
    coast_mask = (ds["dist_coast_m"].values < 500).astype(int)
    river_mask = (ds["dist_river_m"].values < 500).astype(int)
    
    # Vmin/vmax for consistent colormaps
    vmin_wt, vmax_wt = -15, 15
    vmin_sat, vmax_sat = 0, 0.15
    
    fig.suptitle("Bay Groundwater Dynamics: 10-year Simulation", fontsize=16, fontweight='bold')
    
    png_files = []
    
    # Iterate through time
    for idx, tidx in enumerate(time_indices):
        # Clear axes
        for ax in [ax_dem, ax_wt, ax_sat]:
            ax.clear()
        
        t_val = time_vals[tidx]
        t_str = pd.Timestamp(t_val).strftime('%Y-%m-%d')
        
        # Panel 1: DEM + overlays
        im_dem = ax_dem.imshow(dem, origin='lower', cmap='terrain', alpha=0.7)
        ax_dem.imshow(coast_mask, origin='lower', cmap='Blues', alpha=0.3)
        ax_dem.imshow(river_mask, origin='lower', cmap='Purples', alpha=0.3)
        ax_dem.set_title(f'DEM & Bay Geometry')
        ax_dem.set_ylabel('North (km)')
        cbar_dem = plt.colorbar(im_dem, ax=ax_dem, fraction=0.046)
        cbar_dem.set_label('Elevation (m)')
        
        # Panel 2: Water table elevation change
        wt_change = ds["water_table_elev_change"].isel(time=tidx).values
        im_wt = ax_wt.imshow(wt_change, origin='lower', cmap='RdBu_r', 
                            vmin=vmin_wt, vmax=vmax_wt)
        ax_wt.set_title(f'Water Table Change (m)\n{t_str}')
        ax_wt.set_ylabel('North (km)')
        cbar_wt = plt.colorbar(im_wt, ax=ax_wt, fraction=0.046)
        cbar_wt.set_label('ΔWT (m)')
        
        # Panel 3: Above saturation (ponding)
        above_sat = ds["above_saturation_change"].isel(time=tidx).values
        im_sat = ax_sat.imshow(above_sat, origin='lower', cmap='Greens',
                              vmin=vmin_sat, vmax=vmax_sat)
        ax_sat.set_title(f'Saturation Excess / Ponding (m)')
        ax_sat.set_ylabel('North (km)')
        cbar_sat = plt.colorbar(im_sat, ax=ax_sat, fraction=0.046)
        cbar_sat.set_label('Excess (m)')
        
        # Set tick labels for spatial axes
        for ax in [ax_dem, ax_wt, ax_sat]:
            ny_pix, nx_pix = dem.shape
            ax.set_xticks(np.linspace(0, nx_pix-1, 3))
            ax.set_xticklabels([f'{i:.1f}' for i in np.linspace(0, ds.x.values[-1]/1000, 3)])
            ax.set_yticks(np.linspace(0, ny_pix-1, 3))
            ax.set_yticklabels([f'{i:.1f}' for i in np.linspace(0, ds.y.values[-1]/1000, 3)])
        
        ax_wt.set_xlabel('East (km)')
        ax_sat.set_xlabel('East (km)')
        
        # Panel 4: Coastal forcing time series
        ax_coastal.clear()
        sea_level = ds["sea_level_boundary_m"].values
        precip_total = ds["precip_total_mm_d"].values
        
        # Plot sea level with all history up to current time
        time_days = np.arange(len(sea_level))
        ax_coastal_sl = ax_coastal
        line_sl = ax_coastal_sl.plot(time_days[:tidx+1], sea_level[:tidx+1], 
                                     'b-', linewidth=1.5, label='Sea Level (m)')
        ax_coastal_sl.axvline(tidx, color='red', linestyle='--', alpha=0.5, label='Current time')
        ax_coastal_sl.set_ylabel('Sea Level Boundary (m)', color='b')
        ax_coastal_sl.tick_params(axis='y', labelcolor='b')
        ax_coastal_sl.set_xlim(0, nt)
        ax_coastal_sl.grid(True, alpha=0.3)
        
        # Twin axis for precipitation
        ax_coastal_pr = ax_coastal_sl.twinx()
        ax_coastal_pr.bar(time_days[:tidx+1], precip_total[:tidx+1], 
                         width=1, alpha=0.3, color='gray', label='Precip (mm/day)')
        ax_coastal_pr.set_ylabel('Precipitation (mm/day)', color='gray')
        ax_coastal_pr.tick_params(axis='y', labelcolor='gray')
        
        ax_coastal_sl.set_xlabel('Days since start')
        ax_coastal_sl.legend(loc='upper left', fontsize=9)
        ax_coastal_pr.legend(loc='upper right', fontsize=9)
        
        # Overall title with frame info
        fig.suptitle(f'Bay Groundwater Dynamics | {t_str} | Frame {idx+1}/{len(time_indices)}',
                    fontsize=14, fontweight='bold')
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        
        # Save PNG if requested
        if save_pngs:
            frame_num = str(idx).zfill(4)
            png_path = png_dir / f"frame_{frame_num}_{t_str}.png"
            plt.savefig(png_path, dpi=100, bbox_inches='tight', facecolor='white')
            png_files.append(str(png_path))
            print(f"  Frame {idx+1:4d}/{len(time_indices)} saved: {png_path.name}")
        
        plt.pause(sleep_time)
    
    # Create animated GIF
    if save_pngs and png_files:
        print(f"\nCreating animated GIF from {len(png_files)} frames...")
        gif_path = out_path / "groundwater_dynamics.gif"
        
        # Read all PNG files and create GIF
        images = []
        for png_file in png_files:
            img = imageio.imread(png_file)
            images.append(img)
        
        # Write GIF with duration per frame (in milliseconds)
        imageio.mimsave(gif_path, images, duration=300)  # 300ms per frame
        print(f"✓ GIF created: {gif_path}")
        print(f"  Total duration: ~{len(images) * 0.3:.1f} seconds")
    
    print(f"\nAnimation complete!")
    print(f"  Total frames rendered: {len(time_indices)}")
    print(f"  Output directory: {out_path.absolute()}")
    plt.show()



print("\nDataset created successfully!")
print(f"Time steps: {len(ds.time)} days")
print(f"Domain: {ds.x.size} x {ds.y.size} cells ({ds.x.size*100}m x {ds.y.size*100}m)")
print(f"\nVariable ranges:")
print(f"  DEM: {ds['dem_m'].min().values:.1f} to {ds['dem_m'].max().values:.1f} m")
print(f"  WT change: {ds['water_table_elev_change'].min().values:.2f} to {ds['water_table_elev_change'].max().values:.2f} m")
print(f"  Sea level range: {ds['sea_level_boundary_m'].min().values:.3f} to {ds['sea_level_boundary_m'].max().values:.3f} m")
print(f"\nSaving to 'pugetsound_synth_wt_sat.nc'...")
save_dataset(ds, "pugetsound_synth_wt_sat.nc")
print("Done!\n")



Running 3-year bay groundwater simulation...
Starting simulation: 1095 days, 50x100 grid...
  Year 1/3...
  Year 2/3...
  Year 3/3...
  Simulation complete!


ValueError: conflicting sizes for dimension 'y': length 2500 on 'tidal_zone' and length 50 on {'time': 'water_table_elev_change', 'y': 'water_table_elev_change', 'x': 'water_table_elev_change'}

In [17]:

# ---
# RUN SIMULATION
# ---

print("Running 3-year bay groundwater simulation...")
print("=" * 70)

ds = run_simulation()

Running 3-year bay groundwater simulation...
Starting simulation: 1095 days, 50x100 grid...
  Year 1/3...
  Year 2/3...
  Year 3/3...
  Simulation complete!


ValueError: conflicting sizes for dimension 'y': length 2500 on 'tidal_zone' and length 50 on {'time': 'water_table_elev_change', 'y': 'water_table_elev_change', 'x': 'water_table_elev_change'}

In [18]:

# ---
# ANIMATE WITH SLEEP MODE & SAVE FRAMES
# ---

print("=" * 70)
print("Starting animated visualization...")
print("Saving frames as PNG and creating GIF...")
print("=" * 70)

animate_simulation(
    ds, 
    frame_step=7,  # Weekly
    sleep_time=0.2,  # Faster playback
    year_subset=None,
    save_pngs=True,
    output_dir="groundwater_animation"
)

Starting animated visualization...
Saving frames as PNG and creating GIF...


NameError: name 'ds' is not defined