In [None]:
import firedrake

nx, ny = 24, 24
Lx, Ly = 20.0, 20.0
mesh = firedrake.RectangleMesh(nx, ny, Lx, Ly, diagonal="crossed")

In [None]:
degree = 1
Q = firedrake.FunctionSpace(mesh, "DG", degree)
V = firedrake.VectorFunctionSpace(mesh, "DG", degree)

Z = Q * V * Q
z_0 = firedrake.Function(Z)

In [None]:
from firedrake import max_value, min_value, inner, as_vector, Constant, ds

x = firedrake.SpatialCoordinate(mesh)

b_0 = Constant(0.0)
δb = Constant(0.2)
b = b_0 - δb * x[0] / Lx

In [None]:
H = Constant(1.0)
u_in = Constant(2.5)
S_in = Constant(1.0)

h_0, u_0, S_0 = z_0.split()
h_0.project(H)
u_0.project(as_vector((u_in, 0)));
S_0.project(S_in)

In [None]:
import numpy as np
from plumes.coefficients import gravity
C = abs(float(u_in)) + np.sqrt(gravity * float(H))
δx = mesh.cell_sizes.dat.data_ro[:].min()
timestep = δx / C / (2 * degree + 1)

final_time = 12 * Lx / C
num_steps = int(final_time / timestep)
dt = final_time / num_steps

output_time = 1 / 30
output_freq = max(int(output_time / dt), 1)

In [None]:
from plumes.models import shallow_water

inflow_ids = (1,)
outflow_ids = (2,)

g = Constant(gravity)
bcs = {
    "thickness_in": H,
    "velocity_in": Constant((u_in, 0.0)),
    "inflow_ids": inflow_ids,
    "outflow_ids": outflow_ids,
}
wave_equation = shallow_water.make_equation(g, b, form="velocity", **bcs)

In [None]:
from firedrake import sqrt, dx

ξ = Constant((Lx / 2, Lx / 2))
R = Constant(Lx / 8)
k_0 = Constant(1.0)
k = k_0 * max_value(0, 1 - inner(x - ξ, x - ξ) / R**2)

def friction_equation(z):
    Z = z.function_space()
    ϕ, v = firedrake.TestFunctions(Z)[:2]
    h, u = firedrake.split(z)[:2]

    U = inner(u, u)**(1 / 2)
    return -k * U * inner(u, v) * dx

In [None]:
from plumes.models import forms

def salt_equation(z):
    Z = z.function_space()
    h, u, S = firedrake.split(z)
    ϕ, v, η = firedrake.TestFunctions(Z)
    
    f_S = h * S * u
    n = firedrake.FacetNormal(mesh)
    c = abs(inner(u, n))
    
    salt_fluxes = (
        forms.cell_flux(f_S, η) +
        forms.central_facet_flux(f_S, η) +
        forms.lax_friedrichs_facet_flux(S, c, η) +
        S * max_value(0, inner(u, n)) * η * ds(outflow_ids) +
        S_in * min_value(0, inner(u, n)) * η * ds(inflow_ids)
    )
    
    return -salt_fluxes

In [None]:
def equation(z):
    return wave_equation(z) + friction_equation(z) + salt_equation(z)

In [None]:
def conserved_variables(z):
    Z = z.function_space()
    h, u, S = firedrake.split(z)
    
    return firedrake.as_vector((h, h * u[0], h * u[1], h * S))

In [None]:
from plumes import numerics

params = {
    "form_compiler_parameters": {"quadrature_degree": 4},
    "solver_parameters": {"mat_type": "aij", "snes_type": "ksponly"},
}

integrator = numerics.ImplicitEuler(
    equation, z_0, conserved_variables=conserved_variables, **params
)

In [None]:
import tqdm

hs = []
us = []
Ss = []

progress_bar = tqdm.trange(num_steps)
for step in progress_bar:
    if step % output_freq == 0:
        z = integrator.state
        h, u, S = z.split()
        hmin, hmax = h.dat.data_ro[:].min(), h.dat.data_ro[:].max()
        progress_bar.set_description(f'{hmin:5.3f}, {hmax:5.3f}')
        hs.append(h.copy(deepcopy=True))
        us.append(u.copy(deepcopy=True))
        Ss.append(S.copy(deepcopy=True))
    
    integrator.step(dt)

In [None]:
%%capture
Q0 = firedrake.FunctionSpace(mesh, 'DG', 0)
η = firedrake.project(hs[0] + b, Q0)

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

fig, axes = plt.subplots()
axes.set_aspect('equal')
axes.get_xaxis().set_visible(False)
axes.get_yaxis().set_visible(False)
axes.set_xlim((0, Lx))
axes.set_ylim((0, Ly))
colors = firedrake.tripcolor(
    η, num_sample_points=1, vmin=0.5, vmax=1.2, axes=axes
)

def animate(h):
    η.project(h + b)
    colors.set_array(η.dat.data_ro[:])

interval = 1e3 * output_freq * dt
animation = FuncAnimation(fig, animate, frames=hs, interval=interval)

In [None]:
from IPython.display import HTML
HTML(animation.to_html5_video())