In [None]:
import firedrake
mesh = firedrake.UnitSquareMesh(32, 32)
degree = 1
temperature_space = firedrake.FunctionSpace(mesh, 'CG', degree)

In [None]:
from firedrake import exp, Constant
def Θ(x):
    return exp(x) / (exp(x) + exp(-x))

T_0 = Constant(0.0)
δT = Constant(1.0)
x = firedrake.SpatialCoordinate(mesh)
x_0 = Constant(0.25)
α = Constant(8.0)
T_Γ = Θ(α * (x_0 - x[0]))

In [None]:
from firedrake import grad, as_vector
velocity_space = firedrake.VectorFunctionSpace(mesh, 'CG', 2)
U = Constant(50)
Ψ = U * x[0]**2 * (1 - x[0])**2 * x[1]**2 * (1 - x[1])**2
grad_Ψ = grad(Ψ)
u = firedrake.interpolate(
    as_vector((-grad_Ψ[1], grad_Ψ[0])),
    velocity_space,
)

In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots()
axes.set_aspect('equal')
streamlines = firedrake.streamplot(u, axes=axes, seed=1729)
fig.colorbar(streamlines);

In [None]:
from firedrake import inner, dx, ds
ρ = Constant(1)
c = Constant(1)
k = Constant(1e-3)
h = Constant(10 * k)

T = firedrake.Function(temperature_space)
T_n = T.copy(deepcopy=True)
J_mass = 0.5 * ρ * c * (T - T_n)**2 * dx
J_cells = 0.5 * k * inner(grad(T), grad(T)) * dx
J_boundary = 0.5 * h * (T - T_Γ)**2 * ds((3,)) + 0.5 * h * T**2 * ds((1, 2, 4))

δt = Constant(1e-1)
J = J_mass + δt * (J_cells + J_boundary)
F_diffusive = firedrake.derivative(J, T)

In [None]:
ϕ = firedrake.TestFunction(temperature_space)
F_advective = -δt * ρ * c * T * inner(u, grad(ϕ)) * dx

In [None]:
from firedrake import (
    NonlinearVariationalProblem as Problem,
    NonlinearVariationalSolver as Solver,
)

F = F_diffusive + F_advective
problem = Problem(F, T)
solver = Solver(problem)

In [None]:
import tqdm
final_time = 1e2
num_steps = int(final_time / float(δt))
Ts = [T.copy(deepcopy=True)]
output_freq = 10
for step in tqdm.trange(num_steps):
    solver.solve()
    T_n.assign(T)
    
    if (step + 1) % output_freq == 0:
        Ts.append(T.copy(deepcopy=True))

In [None]:
%%capture
fig, axes = plt.subplots()
axes.set_aspect('equal')
colors = firedrake.tripcolor(
    Ts[0], num_sample_points=4, vmin=0.0, vmax=0.05, axes=axes
)
fig.colorbar(colors);

In [None]:
from matplotlib.animation import FuncAnimation
fn_plotter = firedrake.FunctionPlotter(mesh, num_sample_points=4)
def animate(T):
    colors.set_array(fn_plotter(T))
    
interval = 1e3 * output_freq * float(δt) / 10
animation = FuncAnimation(fig, animate, frames=Ts, interval=interval)

In [None]:
from IPython.display import HTML
HTML(animation.to_jshtml())