In the previous demo, we looked at solving the nonlinear shallow water equations using various schemes.
Here we'll make things yet more interesting by solving them on a sphere.
We'll largely follow the treatment in [Bernard et al. (2009)](https://doi.org/10.1016/j.jcp.2009.05.046).

### Model

All of this is copied from the previous demos, so I'll give only cursory explanations.

In [None]:
import firedrake
from firedrake import Constant
g = Constant(9.81)
I = firedrake.Identity(2)

The following functions compute symbolic representations of the various shallow water fluxes.

In [None]:
from firedrake import inner, grad, dx

def cell_flux(z):
    Z = z.function_space()
    h, u = firedrake.split(z)
    ϕ, v = firedrake.TestFunctions(Z)
    
    f_h = -inner(h * u, grad(ϕ)) * dx

    F = h * outer(u, u) + 0.5 * g * h**2 * I
    f_u = -inner(F, grad(v)) * dx

    return f_h + f_u

See [this code](https://github.com/firedrakeproject/gusto/blob/8fb6c67307727d77da784f0f9bff6f75fa3f55c6/gusto/transport_equation.py#L266) from Gusto.

In [None]:
from firedrake import avg, outer, dS

def central_facet_flux(z):
    Z = z.function_space()
    h, u = firedrake.split(z)
    ϕ, v = firedrake.TestFunctions(Z)

    mesh = z.ufl_domain()
    n = firedrake.FacetNormal(mesh)

    f_h = inner(avg(h * u), ϕ("+") * n("+") + ϕ("-") * n("-")) * dS

    F = h * outer(u, u) + 0.5 * g * h**2 * I
    f_u = inner(avg(F), outer(v("+"), n("+")) + outer(v("-"), n("-"))) * dS

    return f_h + f_u

In [None]:
from firedrake import sqrt, max_value

def lax_friedrichs_facet_flux(z):
    Z = z.function_space()
    h, u = firedrake.split(z)
    ϕ, v = firedrake.TestFunctions(Z)

    mesh = h.ufl_domain()
    n = firedrake.FacetNormal(mesh)

    c = abs(inner(u, n)) + sqrt(g * h)
    α = avg(c)

    f_h = -α * (h("+") - h("-")) * (ϕ("+") - ϕ("-")) * dS
    f_u = -α * inner(u("+") - u("-"), v("+") - v("-")) * dS

    return f_h + f_u

In [None]:
def conserved_variables(z):
    h, u = firedrake.split(z)
    return firedrake.as_vector((h, h * u[0], h * u[1]))

In [None]:
def topographic_forcing(z, b):
    Z = z.function_space()
    h = firedrake.split(z)[0]
    v = firedrake.TestFunctions(Z)[1]

    return -g * h * inner(grad(b), v) * dx

We'll add one more bit of physics to this problem that wasn't included in previous demos: rotation.

In [None]:
def coriolis(z):
    Z = z.function_space()
    u = firedrake.split(z)
    v = firedrake.TestFunctions(Z)[1]

    inverse_day = firedrake.Constant(1.0 / (24 * 60 * 60))
    #f = firedrake.as_vector((0, 0, inverse_day))
    # for debugging:
    f = firedrake.as_vector((0, 0, 0))
    
    return inner(firedrake.cross(f, u), v) * dx

### Solver

We'll take a shortcut to implementing the Rosenbrock implicit midpoint scheme here.
Rather than explicitly include the functional derivative terms, we can just implement the implicit midpoint scheme.
The PETSc option `"snes_type": "ksponly"` will cause the solver to take only a single step of Newton's method.

In [None]:
from firedrake import (
    NonlinearVariationalProblem as Problem,
    NonlinearVariationalSolver as Solver,
)

class ImplicitMidpoint:
    def __init__(self, state, equation, conserved_variables, solver_parameters=None):
        z = state.copy(deepcopy=True)
        dt = firedrake.Constant(1.0)

        z_n = z.copy(deepcopy=True)
        Z = z.function_space()
        w = firedrake.TestFunction(Z)
        
        F = firedrake.replace(equation(z), {z: (z + z_n) / 2})
        Q = conserved_variables(z)
        Q_n = conserved_variables(z_n)
        
        problem = Problem(inner(Q_n - Q, w) * dx - dt * F, z_n)
        solver = Solver(problem, solver_parameters=solver_parameters)
        
        self.state = z
        self.next_state = z_n
        self.timestep = dt
        self.solver = solver
    
    def step(self, timestep):
        self.timestep.assign(timestep)
        self.solver.solve()
        self.state.assign(self.next_state)

### Demonstration

We'll use the same function spaces and timestepping scheme as before: BDFM(2) for the momentum, DG(1) for the thickness, and a Rosenbrock form of the implicit midpoint rule.

In [None]:
radius = 6.370e6
level = 2
mesh_degree = 1

mesh = firedrake.IcosahedralSphereMesh(radius, level, mesh_degree)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

fig = plt.figure()
ax = fig.add_subplot(projection="3d")
firedrake.triplot(mesh, axes=ax);

In [None]:
Q = firedrake.FunctionSpace(mesh, "DG", 1)
V = firedrake.FunctionSpace(mesh, "BDFM", 2)
Z = Q * V

In [None]:
z0 = firedrake.Function(Z)
z0.sub(0).project(firedrake.Constant(1e3));

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
colors = firedrake.trisurf(z0.sub(0), vmin=1e3-1, vmax=1e3+1, axes=ax)
fig.colorbar(colors);