In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm.notebook import trange, tqdm
import firedrake
from firedrake import (
    Constant, inner, min_value, max_value, jump, avg, grad, dx, ds, dS
)
import irksome
from irksome import Dt

In [None]:
nx, ny = 32, 32
mesh = firedrake.UnitSquareMesh(nx, ny, diagonal="crossed")
δ = 1.0 / nx

In [None]:
degree = 0
element = firedrake.FiniteElement("DG", "triangle", degree)
Q = firedrake.FunctionSpace(mesh, element)

In [None]:
u_max = 1.0
u = firedrake.Constant((u_max, 0.0))

In [None]:
h_in = Constant(1.0)
x = firedrake.SpatialCoordinate(mesh)
L = Constant(0.25)
expr = firedrake.conditional(x[0] < L, 1.0, 0.0)
h_0 = firedrake.Function(Q).project(expr)
h = h_0.copy(deepcopy=True)

In [None]:
ϕ = firedrake.TestFunction(Q)
F_cells = (Dt(h) * ϕ - inner(h * u, grad(ϕ))) * dx

ν = firedrake.FacetNormal(mesh)
f = max_value(0, h * inner(u, ν))
F_facets = jump(f) * jump(ϕ) * dS

F_inflow = h_in * ϕ * min_value(0, inner(u, ν)) * ds
F_outflow = h * ϕ * max_value(0, inner(u, ν)) * ds

F = F_cells + F_facets + F_inflow + F_outflow

In [None]:
method = irksome.BackwardEuler()
t = Constant(0.0)
dt = Constant(0.5 * δ / u_max)

params = {
    "solver_parameters": {
        "snes_type": "newtonls",
        "ksp_type": "gmres",
        "pc_type": "lu",
        "pc_factor_mat_solver_type": "mumps",
    },
}

solver = irksome.TimeStepper(F, method, t, dt, h, **params)

In [None]:
hs = [h.copy(deepcopy=True)]

final_time = 1.0
num_steps = int(final_time / float(dt))
for step in trange(num_steps):
    solver.advance()
    hs.append(h.copy(deepcopy=True))

In [None]:
%%capture

fig, ax = plt.subplots()
ax.set_aspect("equal")
colors = firedrake.tripcolor(
    hs[0], axes=ax, num_sample_points=1, shading="gouraud"
);

fn_plotter = firedrake.FunctionPlotter(mesh, num_sample_points=1)
def animate(h):
    colors.set_array(fn_plotter(h))

In [None]:
animation = FuncAnimation(fig, animate, tqdm(hs), interval=1e3/10)

In [None]:
HTML(animation.to_html5_video())