In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm.notebook import trange, tqdm
import firedrake
from firedrake import Constant, max_value, inner, grad, jump, dx, ds, dS
import irksome
from irksome import Dt
from icepack2 import model

In [None]:
nx, ny = 64, 64
mesh = firedrake.UnitSquareMesh(nx, ny, quadrilateral=True)
δ = 1.0 / nx

In [None]:
degree = 0
element = firedrake.FiniteElement("DQ", "quadrilateral", degree)
Q = firedrake.FunctionSpace(mesh, element)

In [None]:
u_max = 1.0
u = Constant((u_max, 0.0))

In [None]:
h_in = Constant(1.0)
x = firedrake.SpatialCoordinate(mesh)
L = Constant(0.25)
expr = firedrake.conditional(x[0] < L, 1.0, 0.0)
h_0 = firedrake.Function(Q).project(expr)
h = h_0.copy(deepcopy=True)

In [None]:
a = Constant(0.0)
m_0 = Constant(-1.0)
δm_δx = Constant(4.0)
#m = max_value(0, δm_δx * x[0] + m_0)
m = Constant(0.0)

In [None]:
U_c = Constant(1.0)
x_c = Constant(0.5)

u_c = firedrake.as_vector(
    (firedrake.conditional(x[0] >= x_c, U_c, 0), 0)
)

h_c = Constant(1.5)
f_c = h_c * u_c

In [None]:
ϕ = firedrake.TestFunction(Q)
F = model.mass_balance(
    thickness=h,
    velocity=u,
    accumulation=a + m,
    thickness_inflow=h_in,
    test_function=ϕ,
)

ν = firedrake.FacetNormal(mesh)
F_c = max_value(0, inner(f_c, ν))
F += (
    -inner(f_c, grad(ϕ)) * dx
    + jump(F_c) * jump(ϕ) * dS
    + F_c * ϕ * ds
)

In [None]:
method = irksome.BackwardEuler()
t = Constant(0.0)
dt = Constant(0.5 * δ / u_max)

lower = firedrake.Function(Q)
upper = firedrake.Function(Q)
upper.assign(np.inf)
bounds = ("stage", lower, upper)

params = {
    "solver_parameters": {
        "snes_monitor": ":frontal-ablation.log",
        "snes_atol": 1e-12,
        "snes_type": "vinewtonrsls",
    },
    "stage_type": "value",
    "basis_type": "Bernstein",
    "bounds": bounds,
}

solver = irksome.TimeStepper(F, method, t, dt, h, **params)

In [None]:
hs = [h.copy(deepcopy=True)]

final_time = 2.0
num_steps = int(final_time / float(dt))
for step in trange(num_steps):
    solver.advance()
    hs.append(h.copy(deepcopy=True))

In [None]:
%%capture

fig, ax = plt.subplots()
ax.set_aspect("equal")
colors = firedrake.tripcolor(
    hs[0], axes=ax, num_sample_points=1, shading="gouraud"
);

fn_plotter = firedrake.FunctionPlotter(mesh, num_sample_points=1)
def animate(h):
    colors.set_array(fn_plotter(h))

In [None]:
animation = FuncAnimation(fig, animate, tqdm(hs), interval=1e3/10)

In [None]:
HTML(animation.to_html5_video())

In [None]:
hs[-1].at((0.75, 0.5))

In [None]:
from firedrake import dx
expr = firedrake.conditional(x[0] <= x_c, h_in, h_in / 2)
firedrake.assemble(abs(hs[-1] - expr) * dx) / firedrake.assemble(abs(hs[-1]) * dx)