In [1]:
#!/Users/jbm/miniforge3/envs/dedalus3/bin/python3
"""
Shallow Water Equations on a Double-Tanh Plane

February 2026
Authors: Leopold Li, Brad Marston

---------------------------------------------------------------------------
    This script solves the IVP for the shallow water equations 
    initialized with an eigenmode of a chosen frequency and horizontal 
    wavenumber on a double-tanh plane: 

        ∂ₜu + ∂ₓh - fv = 0
        ∂ₜv + ∂ᵧh + fu = 0
        ∂ₜh + ∂ₓu + ∂ᵧv = 0

        f = tanh(⍺(y-y₀)) - tanh(⍺(y+y₀)) + 1
---------------------------------------------------------------------------
"""

# Import Packages
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["NUMEXPR_MAX_THREADS"] = "1"

import pathlib
import time
import h5py
import numpy as np
import matplotlib
matplotlib.use('Agg') 
import matplotlib.pyplot as plt
import dedalus.public as d3
import dedalus.core as dec
from dedalus.tools import post 
import logging
logger = logging.getLogger(__name__)
plt.rcParams['text.usetex'] = False
from dedalus.core.operators import GeneralFunction
from dedalus.extras import flow_tools
import shutil
from mpi4py import MPI

In [2]:
path = '[Your Path]'

# Domain Parameters 
Nx = 128  # x resolution
Ny = 128  # y resolution
Lx = 10   # meridional dimension
Ly = 20*np.pi # zonal dimension
y0 = Ly/4 # Offset for tanh equator 
f0 = 1.0 # f₀ for beta-plane

"""
Select parameters 
    * Select target frequency + horizontal wavenumber of desired eigenmode for initialization
    * Set ⍺ parameter for double-tanh plane
    * Set simulation time / max iterations for IVP
"""

target_omega = [2.0] # Choose ⍵ for desired eigenmode to seed IVP
horizontal_wavenumber = [2.0] # Choose horizontal wavenumber for desired eigenmode to seed IVP
alphas = [10] # alpha parameter for 'sharpness' double tanh equator
sim_time = 30 
max_iterations = 2000

h_g = None
u_g = None
v_g = None
eig_sel = None
kx  = None
h_gs = []
u_gs = []
v_gs = []
y = None

In [3]:
### Solve EVP for given horizontal wavenumber, frequency, alpha, f0
def EVP_solve(k_x, target_omega, alph, f0):
    alpha = alph
    y = d3.Coordinate('y')
    dist = d3.Distributor(y, dtype=np.complex128)
    Y = d3.ComplexFourier(y, size=Ny, bounds=(-Ly/2, Ly/2))

    _dy = lambda A: d3.Differentiate(A, y)
    _dx = lambda A: (-1j*k_x)*A           
    _dt = lambda A: ( 1j)*omega*A         

    u = dist.Field(name='u', bases=Y)
    v = dist.Field(name='v', bases=Y)
    h = dist.Field(name='h', bases=Y)
    omega = dist.Field(name='omega')        

    yy = dist.local_grids(Y)[0]
    f  = dist.Field(name='f', bases=Y)

    f['g'] = f0*(np.tanh(alph*(yy-y0)) -np.tanh(alph*(yy+y0)) +1) # Double-tanh equator

    problem = d3.EVP([u, v, h], eigenvalue=omega, namespace=locals())
    problem.add_equation("_dt(u) + _dx(h) - f*v = 0")
    problem.add_equation("_dt(v) + _dy(h) + f*u = 0")
    problem.add_equation("_dt(h) + _dx(u) + _dy(v) = 0")

    solver = problem.build_solver()

    solver.solve_dense(solver.subproblems[0])

    order = np.argsort(solver.eigenvalues.real) # Indices ordered by eigenvalues 
    eigs_ordered  = solver.eigenvalues[order] # Eigenmodes in order
    print("First few eigenvalues:\n", eigs_ordered[:5]) 

    # Pick eigenvalues - finds eigenmode nearest to selected ⍵ 
    omega_idx = np.argmin(np.abs(eigs_ordered.real - target_omega)) 
    omega_idx_unsorted = np.argmin(np.abs(solver.eigenvalues.real - target_omega))
    eig_sel = eigs_ordered[omega_idx]

    print(f"Alpha = {alph} Omega:", eig_sel.real)

    solver.set_state(omega_idx_unsorted)
    y1d = dist.local_grids(Y, scales=1)[0]
    h_g = h['g'].copy()
    u_g = u['g'].copy()
    v_g = v['g'].copy()

    # Plot selected eigenmode
    plt.figure(figsize=(9, 6))
    plt.plot(y1d, np.real(h_g), color = 'black', lw = 2)
    plt.tick_params('both', size = 8, width = 1.5, direction = 'in')
    plt.xlabel('$y$', fontsize = 25, color = 'dimgray')
    plt.ylabel('$h$', fontsize = 25, color = 'dimgray')

    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_linewidth(2)
        spine.set_color('dimgray')

    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.title(f'$\\alpha={alph}, k_x = {k_x:.3f},  \\omega = {eig_sel.real:.3f}, f_0 = {f0}$',  fontsize = 25, color = 'dimgray')
    ax.tick_params(axis='both', colors='dimgray') 
    # plt.savefig(path + 'selected_eigenmode.png', dpi = 200) # Uncomment to save plot

    # Store grids for initializing IVP
    h_gs.append(np.real(h_g))
    u_gs.append(np.real(u_g))
    v_gs.append(np.real(v_g))



In [4]:
# Run EVP
for i in range(len(alphas)): 
    EVP_solve(horizontal_wavenumber[i], target_omega[i], alphas[i], f0)

# Set desired kx: (Set to zero here)
kx = horizontal_wavenumber[0]
alpha = alphas[0] 


"""
IVP 
"""
# Pick initial conditions from collected eigenmodes
eigenmode_h = h_gs[0]  # selects height field of indexed eigenmode
eigenmode_u =  u_gs[0] # selects meriodional velocity field of indexed eigenmode
eigenmode_v = v_gs[0] # selects zonal velocity field of indexed eigenmode

# Dedalus Domain parameters 
coords = d3.CartesianCoordinates('x','y')
dist = d3.Distributor(coords, dtype=np.float64)

x_basis = d3.RealFourier(coords['x'], Nx, bounds=[-Lx/2, Lx/2], dealias=1.0)
y_basis = d3.RealFourier(coords['y'], Ny, bounds=[-Ly/2, Ly/2], dealias = 1.0)

# Dedalus fields
u = dist.Field(name='u', bases=[x_basis, y_basis])
v = dist.Field(name='v', bases=[x_basis, y_basis])
h = dist.Field(name='h', bases=[x_basis, y_basis])
f = dist.Field(name='f', bases=[y_basis])

dx = lambda A: d3.Differentiate(A, coords['x'])
dy = lambda A: d3.Differentiate(A, coords['y'])

x, y = dist.local_grids(x_basis, y_basis)
f['g'] = np.tanh(alpha*(y - y0)) - np.tanh(alpha*(y + y0)) + 1.0 # Double-tanh plane
# f['g'] =np.sin(2*np.pi*y/Ly) # sin coriolis parameter 

# Initial conditions
u['g'] = eigenmode_u.real*np.cos(kx*2*np.pi/Lx*x)
v['g'] = eigenmode_v.real*np.cos(kx*2*np.pi/Lx*x)
h['g'] = eigenmode_h.real*np.cos(kx*2*np.pi/Lx*x)


# IVP Equations 
problem = d3.IVP([u, v, h], namespace=locals())
problem.add_equation("dt(u) + dx(h) - f*v = 0")
problem.add_equation("dt(v) + dy(h) + f*u  = 0")
problem.add_equation("dt(h) + dx(u) + dy(v) = 0")

solver = problem.build_solver('RK222') 


2026-02-21 07:51:44,101 subsystems 0/1 INFO :: Building subproblem matrices 1/1 (~100%) Elapsed: 0s, Remaining: 0s, Rate: 2.7e+01/s
First few eigenvalues:
 [-6.74399468-6.40687597e-18j -6.67773352-1.10344869e-17j
 -6.64921026-6.51261074e-18j -6.58348057-1.17059954e-17j
 -6.58203893-1.08874162e-17j]
Alpha = 10 Omega: 2.0000000001093103
2026-02-21 07:51:44,648 subsystems 0/1 INFO :: Building subproblem matrices 1/64 (~2%) Elapsed: 0s, Remaining: 2s, Rate: 3.8e+01/s
2026-02-21 07:51:44,777 subsystems 0/1 INFO :: Building subproblem matrices 7/64 (~11%) Elapsed: 0s, Remaining: 1s, Rate: 4.5e+01/s
2026-02-21 07:51:44,926 subsystems 0/1 INFO :: Building subproblem matrices 14/64 (~22%) Elapsed: 0s, Remaining: 1s, Rate: 4.6e+01/s
2026-02-21 07:51:45,077 subsystems 0/1 INFO :: Building subproblem matrices 21/64 (~33%) Elapsed: 0s, Remaining: 1s, Rate: 4.6e+01/s
2026-02-21 07:51:45,227 subsystems 0/1 INFO :: Building subproblem matrices 28/64 (~44%) Elapsed: 1s, Remaining: 1s, Rate: 4.6e+01/s
2

In [5]:
# Run IVP

solver.stop_sim_time = sim_time
solver.stop_wall_time = np.inf
solver.stop_iteration = max_iterations

# Set up CFL 
vel = d3.VectorField(dist,coordsys =coords, bases=(x_basis, y_basis), name='vel')
init_dt = 0.001
CFL = flow_tools.CFL(solver, initial_dt=init_dt, cadence=10, safety=0.3, max_change=1.5)
CFL.add_velocity(vel)

# Lists for accumulating grids 
u_max= []
u_list = []
h_list = []
t_list = []

logger.info('Starting loop')
start_time = time.time()
dt = 0.005 # Initial dt

while solver.proceed:
    solver.step(dt)

    vel['g'][0] = u['g'].real
    vel['g'][1] = v['g'].real

    dt = CFL.compute_timestep()
    t_list.append(solver.sim_time)
    u_list.append(np.copy(u['g']))
    h_list.append(np.copy(h['g']))
    
    
    if solver.iteration % 10 == 0:
        print('Completed iteration {}, time {}, dt {}'.format(solver.iteration, t_list[-1], dt))

end_time = time.time()


logger.info('Run time: %f' %(end_time-start_time))
logger.info('Iterations: %i' %solver.iteration)


2026-02-21 07:51:52,214 __main__ 0/1 INFO :: Starting loop
Completed iteration 10, time 0.014000000000000005, dt 0.001
Completed iteration 20, time 0.02850000000000002, dt 0.0015
Completed iteration 30, time 0.05025000000000004, dt 0.0022500000000000003
Completed iteration 40, time 0.08287500000000006, dt 0.0033750000000000004
Completed iteration 50, time 0.13181250000000005, dt 0.005062500000000001
Completed iteration 60, time 0.1865106957773567, dt 0.005515077308595187
Completed iteration 70, time 0.24167402694984622, dt 0.005516472651543816
Completed iteration 80, time 0.2968050280012764, dt 0.005512725377765165
Completed iteration 90, time 0.3519399055133756, dt 0.005513572459370466
Completed iteration 100, time 0.4071246647185199, dt 0.005519020749530449
Completed iteration 110, time 0.46226690375289525, dt 0.005513690920538345
Completed iteration 120, time 0.5173946528168605, dt 0.005512673127047435
Completed iteration 130, time 0.5725535567669499, dt 0.005516247869226841
Complet

In [6]:
# Produce animation 
import matplotlib.colors as mcolors
from matplotlib import animation 

xm, ym = np.meshgrid(x,y)

fig, axis = plt.subplots(figsize=(10,5),num="Selected eigenmode")

lim  = np.nanmax(np.abs(h_list))
norm = mcolors.TwoSlopeNorm(vmin=-lim, vcenter=0.0, vmax=lim)
p = axis.pcolormesh(xm, ym, np.array(h_list[0]).T,norm = norm, cmap='RdBu_r', shading='gouraud')
axis.set_title(rf'$\alpha = {alpha},\ t = {t_list[0]:6.2f}$', fontsize = 20)
axis.set_xlabel('x',fontsize=20)
axis.set_ylabel('y',fontsize=20)
cbar = fig.colorbar(p,ax=axis)
cbar.set_label('h', fontsize = 20)
cbar.ax.tick_params(labelsize=20)
u_all = np.array(u_list)
h_all = np.array(h_list)

def init():
    p.set_array(np.ravel(np.array(h_list[0]).T))
    return p

def animate(i): 
    if i % 10 == 0:
        print(f"Rendering frame {i}...")
    p.set_array(np.ravel(np.array(h_list[(i+1)*10]).T))
    axis.set_title(fr'$\alpha = {alpha},\; t = {t_list[(i+1)*10]:6.2f}$')

    return p

ani = animation.FuncAnimation(fig, animate, frames=int(len(t_list)/10-1))
print("Saving animation ...")
ani.save(path + 'swe_ivp.gif', writer='pillow', fps=10)
print(f"\n \n Animation ('swe_ivp.gif') saved to: {path}")


Saving animation ...
2026-02-21 07:53:35,033 matplotlib.animation 0/1 INFO :: Animation.save using <class 'matplotlib.animation.PillowWriter'>
Rendering frame 0...
Rendering frame 0...
Rendering frame 10...
Rendering frame 20...
Rendering frame 30...
Rendering frame 40...
Rendering frame 50...
Rendering frame 60...
Rendering frame 70...
Rendering frame 80...
Rendering frame 90...
Rendering frame 100...
Rendering frame 110...
Rendering frame 120...
Rendering frame 130...
Rendering frame 140...
Rendering frame 150...
Rendering frame 160...
Rendering frame 170...
Rendering frame 180...
Rendering frame 190...

 
 Animation ('swe_ivp.gif') saved to: [Your Path]
