# Weierstrass ℘ Playground — Two-Panel, Coupled Second-Order Trajectories

This notebook visualizes complex fields derived from the Weierstrass ℘ function on a rectangular lattice and overlays shared second-order trajectories across two panels:
- **Left panel**: ℘(z) background/field
- **Right panel**: ℘′(z) background/field
- **Trajectories**: Follow the second-order ODE z''(t) = -℘(z(t)) z(t)

## Important: How to Use This Notebook

**To avoid errors, please run all cells in order from top to bottom.** 

You can either:
1. Use **Cell → Run All** from the menu, or
2. Run each cell individually using **Shift+Enter**

The interactive controls will appear at the bottom after all cells are executed.

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle
import ipywidgets as widgets
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib to inline mode
%matplotlib inline

# Configure matplotlib for high quality plots
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10

## Core Mathematical Functions

In [None]:
def wp_rect(z, p, q, N):
    """
    Weierstrass ℘ function for rectangular lattice Λ = Zp + Ziq
    using truncated symmetric lattice sum.
    
    Args:
        z: complex number or array
        p, q: real lattice parameters
        N: truncation parameter (sum from -N to N)
    
    Returns:
        ℘(z) values
    """
    z = np.asarray(z, dtype=complex)
    result = np.zeros_like(z, dtype=complex)
    
    # Main term: 1/z^2
    with np.errstate(divide='ignore', invalid='ignore'):
        result += 1.0 / (z**2)
    
    # Lattice sum (excluding origin)
    for m in range(-N, N+1):
        for n in range(-N, N+1):
            if m == 0 and n == 0:
                continue
            
            omega = m * p + n * 1j * q
            with np.errstate(divide='ignore', invalid='ignore'):
                term = 1.0 / (z - omega)**2 - 1.0 / omega**2
                result += term
    
    return result

def wp_deriv(z, p, q, N):
    """
    Derivative of Weierstrass ℘ function: ℘'(z) = -2 * sum(1/(z-ω)^3)
    
    Args:
        z: complex number or array
        p, q: real lattice parameters
        N: truncation parameter
    
    Returns:
        ℘'(z) values
    """
    z = np.asarray(z, dtype=complex)
    result = np.zeros_like(z, dtype=complex)
    
    # Main term: -2/z^3
    with np.errstate(divide='ignore', invalid='ignore'):
        result += -2.0 / (z**3)
    
    # Lattice sum (excluding origin)
    for m in range(-N, N+1):
        for n in range(-N, N+1):
            if m == 0 and n == 0:
                continue
            
            omega = m * p + n * 1j * q
            with np.errstate(divide='ignore', invalid='ignore'):
                term = -2.0 / (z - omega)**3
                result += term
    
    return result

def wp_and_deriv(z, p, q, N):
    """
    Compute both ℘(z) and ℘'(z) efficiently.
    
    Returns:
        (wp_val, wp_deriv_val)
    """
    return wp_rect(z, p, q, N), wp_deriv(z, p, q, N)

## Field Sampling and Visualization Functions

In [None]:
def field_grid(p, q, which, N, nx, ny, pole_eps=1e-6):
    """
    Sample field on grid with pole detection.
    
    Args:
        p, q: lattice parameters
        which: 'wp' for ℘(z) or 'wp_deriv' for ℘'(z)
        N: lattice truncation
        nx, ny: grid resolution
        pole_eps: pole detection threshold
    
    Returns:
        X, Y, F, M where F is field values and M is valid mask
    """
    x = np.linspace(0, p, nx)
    y = np.linspace(0, q, ny)
    X, Y = np.meshgrid(x, y)
    Z = X + 1j * Y
    
    # Check for poles (lattice points)
    mask = np.ones_like(Z, dtype=bool)
    for m in range(-N, N+1):
        for n in range(-N, N+1):
            omega = m * p + n * 1j * q
            # Wrap omega to fundamental cell
            omega_wrapped = (omega.real % p) + 1j * (omega.imag % q)
            dist = np.abs(Z - omega_wrapped)
            mask &= (dist > pole_eps)
    
    # Compute field
    if which == 'wp':
        F = wp_rect(Z, p, q, N)
    elif which == 'wp_deriv':
        F = wp_deriv(Z, p, q, N)
    else:
        raise ValueError("which must be 'wp' or 'wp_deriv'")
    
    # Apply mask
    F = np.where(mask, F, np.nan)
    
    # Additional validation
    finite_mask = np.isfinite(F)
    mask &= finite_mask
    
    return X, Y, F, mask

def soft_background(F, M, sat=0.3, mag_scale=1.0, value_floor=0.3):
    """
    Create soft color palette RGB image.
    
    Args:
        F: complex field values
        M: valid mask
        sat: saturation level
        mag_scale: magnitude scaling
        value_floor: minimum brightness
    
    Returns:
        RGB image array
    """
    # Hue from argument
    H = np.angle(F) / (2 * np.pi) + 0.5  # Map to [0, 1]
    H = H % 1.0
    
    # Brightness from magnitude with compression
    mag = np.abs(F) * mag_scale
    V = np.arctan(mag) / (np.pi / 2)  # Compress to [0, 1]
    V = value_floor + (1 - value_floor) * V  # Raise floor
    
    # Constant saturation
    S = np.full_like(H, sat)
    
    # Convert HSV to RGB
    HSV = np.stack([H, S, V], axis=-1)
    RGB = mcolors.hsv_to_rgb(HSV)
    
    # Apply mask (set invalid regions to white)
    RGB = np.where(M[..., np.newaxis], RGB, 1.0)
    
    return np.clip(RGB, 0, 1)

def add_topo_contours(ax, X, Y, F, M, n_contours=10):
    """
    Add topographic contours of |F|.
    """
    if n_contours <= 0:
        return
    
    mag = np.abs(F)
    mag = np.where(M, mag, np.nan)
    
    if np.all(np.isnan(mag)):
        return
    
    # Use log scale for better contour distribution
    with np.errstate(divide='ignore', invalid='ignore'):
        log_mag = np.log10(mag + 1e-10)
    
    finite_mask = np.isfinite(log_mag)
    if not np.any(finite_mask):
        return
    
    vmin, vmax = np.nanmin(log_mag), np.nanmax(log_mag)
    if vmin == vmax:
        return
    
    levels = np.linspace(vmin, vmax, n_contours)
    ax.contour(X, Y, log_mag, levels=levels, colors='black', alpha=0.3, linewidths=0.5)

def vector_overlay(ax, X, Y, F, M, density=20, width=0.002, max_len=0.5):
    """
    Add vector field overlay with magnitude compression and length clipping.
    """
    if density <= 0:
        return
    
    # Subsample grid
    ny, nx = X.shape
    step_x = max(1, nx // density)
    step_y = max(1, ny // density)
    
    X_sub = X[::step_y, ::step_x]
    Y_sub = Y[::step_y, ::step_x]
    F_sub = F[::step_y, ::step_x]
    M_sub = M[::step_y, ::step_x]
    
    # Vector components with compression
    mag = np.abs(F_sub)
    # Compress magnitude
    compressed_mag = np.tanh(mag / np.nanmax(mag) * 2) if np.nanmax(mag) > 0 else mag
    
    U = np.real(F_sub / mag * compressed_mag)
    V = np.imag(F_sub / mag * compressed_mag)
    
    # Apply mask and length filter
    arrow_len = np.sqrt(U**2 + V**2)
    valid = M_sub & np.isfinite(U) & np.isfinite(V) & (arrow_len <= max_len)
    
    if np.any(valid):
        ax.quiver(X_sub[valid], Y_sub[valid], U[valid], V[valid], 
                 scale_units='xy', scale=1, width=width, alpha=0.7, color='darkblue')

## Trajectory Integration Functions

In [None]:
def wrap_point(z, p, q):
    """
    Wrap a point to the fundamental cell [0,p] × [0,q].
    """
    return (z.real % p) + 1j * (z.imag % q)

def wrap_with_breaks(zs, p, q, wrap_threshold=0.5):
    """
    Wrap trajectory to fundamental cell and insert breaks at wrap jumps.
    
    Args:
        zs: array of complex trajectory points
        p, q: lattice parameters
        wrap_threshold: fraction of cell size to detect wraps
    
    Returns:
        wrapped_zs with NaN breaks where wrapping occurs
    """
    if len(zs) == 0:
        return np.array([])
    
    wrapped = np.array([wrap_point(z, p, q) for z in zs])
    result = [wrapped[0]]
    
    for i in range(1, len(wrapped)):
        dz = wrapped[i] - wrapped[i-1]
        
        # Check for wrap in x or y direction
        if (abs(dz.real) > wrap_threshold * p or 
            abs(dz.imag) > wrap_threshold * q):
            result.append(np.nan + 1j * np.nan)  # Break
        
        result.append(wrapped[i])
    
    return np.array(result)

def integrate_second_order_with_blowup(z0, v0, dt, T, p, q, N, blow_thresh=10.0, pole_eps=1e-6):
    """
    Integrate second-order ODE: z''(t) = -℘(z(t)) * z(t)
    using RK4 with blow-up detection.
    
    Args:
        z0, v0: initial position and velocity (complex)
        dt: time step
        T: total time
        p, q, N: lattice parameters
        blow_thresh: blow-up threshold for |Δz|
        pole_eps: pole proximity threshold
    
    Returns:
        (trajectory, blow_up_point) where blow_up_point is None if no blow-up
    """
    def force(z):
        """Compute force: -℘(z) * z"""
        wp_val = wp_rect(z, p, q, N)
        return -wp_val * z
    
    # Convert to system of first-order ODEs
    def rhs(t, state):
        """Right-hand side: [z', v'] = [v, force(z)]"""
        z, v = state[0], state[1]
        return np.array([v, force(z)])
    
    # RK4 integration
    steps = int(T / dt)
    trajectory = []
    
    state = np.array([z0, v0])
    t = 0
    
    trajectory.append(state[0])  # Store position
    
    for step in range(steps):
        # Check for pole proximity
        z_wrapped = wrap_point(state[0], p, q)
        too_close_to_pole = False
        
        for m in range(-N, N+1):
            for n in range(-N, N+1):
                omega = m * p + n * 1j * q
                omega_wrapped = wrap_point(omega, p, q)
                if abs(z_wrapped - omega_wrapped) < pole_eps:
                    too_close_to_pole = True
                    break
            if too_close_to_pole:
                break
        
        if too_close_to_pole:
            return np.array(trajectory), state[0]  # Blow-up at pole
        
        # RK4 step
        try:
            k1 = dt * rhs(t, state)
            k2 = dt * rhs(t + dt/2, state + k1/2)
            k3 = dt * rhs(t + dt/2, state + k2/2)
            k4 = dt * rhs(t + dt, state + k3)
            
            new_state = state + (k1 + 2*k2 + 2*k3 + k4) / 6
            
            # Check for blow-up
            dz = new_state[0] - state[0]
            if (abs(dz) > blow_thresh or 
                not np.isfinite(new_state[0]) or 
                not np.isfinite(new_state[1])):
                return np.array(trajectory), state[0]  # Blow-up detected
            
            state = new_state
            t += dt
            trajectory.append(state[0])
            
        except Exception:
            return np.array(trajectory), state[0]  # Integration error
    
    return np.array(trajectory), None  # No blow-up

## Interactive UI Setup

In [None]:
# Global variables for current figure
current_fig = None
output_widget = widgets.Output()

# Lattice parameters
p_slider = widgets.FloatSlider(value=11.0, min=1.0, max=20.0, step=0.1, description='p')
q_slider = widgets.FloatSlider(value=5.0, min=1.0, max=20.0, step=0.1, description='q')
N_slider = widgets.IntSlider(value=3, min=0, max=6, description='N (truncation)')

# Rendering parameters
grid_x_slider = widgets.IntSlider(value=100, min=50, max=300, description='Grid X')
grid_y_slider = widgets.IntSlider(value=100, min=50, max=300, description='Grid Y')
contours_slider = widgets.IntSlider(value=10, min=0, max=30, description='# Contours')
vec_density_slider = widgets.IntSlider(value=20, min=0, max=50, description='Vec density')
vec_width_slider = widgets.FloatSlider(value=0.002, min=0.001, max=0.01, step=0.001, description='Vec width')
vec_max_len_slider = widgets.FloatSlider(value=0.5, min=0.1, max=2.0, step=0.1, description='Vec max len')
show_vectors_checkbox = widgets.Checkbox(value=True, description='Show vectors')

# Palette parameters
saturation_slider = widgets.FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Saturation')
value_floor_slider = widgets.FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Value floor')
mag_scale_slider = widgets.FloatSlider(value=1.0, min=0.1, max=5.0, step=0.1, description='Mag scale')

# Integration parameters
dt_slider = widgets.FloatSlider(value=0.01, min=0.001, max=0.1, step=0.001, description='dt')
T_slider = widgets.FloatSlider(value=10.0, min=1.0, max=50.0, step=1.0, description='T (duration)')
blowup_thresh_slider = widgets.FloatSlider(value=10.0, min=1.0, max=50.0, step=1.0, description='Blow-up |Δz|')
emoji_size_slider = widgets.IntSlider(value=20, min=10, max=50, description='Emoji size')

# Particle list
particle_list = []
particles_container = widgets.VBox()

def create_particle_row(idx=0):
    """Create a particle input row."""
    z0_text = widgets.Text(value='2+1j', description=f'z0 #{idx}')
    v0_text = widgets.Text(value='0+1j', description=f'v0 #{idx}')
    remove_btn = widgets.Button(description='Remove', button_style='danger', layout=widgets.Layout(width='80px'))
    
    def remove_particle(b):
        if len(particle_list) > 1:  # Keep at least one particle
            particle_list.remove((z0_text, v0_text, remove_btn, row))
            update_particles_display()
    
    remove_btn.on_click(remove_particle)
    row = widgets.HBox([z0_text, v0_text, remove_btn])
    
    return z0_text, v0_text, remove_btn, row

def add_particle(b=None):
    """Add a new particle."""
    idx = len(particle_list)
    particle_row = create_particle_row(idx)
    particle_list.append(particle_row)
    update_particles_display()

def update_particles_display():
    """Update the particles display."""
    # Update indices
    for i, (z0_text, v0_text, remove_btn, row) in enumerate(particle_list):
        z0_text.description = f'z0 #{i}'
        v0_text.description = f'v0 #{i}'
    
    particles_container.children = [row for _, _, _, row in particle_list]

def get_particles():
    """Get current particle initial conditions."""
    particles = []
    for z0_text, v0_text, _, _ in particle_list:
        try:
            z0 = complex(z0_text.value)
            v0 = complex(v0_text.value)
            particles.append((z0, v0))
        except ValueError:
            continue  # Skip invalid entries
    return particles

# Initialize with one particle
add_particle()

add_particle_btn = widgets.Button(description='Add Particle', button_style='success')
add_particle_btn.on_click(add_particle)

# Control buttons
render_btn = widgets.Button(description='Render', button_style='primary')
save_btn = widgets.Button(description='Save PNG', button_style='info')

## Main Rendering Function

In [None]:
def render_playground():
    """Main rendering function for the Weierstrass playground."""
    global current_fig
    
    # Check if required functions are defined
    try:
        field_grid  # Check if function exists
    except NameError:
        print("Error: Required functions not defined. Please run all cells from the beginning.")
        return
    
    with output_widget:
        clear_output(wait=True)
        
        # Get parameters
        p, q, N = p_slider.value, q_slider.value, N_slider.value
        nx, ny = grid_x_slider.value, grid_y_slider.value
        n_contours = contours_slider.value
        vec_density = vec_density_slider.value if show_vectors_checkbox.value else 0
        vec_width = vec_width_slider.value
        vec_max_len = vec_max_len_slider.value
        
        saturation = saturation_slider.value
        value_floor = value_floor_slider.value
        mag_scale = mag_scale_slider.value
        
        dt = dt_slider.value
        T = T_slider.value
        blow_thresh = blowup_thresh_slider.value
        emoji_size = emoji_size_slider.value
        
        particles = get_particles()
        
        print(f"Rendering with p={p}, q={q}, N={N}, particles={len(particles)}")
        print(f"Grid: {nx}×{ny}, dt={dt}, T={T}")
        
        # Create figure with two panels
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        fig.subplots_adjust(wspace=0.0)  # No space between panels
        
        # Compute fields
        X1, Y1, F1, M1 = field_grid(p, q, 'wp', N, nx, ny)
        X2, Y2, F2, M2 = field_grid(p, q, 'wp_deriv', N, nx, ny)
        
        # Create backgrounds
        bg1 = soft_background(F1, M1, saturation, mag_scale, value_floor)
        bg2 = soft_background(F2, M2, saturation, mag_scale, value_floor)
        
        # Display backgrounds
        ax1.imshow(bg1, extent=[0, p, 0, q], origin='lower', aspect='equal')
        ax2.imshow(bg2, extent=[0, p, 0, q], origin='lower', aspect='equal')
        
        # Add contours
        add_topo_contours(ax1, X1, Y1, F1, M1, n_contours)
        add_topo_contours(ax2, X2, Y2, F2, M2, n_contours)
        
        # Add vector fields
        if vec_density > 0:
            vector_overlay(ax1, X1, Y1, F1, M1, vec_density, vec_width, vec_max_len)
            vector_overlay(ax2, X2, Y2, F2, M2, vec_density, vec_width, vec_max_len)
        
        # Integrate and plot trajectories
        colors = plt.cm.tab10(np.linspace(0, 1, len(particles)))
        
        for i, (z0, v0) in enumerate(particles):
            try:
                trajectory, blowup_point = integrate_second_order_with_blowup(
                    z0, v0, dt, T, p, q, N, blow_thresh
                )
                
                if len(trajectory) > 1:
                    # Wrap trajectory with breaks
                    wrapped_traj = wrap_with_breaks(trajectory, p, q)
                    
                    # Split trajectory at NaN breaks
                    segments = []
                    current_segment = []
                    
                    for z in wrapped_traj:
                        if np.isnan(z):
                            if current_segment:
                                segments.append(np.array(current_segment))
                                current_segment = []
                        else:
                            current_segment.append(z)
                    
                    if current_segment:
                        segments.append(np.array(current_segment))
                    
                    # Plot segments on both panels
                    for segment in segments:
                        if len(segment) > 1:
                            ax1.plot(segment.real, segment.imag, color=colors[i], linewidth=2, alpha=0.8)
                            ax2.plot(segment.real, segment.imag, color=colors[i], linewidth=2, alpha=0.8)
                    
                    # Mark starting point
                    z0_wrapped = wrap_point(z0, p, q)
                    ax1.plot(z0_wrapped.real, z0_wrapped.imag, 'o', color=colors[i], markersize=8, markeredgecolor='white')
                    ax2.plot(z0_wrapped.real, z0_wrapped.imag, 'o', color=colors[i], markersize=8, markeredgecolor='white')
                    
                    # Mark blow-up point if exists
                    if blowup_point is not None:
                        bp_wrapped = wrap_point(blowup_point, p, q)
                        ax1.text(bp_wrapped.real, bp_wrapped.imag, '💥', fontsize=emoji_size, ha='center', va='center')
                        ax2.text(bp_wrapped.real, bp_wrapped.imag, '💥', fontsize=emoji_size, ha='center', va='center')
            
            except Exception as e:
                print(f"Error integrating particle {i}: {e}")
        
        # Set labels and titles
        ax1.set_title('℘(z)', fontsize=16)
        ax2.set_title("℘'(z)", fontsize=16)
        ax1.set_xlabel('Re(z)')
        ax1.set_ylabel('Im(z)')
        ax2.set_xlabel('Re(z)')
        ax2.set_ylabel('')  # No y-label on right panel
        
        # Set axis limits
        ax1.set_xlim(0, p)
        ax1.set_ylim(0, q)
        ax2.set_xlim(0, p)
        ax2.set_ylim(0, q)
        
        # Remove ticks on right panel y-axis
        ax2.set_yticks([])
        
        plt.tight_layout()
        current_fig = fig
        plt.show()

def save_figure(b=None):
    """Save current figure as PNG."""
    global current_fig
    if current_fig is not None:
        filename = 'weierstrass_playground.png'
        current_fig.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Figure saved as {filename}")
    else:
        print("No figure to save. Please render first.")

## Interactive Interface

In [None]:
# Ensure all UI widgets are defined (in case previous cells weren't run)
try:
    p_slider
except NameError:
    print("Creating UI widgets...")
    
    # Global variables for current figure
    current_fig = None
    output_widget = widgets.Output()
    
    # Lattice parameters
    p_slider = widgets.FloatSlider(value=11.0, min=1.0, max=20.0, step=0.1, description='p')
    q_slider = widgets.FloatSlider(value=5.0, min=1.0, max=20.0, step=0.1, description='q')
    N_slider = widgets.IntSlider(value=3, min=0, max=6, description='N (truncation)')
    
    # Rendering parameters
    grid_x_slider = widgets.IntSlider(value=100, min=50, max=300, description='Grid X')
    grid_y_slider = widgets.IntSlider(value=100, min=50, max=300, description='Grid Y')
    contours_slider = widgets.IntSlider(value=10, min=0, max=30, description='# Contours')
    vec_density_slider = widgets.IntSlider(value=20, min=0, max=50, description='Vec density')
    vec_width_slider = widgets.FloatSlider(value=0.002, min=0.001, max=0.01, step=0.001, description='Vec width')
    vec_max_len_slider = widgets.FloatSlider(value=0.5, min=0.1, max=2.0, step=0.1, description='Vec max len')
    show_vectors_checkbox = widgets.Checkbox(value=True, description='Show vectors')
    
    # Palette parameters
    saturation_slider = widgets.FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Saturation')
    value_floor_slider = widgets.FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Value floor')
    mag_scale_slider = widgets.FloatSlider(value=1.0, min=0.1, max=5.0, step=0.1, description='Mag scale')
    
    # Integration parameters
    dt_slider = widgets.FloatSlider(value=0.01, min=0.001, max=0.1, step=0.001, description='dt')
    T_slider = widgets.FloatSlider(value=10.0, min=1.0, max=50.0, step=1.0, description='T (duration)')
    blowup_thresh_slider = widgets.FloatSlider(value=10.0, min=1.0, max=50.0, step=1.0, description='Blow-up |Δz|')
    emoji_size_slider = widgets.IntSlider(value=20, min=10, max=50, description='Emoji size')
    
    # Particle list management
    particle_list = []
    particles_container = widgets.VBox()
    
    def create_particle_row(idx=0):
        """Create a particle input row."""
        z0_text = widgets.Text(value='2+1j', description=f'z0 #{idx}')
        v0_text = widgets.Text(value='0+1j', description=f'v0 #{idx}')
        remove_btn = widgets.Button(description='Remove', button_style='danger', layout=widgets.Layout(width='80px'))
        
        def remove_particle(b):
            if len(particle_list) > 1:
                particle_list.remove((z0_text, v0_text, remove_btn, row))
                update_particles_display()
        
        remove_btn.on_click(remove_particle)
        row = widgets.HBox([z0_text, v0_text, remove_btn])
        
        return z0_text, v0_text, remove_btn, row
    
    def add_particle(b=None):
        """Add a new particle."""
        idx = len(particle_list)
        particle_row = create_particle_row(idx)
        particle_list.append(particle_row)
        update_particles_display()
    
    def update_particles_display():
        """Update the particles display."""
        for i, (z0_text, v0_text, remove_btn, row) in enumerate(particle_list):
            z0_text.description = f'z0 #{i}'
            v0_text.description = f'v0 #{i}'
        
        particles_container.children = [row for _, _, _, row in particle_list]
    
    def get_particles():
        """Get current particle initial conditions."""
        particles = []
        for z0_text, v0_text, _, _ in particle_list:
            try:
                z0 = complex(z0_text.value)
                v0 = complex(v0_text.value)
                particles.append((z0, v0))
            except ValueError:
                continue
        return particles
    
    # Initialize with one particle
    add_particle()
    
    add_particle_btn = widgets.Button(description='Add Particle', button_style='success')
    add_particle_btn.on_click(add_particle)
    
    # Control buttons
    render_btn = widgets.Button(description='Render', button_style='primary')
    save_btn = widgets.Button(description='Save PNG', button_style='info')
    
    # Connect button callbacks (need to be defined after render_playground function)
    render_btn.on_click(lambda b: render_playground())
    save_btn.on_click(save_figure)

# Create the UI layout
lattice_box = widgets.VBox([
    widgets.HTML("<h3>Lattice Parameters</h3>"),
    p_slider, q_slider, N_slider
])

rendering_box = widgets.VBox([
    widgets.HTML("<h3>Rendering</h3>"),
    grid_x_slider, grid_y_slider, contours_slider,
    show_vectors_checkbox, vec_density_slider, vec_width_slider, vec_max_len_slider
])

palette_box = widgets.VBox([
    widgets.HTML("<h3>Palette</h3>"),
    saturation_slider, value_floor_slider, mag_scale_slider
])

integration_box = widgets.VBox([
    widgets.HTML("<h3>Integration</h3>"),
    dt_slider, T_slider, blowup_thresh_slider, emoji_size_slider
])

particles_box = widgets.VBox([
    widgets.HTML("<h3>Particles</h3>"),
    particles_container,
    add_particle_btn
])

controls_box = widgets.VBox([
    widgets.HTML("<h3>Controls</h3>"),
    render_btn, save_btn
])

# Layout in two columns
left_column = widgets.VBox([lattice_box, rendering_box, palette_box])
right_column = widgets.VBox([integration_box, particles_box, controls_box])

ui = widgets.HBox([left_column, right_column])

# Display the interface
display(ui)
display(output_widget)

print("Weierstrass ℘ Playground loaded!")
print("Click 'Render' to generate the visualization.")
print("\nTips:")
print("- Try p=11, q=5, N=3 for a good starting point")
print("- Higher N values give more accurate ℘ function but slower computation")
print("- Trajectories follow z''(t) = -℘(z(t)) * z(t)")
print("- 💥 marks indicate trajectory blow-ups near poles")
print("\nNote: If you get NameError, please run all cells from the beginning in order.")