# CausticWaves - Pool Caustics SimulatorThis notebook simulates the beautiful "dancing light" patterns (caustics) seen at the bottom of swimming pools.## What This Notebook DoesThis is the **main working version** that generated the visualization assets in the `Output/` directory:- `caustics_simulation.mp4` - The video shown in the README- `caustics_simulation.gif` - Animated preview- `caustics_preview.png` - Static frame## How It WorksThe simulation uses **spectral synthesis** (FFT-based Tessendorf waves) combined with **geometric optics** to:1. **Generate realistic water surface ripples** using Fast Fourier Transform with proper wave dispersion2. **Trace light rays** from the sun through the water surface using Snell's Law3. **Compute caustic patterns** by accumulating refracted rays on the pool bottom4. **Apply artistic effects** including blue color scheme and tonemapping## Key Parameters (Optimized for Realism)- **Wavelength**: `lambda0 = 0.45 m` - Pool-scale waves (45 cm)- **Wave height**: `wave_rms = 0.008 m` - Natural ripple amplitude (~8mm)- **Pool depth**: `depth = 1.2 m` - Typical pool depth- **Sun angle**: `sun_elev_deg = 62°` - Mid-afternoon lighting- **Color scheme**: Blue gradient for natural water appearance## Running This NotebookSimply execute the cell below. It will:1. Generate and display two animations side-by-side2. Export a clean version (no axes) to `pool_caustics_clean.mp4`3. Run for 300 frames (10 seconds at 30 fps)**Note:** The animation uses `matplotlib` and may take a few minutes to render. Requires `ffmpeg` for video export.

In [None]:
#!/usr/bin/env python3
"""
Pool caustics simulator (geometric optics) — pool-like wavelengths and colors.

This version:
  - Uses lambda0 = 0.45 m (pool-scale waves).
  - Creates TWO animations:
      1) With axes and title (for debugging/inspection).
      2) Clean: no axes, no frame, just the caustic image.
  - Exports the clean animation to pool_caustics_clean.mp4 (requires ffmpeg).

Deps: numpy, matplotlib
Optional: scipy (for gaussian_filter). If missing, periodic FFT blur is used instead.
"""

from __future__ import annotations

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.colors import LinearSegmentedColormap


# ----------------------------
# Environment detection
# ----------------------------

def running_in_notebook() -> bool:
    try:
        from IPython import get_ipython  # type: ignore
        ip = get_ipython()
        if ip is None:
            return False
        return ip.__class__.__name__ != "TerminalInteractiveShell"
    except Exception:
        return False


# ----------------------------
# Optics helpers
# ----------------------------

def sun_direction(elev_deg: float, az_deg: float) -> np.ndarray:
    """
    Unit direction vector of sunlight rays traveling *downward* (from sun to surface).
    elev_deg: sun elevation above horizon (90 = zenith).
    az_deg: azimuth in x-y plane (0 along +x, 90 along +y).
    """
    elev = np.deg2rad(elev_deg)
    az = np.deg2rad(az_deg)
    sx = np.cos(elev) * np.cos(az)
    sy = np.cos(elev) * np.sin(az)
    sz = -np.sin(elev)  # downward
    v = np.array([sx, sy, sz], dtype=np.float64)
    return v / np.linalg.norm(v)


def refract_air_to_water(I: np.ndarray,
                         nx: np.ndarray, ny: np.ndarray, nz: np.ndarray,
                         n_air: float = 1.0, n_water: float = 1.333
                         ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Vector Snell refraction for many surface normals at once.

    I: 3-vector incident direction (constant), rays pointing downward.
    N: (nx,ny,nz) unit normals (upward).

    Returns (tx,ty,tz,cosi), where cosi = cos(theta_i) = -dot(N,I) clipped to [0,1].
    """
    sx, sy, sz = I
    cosi = -(nx * sx + ny * sy + nz * sz)
    cosi = np.clip(cosi, 0.0, 1.0)

    eta = n_air / n_water
    k = 1.0 - eta * eta * (1.0 - cosi * cosi)
    sqrtk = np.sqrt(np.maximum(k, 0.0))

    c2 = eta * cosi - sqrtk
    tx = eta * sx + c2 * nx
    ty = eta * sy + c2 * ny
    tz = eta * sz + c2 * nz
    return tx, ty, tz, cosi


def fresnel_transmission_schlick(cosi: np.ndarray, n1: float, n2: float) -> np.ndarray:
    """
    Unpolarized Fresnel transmission approximation: T ≈ 1 - R, using Schlick for R.
    """
    r0 = ((n1 - n2) / (n1 + n2)) ** 2
    R = r0 + (1.0 - r0) * (1.0 - cosi) ** 5
    return 1.0 - R


# ----------------------------
# Spectral ripple surface (band-limited, pool-like)
# ----------------------------

class SpectralRipples:
    """
    Band-limited ripple field using Tessendorf-style time evolution in Fourier space:

        H(k,t) = h0(k) e^{ i ω t } + h0*(−k) e^{ −i ω t }

    Real for all t. Spectrum is a narrow Gaussian around k0 = 2π/lambda0.
    """

    def __init__(self,
                 Nx: int = 256, Ny: int = 256,
                 Lx: float = 2.0, Ly: float = 2.0,
                 depth: float = 1.2,
                 target_rms: float = 0.008,
                 lambda0: float = 0.45,
                 k_sigma_frac: float = 0.16,
                 direction_deg: float | None = None,
                 dir_spread_deg: float = 90.0,
                 seed: int = 0):
        self.Nx, self.Ny = Nx, Ny
        self.Lx, self.Ly = Lx, Ly
        self.depth = depth
        self.dx = Lx / Nx
        self.dy = Ly / Ny
        self.g = 9.81

        rng = np.random.default_rng(seed)

        # Wavevectors on the FFT grid
        kx = 2.0 * np.pi * np.fft.fftfreq(Nx, d=self.dx)
        ky = 2.0 * np.pi * np.fft.fftfreq(Ny, d=self.dy)
        KX, KY = np.meshgrid(kx, ky)  # (Ny,Nx)
        k = np.sqrt(KX**2 + KY**2)

        # Dispersion (finite depth): ω^2 = g k tanh(kH)
        omega = np.sqrt(self.g * k * np.tanh(k * depth))
        omega[k == 0.0] = 0.0
        self.omega = omega

        # Dominant wavenumber and spectral width in k-space
        k0 = 2.0 * np.pi / max(lambda0, 1e-6)
        sigma_k = max(k_sigma_frac, 1e-6) * k0

        # Narrow band around k0
        S = np.exp(-0.5 * ((k - k0) / sigma_k) ** 2)

        # High-k suppression
        k_cut = 3.0 * k0
        S *= np.exp(-(k / max(k_cut, 1e-9)) ** 4)

        # Optional directionality (wind / push)
        if direction_deg is not None:
            theta = np.arctan2(KY, KX)
            theta0 = np.deg2rad(direction_deg)
            dtheta = np.angle(np.exp(1j * (theta - theta0)))  # wrapped diff
            sigma_th = np.deg2rad(max(dir_spread_deg, 1e-3))
            D = np.exp(-0.5 * (dtheta / sigma_th) ** 2)
            S *= D

        S[k == 0.0] = 0.0

        # Random complex Gaussian h0(k) with variance proportional to S(k)
        h0 = (rng.standard_normal((Ny, Nx)) + 1j * rng.standard_normal((Ny, Nx))) * np.sqrt(S / 2.0)

        # Build h0*(−k) lookup via index map
        iy_neg = (-np.arange(Ny)) % Ny
        ix_neg = (-np.arange(Nx)) % Nx
        h0_neg = h0[iy_neg[:, None], ix_neg[None, :]]
        h0_star_neg = np.conj(h0_neg)

        # Scale to requested RMS height at t=0
        H_hat_0 = h0 + h0_star_neg
        h_spatial_0 = np.fft.ifft2(H_hat_0).real
        rms = h_spatial_0.std()
        scale = (target_rms / rms) if rms > 0 else 1.0

        self.h0 = h0 * scale
        self.h0_star_neg = h0_star_neg * scale

        # Spatial grid
        x = np.linspace(0.0, Lx, Nx, endpoint=False)
        y = np.linspace(0.0, Ly, Ny, endpoint=False)
        self.X, self.Y = np.meshgrid(x, y)

    def height(self, t: float) -> np.ndarray:
        exp_pos = np.exp(1j * self.omega * t)
        exp_neg = np.conj(exp_pos)
        H_hat = self.h0 * exp_pos + self.h0_star_neg * exp_neg
        return np.fft.ifft2(H_hat).real


# ----------------------------
# Rendering utilities
# ----------------------------

def gaussian_blur_fft(img: np.ndarray, sigma_px: float) -> np.ndarray:
    """Periodic Gaussian blur via FFT (fallback when SciPy is unavailable)."""
    H, W = img.shape
    ky = np.fft.fftfreq(H) * 2.0 * np.pi
    kx = np.fft.fftfreq(W) * 2.0 * np.pi
    KX, KY = np.meshgrid(kx, ky)
    G = np.exp(-0.5 * (sigma_px**2) * (KX*KX + KY*KY))
    return np.fft.ifft2(np.fft.fft2(img) * G).real


def splat_bilinear(img: np.ndarray,
                   x: np.ndarray, y: np.ndarray, w: np.ndarray,
                   Lx: float, Ly: float,
                   wrap: bool) -> None:
    """
    Deposit weighted samples (x,y,w) into img using bilinear splatting.
    img shape: (H,W) corresponds to (y,x).

    wrap=True implements periodic boundary conditions.
    """
    H, W = img.shape

    # Map physical coords to pixel coords in [0,W) and [0,H)
    u = (x / Lx) * W
    v = (y / Ly) * H

    i0 = np.floor(u).astype(np.int64)
    j0 = np.floor(v).astype(np.int64)
    fu = u - i0
    fv = v - j0

    i1 = i0 + 1
    j1 = j0 + 1

    if wrap:
        i0 %= W; i1 %= W
        j0 %= H; j1 %= H
        ww = w
        fuu = fu
        fvv = fv
    else:
        m = (i0 >= 0) & (i1 < W) & (j0 >= 0) & (j1 < H)
        if not np.any(m):
            return
        i0 = i0[m]; i1 = i1[m]
        j0 = j0[m]; j1 = j1[m]
        fuu = fu[m]; fvv = fv[m]
        ww = w[m]

    w00 = ww * (1.0 - fuu) * (1.0 - fvv)
    w10 = ww * (fuu) * (1.0 - fvv)
    w01 = ww * (1.0 - fuu) * (fvv)
    w11 = ww * (fuu) * (fvv)

    np.add.at(img, (j0, i0), w00)
    np.add.at(img, (j0, i1), w10)
    np.add.at(img, (j1, i0), w01)
    np.add.at(img, (j1, i1), w11)


def caustics_frame(surface: SpectralRipples,
                   t: float,
                   bottom_res: int = 520,
                   sun_elev_deg: float = 62.0,
                   sun_az_deg: float = 25.0,
                   n_air: float = 1.0,
                   n_water: float = 1.333,
                   exposure: float = 70.0,
                   gamma: float = 0.80,
                   wrap: bool = True,
                   blur_sigma_px: float | None = 1.5
                   ) -> np.ndarray:
    """
    Compute a tonemapped caustic image on the pool floor for time t.
    Returns float image in [0,1].
    """
    h = surface.height(t)
    X, Y = surface.X, surface.Y
    dx, dy = surface.dx, surface.dy
    Lx, Ly = surface.Lx, surface.Ly
    Hdepth = surface.depth

    # Surface slopes
    hy, hx = np.gradient(h, dy, dx, edge_order=2)

    # Unit normal (upward)
    denom = np.sqrt(1.0 + hx*hx + hy*hy)
    nx = -hx / denom
    ny = -hy / denom
    nz =  1.0 / denom

    I = sun_direction(sun_elev_deg, sun_az_deg)

    # Refract
    tx, ty, tz, cosi = refract_air_to_water(I, nx, ny, nz, n_air=n_air, n_water=n_water)

    # Downward-going rays in water
    mask = tz < -1e-8
    if not np.any(mask):
        return np.zeros((bottom_res, bottom_res), dtype=np.float64)

    # Intersect bottom plane z = -Hdepth
    s = (-Hdepth - h) / tz
    bx = X + s * tx
    by = Y + s * ty

    if wrap:
        bx = np.mod(bx, Lx)
        by = np.mod(by, Ly)

    # Ray power weights: Fresnel transmission + projected flux
    T = fresnel_transmission_schlick(cosi, n_air, n_water)
    w = T * cosi / np.maximum(nz, 1e-8)

    bx1 = bx[mask].ravel()
    by1 = by[mask].ravel()
    w1  = w[mask].ravel()

    # Accumulate on bottom with bilinear splatting
    img = np.zeros((bottom_res, bottom_res), dtype=np.float64)
    splat_bilinear(img, bx1, by1, w1, Lx=Lx, Ly=Ly, wrap=wrap)

    # Optional blur: finite sun disk / regularization
    if blur_sigma_px is not None and blur_sigma_px > 0:
        try:
            from scipy.ndimage import gaussian_filter  # type: ignore
            img = gaussian_filter(img, blur_sigma_px, mode="wrap" if wrap else "nearest")
        except Exception:
            img = gaussian_blur_fft(img, blur_sigma_px)

    # Tonemap to [0,1]
    img = img / (img.mean() + 1e-12)
    img = np.log1p(exposure * img)
    img = img / (img.max() + 1e-12)
    img = img ** gamma
    return img


def make_water_cmap() -> LinearSegmentedColormap:
    """
    Create a simple water-like colormap:
      dark blue -> deep aqua -> turquoise -> pale cyan -> white.
    """
    colors = [
        (0.0,  (0.01, 0.05, 0.12)),   # very dark blue
        (0.25, (0.02, 0.20, 0.40)),   # deep blue
        (0.50, (0.00, 0.45, 0.70)),   # turquoise
        (0.75, (0.60, 0.90, 0.95)),   # pale cyan
        (1.0,  (1.00, 1.00, 1.00)),   # white highlights
    ]
    return LinearSegmentedColormap.from_list("pool_water", colors)


def display_animation(fig, ani) -> None:
    if running_in_notebook():
        import matplotlib as mpl
        mpl.rcParams["animation.embed_limit"] = max(
            float(mpl.rcParams.get("animation.embed_limit", 20)), 120.0
        )
        from IPython.display import HTML, display  # type: ignore
        plt.close(fig)
        try:
            html = ani.to_html5_video()
        except Exception:
            html = ani.to_jshtml()
        display(HTML(html))
    else:
        plt.show()


# ----------------------------
# Demo / animation
# ----------------------------

def main():
    # Geometry / discretization
    Nx = Ny = 256
    Lx = Ly = 2.0
    depth = 1.2

    # Output resolution
    if running_in_notebook():
        bottom_res = 520
        frames = 180
    else:
        bottom_res = 700
        frames = 300

    # Sun / optics
    sun_elev_deg = 62.0
    sun_az_deg = 25.0
    n_air = 1.0
    n_water = 1.333

    # Surface: lambda0 = 0.45 m as you chose
    lambda0 = 0.45        # dominant wavelength [m]
    k_sigma_frac = 0.16   # spectral width in k-space
    wave_rms = 0.008      # RMS height [m] ~ 8 mm
    direction_deg = None  # or e.g. 0.0 to bias along +x
    dir_spread_deg = 90.0

    # Caustics tonemap / regularization
    exposure = 70.0
    gamma = 0.80
    blur_sigma_px = 1.5

    # Animation
    fps = 30
    dt = 1.0 / fps

    surface = SpectralRipples(
        Nx=Nx, Ny=Ny, Lx=Lx, Ly=Ly, depth=depth,
        target_rms=wave_rms,
        lambda0=lambda0,
        k_sigma_frac=k_sigma_frac,
        direction_deg=direction_deg,
        dir_spread_deg=dir_spread_deg,
        seed=3
    )

    water_cmap = make_water_cmap()

    # ----------------- Animation 1: with axes / title -----------------
    fig1, ax1 = plt.subplots()
    ax1.set_xlabel("x [m]")
    ax1.set_ylabel("y [m]")

    img0 = caustics_frame(
        surface, 0.0,
        bottom_res=bottom_res,
        sun_elev_deg=sun_elev_deg, sun_az_deg=sun_az_deg,
        n_air=n_air, n_water=n_water,
        exposure=exposure, gamma=gamma,
        wrap=True,
        blur_sigma_px=blur_sigma_px
    )

    im1 = ax1.imshow(
        img0,
        origin="lower",
        extent=(0.0, Lx, 0.0, Ly),
        interpolation="nearest",
        cmap=water_cmap,
        vmin=0.0, vmax=1.0
    )
    ax1.set_title("Pool caustics (with axes)")

    def update1(k: int):
        t = k * dt
        img = caustics_frame(
            surface, t,
            bottom_res=bottom_res,
            sun_elev_deg=sun_elev_deg, sun_az_deg=sun_az_deg,
            n_air=n_air, n_water=n_water,
            exposure=exposure, gamma=gamma,
            wrap=True,
            blur_sigma_px=blur_sigma_px
        )
        im1.set_data(img)
        ax1.set_title(
            f"Pool caustics (with axes) | t = {t:7.3f} s  |  λ0={lambda0:.2f} m"
        )
        return (im1,)

    ani1 = FuncAnimation(fig1, update1, frames=frames, interval=1000/fps, blit=False)

    # ----------------- Animation 2: clean image only -----------------
    fig2, ax2 = plt.subplots(figsize=(4, 4), dpi=200)
    ax2.set_axis_off()
    # remove margins, so only image area
    fig2.subplots_adjust(left=0, right=1, bottom=0, top=1)

    im2 = ax2.imshow(
        img0,
        origin="lower",
        extent=(0.0, Lx, 0.0, Ly),
        interpolation="nearest",
        cmap=water_cmap,
        vmin=0.0, vmax=1.0
    )

    def update2(k: int):
        t = k * dt
        img = caustics_frame(
            surface, t,
            bottom_res=bottom_res,
            sun_elev_deg=sun_elev_deg, sun_az_deg=sun_az_deg,
            n_air=n_air, n_water=n_water,
            exposure=exposure, gamma=gamma,
            wrap=True,
            blur_sigma_px=blur_sigma_px
        )
        im2.set_data(img)
        return (im2,)

    ani2 = FuncAnimation(fig2, update2, frames=frames, interval=1000/fps, blit=False)

    # ---- Export clean animation without axes ----
    # Requires ffmpeg to be installed and on PATH.
    from matplotlib.animation import FFMpegWriter
    writer = FFMpegWriter(fps=fps, bitrate=4000)
    ani2.save("pool_caustics_clean.mp4", writer=writer, dpi=200)

    # For interactive use: show the "with axes" version
    display_animation(fig1, ani1)


if __name__ == "__main__":
    main()

---## Output FilesAfter running the cell above, you'll find:- **Video**: `Output/pool_caustics_clean.mp4` (clean version without axes)- **Preview animations**: Check the matplotlib windows for real-time visualization## Understanding the CodeThe code above contains:1. **Physics Functions** (lines 1-125):   - `sun_direction()` - Calculates sun's 3D direction vector   - `refract_air_to_water()` - Implements Snell's Law refraction   - `fresnel_transmission_schlick()` - Approximates light transmission2. **SpectralRipples Class** (lines 127-225):   - Generates water surface using FFT synthesis   - Implements dispersion relation: ω² = g·k·tanh(k·H)   - Ensures real-valued height fields via Hermitian symmetry3. **Caustics Rendering** (lines 227-340):   - `caustics_frame()` - Main pipeline for single frame   - Computes surface normals, refracts rays, accumulates hits   - Applies tonemapping and blue color grading4. **Animation Setup** (lines 342-501):   - Creates two side-by-side plots (with/without axes)   - Configures FFMpeg video export   - Runs animation loop at 30 fps## Next Steps- To experiment with different parameters, modify the values in the configuration section- To change the wavelength, adjust `lambda0` (try values between 0.2 - 0.6)- To change colors, modify the colormap (currently using "Blues")- See `simulation.py` for a standalone version with detailed comments## Scientific BackgroundThis simulation implements:- **Tessendorf FFT waves** (SIGGRAPH 2001) for realistic water surfaces- **Geometric optics** with Snell's Law and Fresnel equations- **Accumulation buffer rendering** for caustic intensity computation- **Logarithmic tonemapping** for high dynamic range visualization