The linear shallow water equations are
$$\begin{align}
\partial_t s + \nabla\cdot q & = 0 \\
\partial_t q + f\hat k \times q + gH\nabla s & = 0
\end{align}$$
An exact solution is
$$s = R^{-1}(x^2 + y^2) / 2, \quad q = \frac{gH}{fR}\left(\begin{matrix}y \\ -x\end{matrix}\right).$$

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm.notebook import trange, tqdm
import firedrake
from firedrake import Constant, inner, perp, div, dx, ds
import irksome
from irksome import Dt

g = Constant(9.81)
f = Constant(1.0)
H = Constant(0.1)
R = Constant(H)

Create the mesh.

In [None]:
mesh = firedrake.UnitDiskMesh(4)
fig, ax = plt.subplots()
ax.set_aspect("equal")
ax.set_axis_off()
firedrake.triplot(mesh, axes=ax);

Create the input data.

In [None]:
x = firedrake.SpatialCoordinate(mesh)

q_expr = g * H / (f * R) * perp(x)
s_expr = (x[0]**2 + x[1]**2) / (2 * R)

s_0 = Constant(H / 5)
r = Constant(1/4)
ξ = Constant((1/4, 1/4))
δs = s_0 * firedrake.exp(-inner(x - ξ, x - ξ) / r**2)

Create a function space, a state variable to store the solution, and initialize it.

In [None]:
degree = 0
s_element = firedrake.FiniteElement("DG", "triangle", degree)
q_element = firedrake.FiniteElement("BDM", "triangle", degree + 1)

Q = firedrake.FunctionSpace(mesh, s_element)
V = firedrake.FunctionSpace(mesh, q_element)
Z = V * Q

z = firedrake.Function(Z)
z.sub(0).project(q_expr)
z.sub(1).project(s_expr + δs);

Form the linear rotation shallow water equations.

In [None]:
def form_problem(z, s_boundary):
    q, s = firedrake.split(z)
    Z = z.function_space()
    v, ϕ = firedrake.TestFunctions(Z)

    # The mass balance equation
    G_mass = (Dt(s) + div(q)) * ϕ * dx

    # The momentum balance equation; rotation and surface slope drive flow
    G_momentum = (inner(Dt(q) + f * perp(q), v) / H - g * s * div(v)) * dx

    # External forcing and frictional drag
    #G_forcing = (inner(F, v) + γ / H * inner(q, v)) * dx

    # Boundary conditions
    n = firedrake.FacetNormal(mesh)
    G_boundary = g * s_boundary * inner(v, n) * ds

    return G_mass + G_momentum + G_boundary

Run the simulation forward in time.

In [None]:
def run_simulation(z, timestep, final_time, degree):
    t = Constant(0.)
    δt = Constant(timestep)
    F = form_problem(z, s_expr)
    method = irksome.GaussLegendre(degree)
    solver = irksome.TimeStepper(F, method, t, δt, z)

    num_steps = int(final_time / timestep)
    zs = [z.copy(deepcopy=True)]
    for step in trange(num_steps):
        solver.advance()
        t.assign(t + δt)
        zs.append(z.copy(deepcopy=True))

    return zs

In [None]:
C = np.sqrt(float(g * H))
δx = mesh.cell_sizes.dat.data_ro.min()
timestep = δx / C
final_time = 10.0
zs = run_simulation(z, timestep, final_time, 1)

In [None]:
dqs = [firedrake.Function(V).project(z.sub(0) - q_expr) for z in zs[1:]]
dss = [firedrake.Function(Q).project(z.sub(1) - s_expr) for z in zs[1:]]

In [None]:
ds_min = np.array([ds.dat.data_ro.min() for ds in dss]).min()
ds_max = np.array([ds.dat.data_ro.max() for ds in dss]).max()
dsm = max(-ds_min, ds_max)
print(dsm)

In [None]:
%%capture

fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)
for ax in axes:
    ax.set_aspect("equal")
    ax.set_axis_off()

kw = {
    "num_sample_points": 4,
    "vmin": -dsm,
    "vmax": +dsm,
    "cmap": "managua",
    "shading": "gouraud",
}
colors = firedrake.tripcolor(dss[0], axes=axes[0], **kw)

X = mesh.coordinates.dat.data_ro
V = mesh.coordinates.function_space()
u_t = dqs[0].copy(deepcopy=True)
interpolator = firedrake.Interpolate(u_t, V)
u_X = firedrake.assemble(interpolator)
u_values = u_X.dat.data_ro
arrows = firedrake.quiver(u_X, axes=axes[1], cmap="Blues")

fn_plotter = firedrake.FunctionPlotter(mesh, num_sample_points=4)
def animate(dz):
    dq, ds = dz
    colors.set_array(fn_plotter(ds))
    u_t.assign(dq)
    u_X = firedrake.assemble(interpolator)
    u_values = u_X.dat.data_ro
    arrows.set_UVC(*(u_values.T))

In [None]:
animation = FuncAnimation(fig, animate, tqdm(list(zip(dqs, dss))), interval=1e3/30)

In [None]:
HTML(animation.to_html5_video())