# Generate Documentation Figures

This notebook generates all SVG figures for the janssen documentation guides.

Run all cells to regenerate figures, or run individual sections as needed.

In [None]:
import os
from pathlib import Path

os.environ["JAX_PLATFORMS"] = "cpu"

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Figure settings
FIGSIZE = (8, 6)
FIGSIZE_WIDE = (10, 6)
FIGSIZE_TALL = (8, 8)
SAVE_DIR = Path(".")  # Save to current directory

plt.style.use("default")
plt.rcParams["figure.dpi"] = 100
plt.rcParams["savefig.dpi"] = 150


def save_fig(name: str) -> None:
    """Save current figure as SVG and close it."""
    if not name.endswith(".svg"):
        name = name.rsplit(".", 1)[0] + ".svg"
    plt.savefig(
        SAVE_DIR / name,
        format="svg",
        bbox_inches="tight",
        facecolor="white",
        edgecolor="none",
    )
    plt.close()
    print(f"Saved: {name}")

## 1. Coherence Figures

Figures for `coherence.md` guide.

In [None]:
# Figure 1.1: Spatial coherence kernel
from janssen.coherence import gaussian_coherence_kernel, jinc_coherence_kernel

dx = 1e-6
grid_size = (128, 128)
wavelength = 532e-9

# Generate kernels
gaussian_kernel = gaussian_coherence_kernel(
    coherence_width=20e-6, dx=dx, grid_size=grid_size
)
jinc_kernel = jinc_coherence_kernel(
    source_diameter=50e-6,
    propagation_distance=0.1,
    wavelength=wavelength,
    dx=dx,
    grid_size=grid_size,
)

# Plot 1D cross-section
fig, ax = plt.subplots(figsize=FIGSIZE)
x = (np.arange(grid_size[1]) - grid_size[1] // 2) * dx * 1e6
center = grid_size[0] // 2

ax.plot(x, np.abs(gaussian_kernel[center, :]), "b-", lw=2, label="Gaussian")
ax.plot(x, np.abs(jinc_kernel[center, :]), "r--", lw=2, label="Jinc (circular)")
ax.set_xlabel(r"$\Delta r$ ($\mu$m)", fontsize=12)
ax.set_ylabel(r"$|\mu(\Delta r)|$", fontsize=12)
ax.set_title("Spatial Coherence Kernels", fontsize=14)
ax.legend(fontsize=11)
ax.set_xlim(-60, 60)
ax.grid(True, alpha=0.3)
save_fig("spatial_coherence_kernel.svg")

In [None]:
# Figure 1.2: Temporal coherence and spectra
from janssen.coherence import gaussian_spectrum, lorentzian_spectrum, coherence_length

center_wl = 532e-9
bandwidths = [1e-9, 5e-9, 20e-9]  # 1nm, 5nm, 20nm FWHM

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=FIGSIZE_WIDE)

colors = ["blue", "green", "red"]
for bw, color in zip(bandwidths, colors):
    wls, weights = gaussian_spectrum(
        center_wavelength=center_wl, bandwidth_fwhm=bw, num_wavelengths=101
    )
    ax1.plot(
        np.array(wls) * 1e9,
        np.array(weights) / np.max(weights),
        color=color,
        lw=2,
        label=f"{bw*1e9:.0f} nm FWHM",
    )

ax1.set_xlabel("Wavelength (nm)", fontsize=12)
ax1.set_ylabel("Normalized S(λ)", fontsize=12)
ax1.set_title("Spectral Distributions", fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Coherence length vs bandwidth
bw_range = np.linspace(0.5e-9, 50e-9, 100)
lc_values = [float(coherence_length(center_wl, bw)) * 1e6 for bw in bw_range]

ax2.semilogy(bw_range * 1e9, lc_values, "b-", lw=2)
ax2.set_xlabel("Bandwidth FWHM (nm)", fontsize=12)
ax2.set_ylabel(r"Coherence Length $L_c$ ($\mu$m)", fontsize=12)
ax2.set_title("Coherence Length vs Bandwidth", fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
save_fig("temporal_coherence_spectrum.svg")

In [None]:
# Figure 1.3: Coherent mode decomposition
from janssen.coherence import gaussian_schell_model_modes

mode_set = gaussian_schell_model_modes(
    wavelength=532e-9,
    dx=1e-6,
    grid_size=(128, 128),
    beam_width=50e-6,
    coherence_width=15e-6,
    num_modes=6,
)

fig, axes = plt.subplots(2, 3, figsize=FIGSIZE_WIDE)
axes = axes.flatten()

for i, ax in enumerate(axes):
    mode_intensity = np.abs(mode_set.modes[i]) ** 2
    im = ax.imshow(mode_intensity, cmap="viridis")
    ax.set_title(f"Mode {i+1}, w={float(mode_set.weights[i]):.3f}", fontsize=10)
    ax.axis("off")

plt.suptitle("Coherent Mode Decomposition (GSM Source)", fontsize=14)
plt.tight_layout()
save_fig("coherent_mode_decomposition.svg")

In [None]:
# Figure 1.4: LED vs Laser modes comparison
from janssen.coherence import laser_with_mode_noise

# Laser with high purity (nearly coherent)
laser_modes = laser_with_mode_noise(
    wavelength=632.8e-9,
    dx=1e-6,
    grid_size=(128, 128),
    beam_waist=30e-6,
    mode_purity=0.95,
    num_modes=5,
)

# LED-like (low purity)
led_modes = gaussian_schell_model_modes(
    wavelength=530e-9,
    dx=1e-6,
    grid_size=(128, 128),
    beam_width=100e-6,
    coherence_width=10e-6,
    num_modes=5,
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=FIGSIZE_WIDE)

# Bar plots of weights
x = np.arange(5)
ax1.bar(x, np.array(laser_modes.weights[:5]), color="red", alpha=0.7)
ax1.set_xlabel("Mode Index", fontsize=12)
ax1.set_ylabel("Weight", fontsize=12)
ax1.set_title(f"Laser (N_eff={float(laser_modes.effective_mode_count):.2f})", fontsize=12)
ax1.set_ylim(0, 1)

ax2.bar(x, np.array(led_modes.weights[:5]), color="green", alpha=0.7)
ax2.set_xlabel("Mode Index", fontsize=12)
ax2.set_ylabel("Weight", fontsize=12)
ax2.set_title(f"LED (N_eff={float(led_modes.effective_mode_count):.2f})", fontsize=12)
ax2.set_ylim(0, 1)

plt.suptitle("Mode Weight Distributions", fontsize=14)
plt.tight_layout()
save_fig("led_vs_laser_modes.svg")

## 2. Propagation Figures

Figures for `propagation.md` guide.

In [None]:
# Figure 2.1: Propagation regimes diagram
fig, ax = plt.subplots(figsize=FIGSIZE)

# Fresnel number regions
nf = np.logspace(-2, 2, 100)
ax.axvspan(0.01, 0.1, alpha=0.3, color="blue", label="Far-field (Fraunhofer)")
ax.axvspan(0.1, 1, alpha=0.3, color="green", label="Intermediate")
ax.axvspan(1, 100, alpha=0.3, color="red", label="Near-field (Fresnel)")

ax.axvline(x=1, color="black", ls="--", lw=2)
ax.text(1.1, 0.5, r"$N_F = 1$", fontsize=12, transform=ax.get_xaxis_transform())

ax.set_xscale("log")
ax.set_xlim(0.01, 100)
ax.set_xlabel(r"Fresnel Number $N_F = a^2/(\lambda z)$", fontsize=12)
ax.set_yticks([])
ax.set_title("Propagation Regimes", fontsize=14)
ax.legend(loc="upper right", fontsize=10)
save_fig("propagation_regimes.svg")

In [None]:
# Figure 2.2: Angular spectrum method diagram
fig, axes = plt.subplots(1, 4, figsize=(12, 3))

# Create sample field
x = np.linspace(-1, 1, 64)
y = np.linspace(-1, 1, 64)
X, Y = np.meshgrid(x, y)
field = np.exp(-(X**2 + Y**2) / 0.2) * np.exp(1j * 2 * np.pi * X)

# Step 1: Input field
axes[0].imshow(np.abs(field), cmap="hot")
axes[0].set_title("1. Input E(x,y)", fontsize=10)
axes[0].axis("off")

# Step 2: FFT
spectrum = np.fft.fftshift(np.fft.fft2(field))
axes[1].imshow(np.log10(np.abs(spectrum) + 1e-6), cmap="viridis")
axes[1].set_title(r"2. $\mathcal{F}\{E\} = A(k_x, k_y)$", fontsize=10)
axes[1].axis("off")

# Step 3: Multiply by H
kx = np.fft.fftshift(np.fft.fftfreq(64, d=2/64))
KX, KY = np.meshgrid(kx, kx)
H = np.exp(1j * 2 * np.pi * np.sqrt(np.maximum(1 - KX**2 - KY**2, 0)))
propagated_spectrum = spectrum * H
axes[2].imshow(np.angle(H), cmap="twilight")
axes[2].set_title(r"3. Multiply by $H(k_x, k_y)$", fontsize=10)
axes[2].axis("off")

# Step 4: IFFT
output = np.fft.ifft2(np.fft.ifftshift(propagated_spectrum))
axes[3].imshow(np.abs(output), cmap="hot")
axes[3].set_title(r"4. Output $\mathcal{F}^{-1}$", fontsize=10)
axes[3].axis("off")

plt.suptitle("Angular Spectrum Propagation", fontsize=12)
plt.tight_layout()
save_fig("angular_spectrum_method.svg")

In [None]:
# Figure 2.3: Fresnel number validity
fig, ax = plt.subplots(figsize=FIGSIZE)

wavelength = 500e-9  # 500 nm
apertures = [10e-6, 50e-6, 100e-6, 500e-6]  # Various aperture sizes
z = np.logspace(-5, -1, 100)  # 10 um to 10 cm

for a in apertures:
    Nf = a**2 / (wavelength * z)
    ax.loglog(z * 1e3, Nf, lw=2, label=f"a = {a*1e6:.0f} μm")

ax.axhline(y=1, color="black", ls="--", lw=1.5, label=r"$N_F = 1$")
ax.axhline(y=0.1, color="gray", ls=":", lw=1.5, label=r"$N_F = 0.1$")

ax.fill_between([1e-2, 1e2], 0.1, 1e-3, alpha=0.2, color="blue")
ax.text(1, 0.02, "Far-field", fontsize=10)

ax.set_xlabel("Propagation Distance z (mm)", fontsize=12)
ax.set_ylabel(r"Fresnel Number $N_F$", fontsize=12)
ax.set_title("Fresnel Number vs Distance", fontsize=14)
ax.legend(loc="upper right", fontsize=9)
ax.set_xlim(1e-2, 1e2)
ax.set_ylim(1e-3, 1e4)
ax.grid(True, alpha=0.3, which="both")
save_fig("fresnel_number_diagram.svg")

In [None]:
# Figure 2.4: Lens propagation model
fig, ax = plt.subplots(figsize=FIGSIZE_WIDE)

# Draw optical axis
ax.axhline(y=0, color="gray", ls="-", lw=0.5)

# Lens position
lens_x = 0.5
lens_height = 0.3
ax.plot([lens_x, lens_x], [-lens_height, lens_height], "b-", lw=3)
ax.annotate("", xy=(lens_x-0.02, lens_height), xytext=(lens_x+0.02, lens_height),
            arrowprops=dict(arrowstyle="<->", color="blue", lw=2))
ax.annotate("", xy=(lens_x-0.02, -lens_height), xytext=(lens_x+0.02, -lens_height),
            arrowprops=dict(arrowstyle="<->", color="blue", lw=2))

# Object plane
ax.axvline(x=0, color="green", ls="--", lw=1.5)
ax.text(0.02, 0.35, "Object\nplane", fontsize=10, color="green")

# Image plane
ax.axvline(x=1, color="red", ls="--", lw=1.5)
ax.text(0.92, 0.35, "Image\nplane", fontsize=10, color="red")

# Rays
for y0 in [0.2, 0.1, 0]:
    # To lens
    ax.plot([0, lens_x], [y0, y0*0.3], "r-", lw=1, alpha=0.7)
    # From lens to focus
    ax.plot([lens_x, 1], [y0*0.3, -y0], "r-", lw=1, alpha=0.7)

# Labels
ax.annotate("", xy=(0, -0.35), xytext=(lens_x, -0.35),
            arrowprops=dict(arrowstyle="<->", color="black"))
ax.text(0.25, -0.4, r"$d_1$", fontsize=12, ha="center")

ax.annotate("", xy=(lens_x, -0.35), xytext=(1, -0.35),
            arrowprops=dict(arrowstyle="<->", color="black"))
ax.text(0.75, -0.4, r"$d_2$", fontsize=12, ha="center")

ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-0.5, 0.5)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Thin Lens Propagation Model", fontsize=14)
save_fig("lens_propagation.svg")

## 3. Ptychography Figures

Figures for `ptychography.md` guide.

In [None]:
# Figure 3.1: Ptychography scanning geometry
fig, ax = plt.subplots(figsize=FIGSIZE)

# Sample (gray rectangle)
sample = plt.Rectangle((0, 0), 1, 1, fill=True, facecolor="lightgray", 
                        edgecolor="black", lw=2)
ax.add_patch(sample)

# Scan positions with overlapping circles
scan_positions = [(0.2, 0.2), (0.4, 0.2), (0.6, 0.2), (0.8, 0.2),
                  (0.2, 0.4), (0.4, 0.4), (0.6, 0.4), (0.8, 0.4),
                  (0.2, 0.6), (0.4, 0.6), (0.6, 0.6), (0.8, 0.6),
                  (0.2, 0.8), (0.4, 0.8), (0.6, 0.8), (0.8, 0.8)]

for i, (x, y) in enumerate(scan_positions):
    circle = plt.Circle((x, y), 0.15, fill=True, facecolor="red", 
                        alpha=0.3, edgecolor="red", lw=1)
    ax.add_patch(circle)
    ax.plot(x, y, "r.", markersize=5)

ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-0.1, 1.1)
ax.set_aspect("equal")
ax.set_xlabel("x position", fontsize=12)
ax.set_ylabel("y position", fontsize=12)
ax.set_title("Ptychography Scan Pattern (60% overlap)", fontsize=14)
save_fig("ptychography_geometry.svg")

In [None]:
# Figure 3.2: ePIE algorithm diagram
fig, ax = plt.subplots(figsize=FIGSIZE_WIDE)

# Boxes for each step
steps = [
    (0.1, 0.5, "Exit Wave\n$\\psi = P \\cdot O$"),
    (0.3, 0.5, "Propagate\n$\\Psi = \\mathcal{F}\\{\\psi\\}$"),
    (0.5, 0.5, "Constraint\n$|\\Psi'| = \\sqrt{I}$"),
    (0.7, 0.5, "Back-prop\n$\\psi' = \\mathcal{F}^{-1}$"),
    (0.9, 0.5, "Update\nO, P"),
]

for x, y, text in steps:
    box = plt.Rectangle((x-0.08, y-0.15), 0.16, 0.3, fill=True,
                        facecolor="lightblue", edgecolor="blue", lw=2)
    ax.add_patch(box)
    ax.text(x, y, text, ha="center", va="center", fontsize=9)

# Arrows
for i in range(len(steps)-1):
    ax.annotate("", xy=(steps[i+1][0]-0.08, steps[i+1][1]),
               xytext=(steps[i][0]+0.08, steps[i][1]),
               arrowprops=dict(arrowstyle="->", color="black", lw=1.5))

# Loop arrow
ax.annotate("", xy=(0.1, 0.25), xytext=(0.9, 0.25),
           arrowprops=dict(arrowstyle="->", color="gray", lw=1.5,
                          connectionstyle="arc3,rad=-0.3"))
ax.text(0.5, 0.1, "Next position", ha="center", fontsize=10, color="gray")

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
ax.set_title("ePIE Algorithm Workflow", fontsize=14)
save_fig("epie_algorithm.svg")

In [None]:
# Figure 3.3: Convergence comparison
fig, ax = plt.subplots(figsize=FIGSIZE)

# Simulated convergence curves
iterations = np.arange(100)
epie_loss = 1.0 * np.exp(-iterations / 15) + 0.01 * np.random.randn(100) * np.exp(-iterations/30)
epie_loss = np.maximum(epie_loss, 0.01)

grad_loss = 1.0 * np.exp(-iterations / 25) + 0.005 * np.random.randn(100) * np.exp(-iterations/40)
grad_loss = np.maximum(grad_loss, 0.008)

ax.semilogy(iterations, epie_loss, "b-", lw=2, label="ePIE", alpha=0.8)
ax.semilogy(iterations, grad_loss, "r-", lw=2, label="Gradient descent", alpha=0.8)

ax.set_xlabel("Iteration", fontsize=12)
ax.set_ylabel("Loss", fontsize=12)
ax.set_title("Convergence Comparison", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 100)
save_fig("gradient_descent_comparison.svg")

In [None]:
# Figure 3.4: Loss function convergence curves
fig, ax = plt.subplots(figsize=FIGSIZE)

iterations = np.arange(200)
np.random.seed(42)

amplitude_loss = 1.0 * np.exp(-iterations / 30) + 0.005
intensity_loss = 1.0 * np.exp(-iterations / 50) * (1 + 0.1*np.sin(iterations/5)) + 0.01
poisson_loss = 1.0 * np.exp(-iterations / 40) + 0.008

ax.semilogy(iterations, amplitude_loss, "b-", lw=2, label="Amplitude loss")
ax.semilogy(iterations, intensity_loss, "orange", lw=2, label="Intensity loss")
ax.semilogy(iterations, poisson_loss, "g-", lw=2, label="Poisson loss")

ax.set_xlabel("Iteration", fontsize=12)
ax.set_ylabel("Normalized Loss", fontsize=12)
ax.set_title("Loss Function Comparison", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
save_fig("reconstruction_convergence.svg")

## 4. Zernike Figures

Figures for `zernike.md` guide.

In [None]:
# Figure 4.1: Zernike polynomial pyramid
from janssen.optics import zernike_polynomial, noll_to_nm

fig, axes = plt.subplots(3, 5, figsize=(12, 7))

# Create polar coordinates
N = 128
x = np.linspace(-1, 1, N)
y = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, y)
rho = np.sqrt(X**2 + Y**2)
theta = np.arctan2(Y, X)
mask = rho <= 1

noll_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
names = ["Piston", "Tilt X", "Tilt Y", "Defocus", "Astig 45°", 
         "Astig 0°", "Coma Y", "Coma X", "Trefoil Y", "Trefoil X",
         "Spherical", "2nd Astig", "2nd Astig", "2nd Coma", "2nd Coma"]

for idx, (ax, j) in enumerate(zip(axes.flatten(), noll_indices)):
    n, m = noll_to_nm(j)
    Z = np.array(zernike_polynomial(j, jnp.array(rho), jnp.array(theta)))
    Z[~mask] = np.nan
    
    im = ax.imshow(Z, cmap="RdBu_r", extent=[-1, 1, -1, 1], vmin=-1, vmax=1)
    ax.set_title(f"Z{j}: {names[idx]}", fontsize=8)
    ax.axis("off")

plt.suptitle("Zernike Polynomials (Noll Indexing)", fontsize=14)
plt.tight_layout()
save_fig("zernike_pyramid.svg")

In [None]:
# Figure 4.2: Common aberrations
from janssen.optics import defocus, astigmatism, coma, spherical_aberration

fig, axes = plt.subplots(1, 4, figsize=(12, 3))

N = 128
x = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, x)
rho = np.sqrt(X**2 + Y**2)
theta = np.arctan2(Y, X)
mask = rho <= 1

aberrations = [
    (defocus(jnp.array(rho), 1.0), "Defocus"),
    (astigmatism(jnp.array(rho), jnp.array(theta), 1.0, 0.0), "Astigmatism"),
    (coma(jnp.array(rho), jnp.array(theta), 1.0, 0.0), "Coma"),
    (spherical_aberration(jnp.array(rho), 1.0), "Spherical"),
]

for ax, (phase, title) in zip(axes, aberrations):
    phase_np = np.array(phase)
    phase_np[~mask] = np.nan
    im = ax.imshow(phase_np, cmap="RdBu_r", extent=[-1, 1, -1, 1])
    ax.set_title(title, fontsize=12)
    ax.axis("off")
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle("Common Optical Aberrations (Phase in radians)", fontsize=14)
plt.tight_layout()
save_fig("common_aberrations.svg")

In [None]:
# Figure 4.3: Noll indexing diagram
fig, ax = plt.subplots(figsize=FIGSIZE)

# Draw grid of (n, m) with Noll indices
from janssen.optics import nm_to_noll

max_n = 5
for n in range(max_n + 1):
    for m in range(-n, n + 1, 2):
        j = int(nm_to_noll(n, m))
        ax.scatter(m, n, s=800, c="lightblue", edgecolor="blue", linewidth=2, zorder=2)
        ax.text(m, n, str(j), ha="center", va="center", fontsize=10, fontweight="bold")

ax.set_xlabel("Azimuthal index m", fontsize=12)
ax.set_ylabel("Radial index n", fontsize=12)
ax.set_title("Noll Index Ordering", fontsize=14)
ax.set_xlim(-6, 6)
ax.set_ylim(-0.5, 5.5)
ax.set_xticks(range(-5, 6))
ax.set_yticks(range(6))
ax.grid(True, alpha=0.3)
ax.set_aspect("equal")
save_fig("noll_indexing.svg")

In [None]:
# Figure 4.4: Aberration effects on PSF
from janssen.optics import generate_aberration_noll, circular_aperture
from janssen.prop import angular_spectrum
from janssen.utils import make_optical_wavefront

wavelength = 632.8e-9
dx = 5e-6
N = 256

# Create pupil
pupil = circular_aperture(grid_size=(N, N), dx=dx, diameter=1e-3)

aberration_cases = [
    ({}, "Perfect"),
    ({4: 0.5}, "Defocus"),
    ({6: 0.5}, "Astigmatism"),
    ({8: 0.5}, "Coma"),
    ({11: 0.3}, "Spherical"),
]

fig, axes = plt.subplots(1, 5, figsize=(14, 3))

for ax, (coeffs, title) in zip(axes, aberration_cases):
    if coeffs:
        phase = generate_aberration_noll(
            coefficients=coeffs,
            grid_size=(N, N),
            pupil_radius=0.5e-3,
            wavelength=wavelength,
        )
        field = pupil * jnp.exp(1j * phase)
    else:
        field = pupil.astype(jnp.complex128)
    
    # Compute PSF via Fourier transform
    psf = jnp.abs(jnp.fft.fftshift(jnp.fft.fft2(field)))**2
    psf = psf / jnp.max(psf)
    
    # Show center region
    c = N // 2
    w = 30
    ax.imshow(np.array(psf[c-w:c+w, c-w:c+w])**0.3, cmap="hot")
    ax.set_title(title, fontsize=11)
    ax.axis("off")

plt.suptitle("Effect of Aberrations on PSF", fontsize=14)
plt.tight_layout()
save_fig("aberration_effects.svg")

## 5. Vector Optics Figures

Figures for `vector-optics.md` guide.

In [None]:
# Figure 5.1: Richards-Wolf geometry
fig, ax = plt.subplots(figsize=FIGSIZE)

# Draw lens
theta_max = np.pi / 3  # 60 degrees (NA ~ 0.87)
lens_angles = np.linspace(-theta_max, theta_max, 50)
lens_r = 1.0
lens_x = lens_r * np.sin(lens_angles)
lens_y = -lens_r * np.cos(lens_angles) + lens_r
ax.plot(lens_x, lens_y, "b-", lw=3)

# Draw rays
for angle in np.linspace(-theta_max*0.8, theta_max*0.8, 7):
    x_lens = lens_r * np.sin(angle)
    y_lens = -lens_r * np.cos(angle) + lens_r
    ax.plot([x_lens, 0], [y_lens, lens_r + 0.5], "r-", lw=1, alpha=0.6)

# Focus point
ax.scatter([0], [lens_r + 0.5], s=100, c="red", zorder=5)
ax.text(0.1, lens_r + 0.5, "Focus", fontsize=10)

# Angle annotation
arc = np.linspace(np.pi/2, np.pi/2 - theta_max, 20)
ax.plot(0.3 * np.cos(arc), 0.3 * np.sin(arc) + lens_r + 0.5, "k-", lw=1)
ax.text(0.2, lens_r + 0.7, r"$\theta_{max}$", fontsize=11)

ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-0.5, 2)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title("Richards-Wolf Focusing Geometry", fontsize=14)
save_fig("debye_wolf_geometry.svg")

In [None]:
# Figure 5.2: Focal field components (simulated)
# This is a simplified visualization - actual computation would use vector_focusing

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

N = 64
x = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, x)
R = np.sqrt(X**2 + Y**2)

# Approximate focal field components for x-polarized input at high NA
# |Ex|^2 - main component, slightly elongated
Ex = np.exp(-R**2 / 0.15) * (1 + 0.1 * (X**2 - Y**2))

# |Ey|^2 - four-fold pattern at corners
Ey = 0.1 * np.exp(-R**2 / 0.15) * (X * Y)**2

# |Ez|^2 - two lobes along x
Ez = 0.3 * np.exp(-R**2 / 0.15) * X**2

for ax, (field, title) in zip(axes, [(Ex, r"$|E_x|^2$"), (Ey, r"$|E_y|^2$"), (Ez, r"$|E_z|^2$")]):
    im = ax.imshow(field, cmap="hot", extent=[-1, 1, -1, 1])
    ax.set_title(title, fontsize=14)
    ax.set_xlabel(r"x ($\lambda$/NA)", fontsize=10)
    ax.set_ylabel(r"y ($\lambda$/NA)", fontsize=10)
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle("Focal Field Components (NA=0.9, x-polarized)", fontsize=14)
plt.tight_layout()
save_fig("focal_field_components.svg")

In [None]:
# Figure 5.3: Polarization modes at focus
fig, axes = plt.subplots(1, 4, figsize=(14, 3.5))

N = 64
x = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, x)
R = np.sqrt(X**2 + Y**2)
Phi = np.arctan2(Y, X)

# Approximate focal patterns
linear = np.exp(-R**2 / 0.12) * (1 + 0.15 * np.cos(2*Phi))  # Slightly elongated
circular = np.exp(-R**2 / 0.1)  # Symmetric
radial = np.exp(-(R-0.3)**2 / 0.05) + 0.5*np.exp(-R**2 / 0.02)  # Donut + center
azimuthal = np.exp(-(R-0.35)**2 / 0.04)  # Pure donut

patterns = [(linear, "Linear"), (circular, "Circular"), 
            (radial, "Radial"), (azimuthal, "Azimuthal")]

for ax, (pattern, title) in zip(axes, patterns):
    im = ax.imshow(pattern, cmap="hot", extent=[-1, 1, -1, 1])
    ax.set_title(title, fontsize=12)
    ax.axis("off")

plt.suptitle("Focal Intensity for Different Polarizations (NA=0.9)", fontsize=14)
plt.tight_layout()
save_fig("polarization_modes.svg")

In [None]:
# Figure 5.4: Apodization effects
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

N = 64
x = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, x)
R = np.sqrt(X**2 + Y**2)

# Pupil functions
uniform = (R <= 1).astype(float)
aplanatic = np.sqrt(np.maximum(1 - R**2, 0)) * (R <= 1)
gaussian = np.exp(-R**2 / 0.5) * (R <= 1)

# PSFs (simplified - actual would use FFT)
def simple_psf(pupil):
    ft = np.fft.fftshift(np.fft.fft2(pupil))
    psf = np.abs(ft)**2
    return psf / np.max(psf)

psfs = [simple_psf(p) for p in [uniform, aplanatic, gaussian]]
titles = ["Uniform", r"Aplanatic $\sqrt{\cos\theta}$", "Gaussian"]

c = N // 2
w = 15

for ax, psf, title in zip(axes, psfs, titles):
    im = ax.imshow(psf[c-w:c+w, c-w:c+w]**0.3, cmap="hot")
    ax.set_title(title, fontsize=12)
    ax.axis("off")

plt.suptitle("Effect of Apodization on Focal Spot", fontsize=14)
plt.tight_layout()
save_fig("apodization_effects.svg")

## 6. PyTree Architecture Figures

Figures for `pytree-architecture.md` guide.

In [None]:
# Figure 6.1: PyTree hierarchy diagram
fig, ax = plt.subplots(figsize=FIGSIZE_WIDE)

# Draw boxes for different types
boxes = [
    (0.5, 0.85, "janssen PyTrees", "lightgray"),
    (0.2, 0.6, "OpticalWavefront", "lightblue"),
    (0.5, 0.6, "CoherentModeSet", "lightgreen"),
    (0.8, 0.6, "LensParameters", "lightyellow"),
    (0.2, 0.35, "field\nwavelength\ndx", "white"),
    (0.5, 0.35, "modes\nweights\nwavelength", "white"),
    (0.8, 0.35, "focal_length\nNA\naperture", "white"),
]

for x, y, text, color in boxes:
    if y == 0.85:
        w, h = 0.3, 0.1
    elif y == 0.6:
        w, h = 0.18, 0.1
    else:
        w, h = 0.16, 0.15
    
    box = plt.Rectangle((x - w/2, y - h/2), w, h, fill=True,
                        facecolor=color, edgecolor="black", lw=1.5)
    ax.add_patch(box)
    ax.text(x, y, text, ha="center", va="center", fontsize=9)

# Arrows
for x in [0.2, 0.5, 0.8]:
    ax.annotate("", xy=(x, 0.65), xytext=(0.5, 0.8),
               arrowprops=dict(arrowstyle="->", color="gray"))
    ax.annotate("", xy=(x, 0.425), xytext=(x, 0.55),
               arrowprops=dict(arrowstyle="->", color="gray"))

ax.set_xlim(0, 1)
ax.set_ylim(0.15, 1)
ax.axis("off")
ax.set_title("Janssen PyTree Type Hierarchy", fontsize=14)
save_fig("pytree_hierarchy.svg")

In [None]:
# Figure 6.2: Wavefront types
fig, axes = plt.subplots(1, 2, figsize=FIGSIZE_WIDE)

N = 64
x = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, x)
R = np.sqrt(X**2 + Y**2)

# Scalar wavefront
scalar_field = np.exp(-R**2 / 0.3) * np.exp(1j * 2 * np.pi * X)
axes[0].imshow(np.abs(scalar_field), cmap="viridis")
axes[0].set_title("Scalar Wavefront\nComplex[H, W]", fontsize=11)
axes[0].axis("off")

# Polarized wavefront (show as 2 components)
Ex = np.exp(-R**2 / 0.3)
Ey = 0.5 * np.exp(-R**2 / 0.3) * np.exp(1j * np.pi/2)

axes[1].imshow(np.abs(Ex), cmap="Reds", alpha=0.7)
axes[1].imshow(np.abs(Ey), cmap="Blues", alpha=0.5)
axes[1].set_title("Polarized Wavefront\nComplex[H, W, 2]", fontsize=11)
axes[1].axis("off")

plt.suptitle("OpticalWavefront Types", fontsize=14)
plt.tight_layout()
save_fig("wavefront_types.svg")

In [None]:
# Figure 6.3: Coherence types
fig, axes = plt.subplots(1, 2, figsize=FIGSIZE_WIDE)

# CoherentModeSet - show multiple mode thumbnails
ax1 = axes[0]
for i in range(4):
    for j in range(3):
        idx = i * 3 + j
        if idx < 10:
            x_pos = 0.1 + j * 0.3
            y_pos = 0.8 - i * 0.22
            box = plt.Rectangle((x_pos, y_pos), 0.2, 0.18, fill=True,
                               facecolor="lightblue", edgecolor="blue", lw=1)
            ax1.add_patch(box)
            ax1.text(x_pos + 0.1, y_pos + 0.09, f"Mode {idx+1}", 
                    ha="center", va="center", fontsize=8)

ax1.text(0.5, 0.05, "modes: Complex[M, H, W]\nweights: Float[M]", 
        ha="center", fontsize=10, family="monospace")
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
ax1.axis("off")
ax1.set_title("CoherentModeSet", fontsize=12)

# PolychromaticWavefront - show wavelength slices
ax2 = axes[1]
colors = plt.cm.rainbow(np.linspace(0, 1, 5))
for i, c in enumerate(colors):
    x_pos = 0.15 + i * 0.14
    box = plt.Rectangle((x_pos, 0.3), 0.12, 0.4, fill=True,
                        facecolor=c, edgecolor="black", lw=1, alpha=0.7)
    ax2.add_patch(box)
    ax2.text(x_pos + 0.06, 0.25, f"λ{i+1}", ha="center", fontsize=9)

ax2.text(0.5, 0.05, "fields: Complex[Nλ, H, W]\nwavelengths: Float[Nλ]", 
        ha="center", fontsize=10, family="monospace")
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.axis("off")
ax2.set_title("PolychromaticWavefront", fontsize=12)

plt.suptitle("Coherence Data Types", fontsize=14)
plt.tight_layout()
save_fig("coherence_types.svg")

In [None]:
# Figure 6.4: Factory pattern workflow
fig, ax = plt.subplots(figsize=FIGSIZE_WIDE)

steps = [
    (0.1, 0.5, "Raw\nInputs", "lightyellow"),
    (0.3, 0.5, "Validate\nShapes", "lightcoral"),
    (0.5, 0.5, "Convert\nto JAX", "lightgreen"),
    (0.7, 0.5, "Normalize\nParams", "lightblue"),
    (0.9, 0.5, "PyTree\nInstance", "plum"),
]

for x, y, text, color in steps:
    box = plt.Rectangle((x - 0.08, y - 0.15), 0.16, 0.3, fill=True,
                        facecolor=color, edgecolor="black", lw=2,
                        transform=ax.transAxes)
    ax.add_patch(box)
    ax.text(x, y, text, ha="center", va="center", fontsize=10,
           transform=ax.transAxes)

# Arrows
for i in range(len(steps) - 1):
    ax.annotate("", xy=(steps[i+1][0] - 0.08, 0.5), 
               xytext=(steps[i][0] + 0.08, 0.5),
               arrowprops=dict(arrowstyle="->", color="black", lw=2),
               xycoords="axes fraction", textcoords="axes fraction")

ax.text(0.5, 0.9, "make_optical_wavefront(field, wavelength, dx)",
       ha="center", fontsize=12, family="monospace", transform=ax.transAxes)

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
ax.set_title("Factory Function Workflow", fontsize=14)
save_fig("factory_pattern.svg")

## Summary

All figures have been generated. Check the `figures/` directory for the SVG files.

In [None]:
# List generated figures
import os
svg_files = sorted([f for f in os.listdir(".") if f.endswith(".svg")])
print(f"Generated {len(svg_files)} figures:")
for f in svg_files:
    print(f"  - {f}")