# Mantle convection, again

Here I'll show some more useful things based on the mantle convection demo from [The FEniCS Book](https://fenicsproject.org/book/) and [van Keken et al (1997)](https://doi.org/10.1029/97JB01353).

The variables to be solved for are the mantle velocity $u$, pressure $p$, and temperature $T$.
The extensive quantity for heat transport is the internal energy density $G = \rho c_p T$ where $\rho$, $c_p$ are the mass density and the specific heat at constant pressure.
The flux of heat when there is bulk movement of the medium at a velocity $u$ is $F = \rho c_p T u - k\nabla T.$
The variational form for heat flow is
$$\int_\Omega\left\{\partial_t(\rho c_p T)\phi - \rho c_p Tu\cdot\nabla\phi + k\nabla T\cdot\nabla\phi - Q\phi\right\}dx = 0$$
for all test functions $\phi$.

### Initial condition

Here I'm wrapping up the creation of the initial condition into a function so that we can call it repeatedly, possible on different meshes.

In [None]:
import firedrake
from firedrake import Constant, sqrt, exp, min_value, max_value
import numpy as np
from numpy import pi as π

def clamp(z, zmin, zmax):
    return min_value(Constant(zmax), max_value(Constant(zmin), z))

def switch(z):
    return exp(z) / (exp(z) + exp(-z))

def initial_temperature(x, Ra, ϵ):
    q = Lx**(7 / 3) / (1 + Lx**4)**(2 / 3) * (Ra / (2 * np.sqrt(π)))**(2/3)
    Q = 2 * sqrt(Lx / (π * q))
    T_u = 0.5 * switch((1 - x[1]) / 2 * sqrt(q / (x[0] + ϵ)))
    T_l = 1 - 0.5 * switch(x[1] / 2 * sqrt(q / (Lx - x[0] + ϵ)))
    T_r = 0.5 + Q / (2 * np.sqrt(π)) * sqrt(q / (x[1] + 1)) * exp(-x[0]**2 * q / (4 * x[1] + 4))
    T_s = 0.5 - Q / (2 * np.sqrt(π)) * sqrt(q / (2 - x[1])) * exp(-(Lx - x[0])**2 * q / (8 - 4 * x[1]))
    return clamp(T_u + T_l + T_r + T_s - Constant(1.5), 0, 1)

### Geometry

Now we'll do something new: create a *mesh hierarchy*.
Rather than start with a relatively refined mesh, we'll instead start with a coarser one.

In [None]:
Lx, Ly = Constant(2.0), Constant(1.0)
ny = 32
nx = int(float(Lx / Ly)) * ny
mesh = firedrake.RectangleMesh(
    nx, ny, float(Lx), float(Ly), diagonal="crossed"
)

### Problem setup

We'll once again wrap up the procedures to create the variational forms into python functions.

In [None]:
pressure_space = firedrake.FunctionSpace(mesh, "CG", 1)
velocity_space = firedrake.VectorFunctionSpace(mesh, "CG", 2)
temperature_space = firedrake.FunctionSpace(mesh, "CG", 1)

Ra = Constant(1e6)
ϵ = Constant(1 / nx)
x = firedrake.SpatialCoordinate(mesh)
expr = initial_temperature(x, Ra, ϵ)
T_0 = firedrake.Function(temperature_space).interpolate(clamp(expr, 0, 1))

Z = velocity_space * pressure_space * temperature_space
z = firedrake.Function(Z)
z.sub(2).assign(T_0);

In [None]:
from firedrake import inner, sym, grad, div, dx, as_vector
from irksome import Dt

μ = Constant(1)
def ε(u):
    return sym(grad(u))

u, p, T = firedrake.split(z)
v, q, ϕ = firedrake.TestFunctions(z.function_space())

τ = 2 * μ * ε(u)
g = as_vector((0, -1))
f = -Ra * T * g
F_momentum = (inner(τ, ε(v)) - q * div(u) - p * div(v) - inner(f, v)) * dx

ρ, c, k = Constant(1), Constant(1), Constant(1)
F_energy = (ρ * c * Dt(T) * ϕ - ρ * c * T * inner(u, grad(ϕ)) + k * inner(grad(T), grad(ϕ))) * dx

F = F_momentum + F_energy

In [None]:
velocity_bc = firedrake.DirichletBC(Z.sub(0), as_vector((0, 0)), "on_boundary")

lower_temp_bc = firedrake.DirichletBC(Z.sub(2), 1, [3])
upper_temp_bc = firedrake.DirichletBC(Z.sub(2), 0, [4])
bcs = [velocity_bc, lower_temp_bc, upper_temp_bc]

basis = firedrake.VectorSpaceBasis(constant=True, comm=firedrake.COMM_WORLD)
nullspace = firedrake.MixedVectorSpaceBasis(Z, [Z.sub(0), basis])

In [None]:
import irksome

method = irksome.BackwardEuler()

parameters = {
    "bcs": bcs,
    "nullspace": [(1, basis)],
    "solver_parameters": {
        "snes_type": "newtontr",
        "ksp_type": "gmres",
        "pc_type": "lu",
        "pc_factor_mat_solver_type": "mumps",
    },
}

t = Constant(0.0)
δt = Constant(1e-4)
solver = irksome.TimeStepper(F, method, t, δt, z, **parameters)

In [None]:
from tqdm.notebook import trange

final_time = 0.25
num_steps = int(final_time / float(δt))
zs = [z.copy(deepcopy=True)]

for step in trange(num_steps):
    solver.advance()
    zs.append(z.copy(deepcopy=True))