This notebook shows that the continuous and discrete adjoint methods give the same result if no stabilisation is applied.

### TO DO:
* What's going on with Crank-Nicolson?
* Fix trapezium

In [1]:
from firedrake import *
from firedrake_adjoint import *
from firedrake.adjoint.blocks import GenericSolveBlock, ProjectBlock

In [2]:
import matplotlib.pyplot as plt
import scipy.interpolate as si

In [3]:
from adapt_utils.case_studies.tohoku.options.options import TohokuInversionOptions
from adapt_utils.misc import gaussian, ellipse

Specify 'optimum' and initial guess

In [4]:
m_opt = 5.0
m_prior = 10.0

Create `Options` object and setup gauges.

In [5]:
level = 2
op = TohokuInversionOptions(level=level)
gauges = list(op.gauges.keys())
for gauge in gauges:
    if gauge[:2] not in ('P0', '80'):
        op.gauges.pop(gauge)
gauges = list(op.gauges.keys())

  knl = loopy.make_function(kernel_domains, instructions, kargs, seq_dependencies=True,


Setup timestepping parameters

In [6]:
num_timesteps = 30
dt = Constant(op.dt)
theta = Constant(1.0)

Create function spaces

In [7]:
mesh = op.default_mesh
P2 = VectorFunctionSpace(mesh, "CG", 1)
P1 = FunctionSpace(mesh, "CG", 1)
V = P2*P1
R = FunctionSpace(mesh, "R", 0)

Create `Function`s

In [8]:
u, eta = TrialFunctions(V)
z, zeta = TestFunctions(V)
m = Function(R).assign(m_prior)
c = Control(m)
eta0 = Function(P1)
eta0.interpolate(gaussian([(0.7e+06, 4.2e+06, 48e+03, 96e+03)], mesh, rotation=pi/12))
q_ = Function(V)
u_, eta_ = q_.split();

In [9]:
# fig, axes = plt.subplots(figsize=(7, 6))
# tc = tricontourf(eta0, axes=axes, cmap='coolwarm')
# cb = plt.colorbar(tc, ax=axes)
# axes.axis(False);

Set physical parameters

In [10]:
b = Function(P1).assign(op.set_bathymetry(P1))
g = Constant(9.81)

Setup forward variational problem

In [11]:
a = inner(u, z)*dx + eta*zeta*dx + theta*dt*g*inner(grad(eta), z)*dx - theta*dt*b*inner(u, grad(zeta))*dx
L = inner(u_, z)*dx + eta_*zeta*dx - (1-theta)*dt*g*inner(grad(eta_), z)*dx + (1-theta)*dt*b*inner(u_, grad(zeta))*dx
q = Function(V)
u, eta = q.split()
bc = DirichletBC(V.sub(1), 0, 100)
problem = LinearVariationalProblem(a, L, q, bcs=bc)

Setup forward variational solver

In [12]:
sp = {
    "snes_type": "ksponly",
    "ksp_type": "gmres",
    "pc_type": "fieldsplit",
    "pc_fieldsplit_type": "multiplicative",
}
solver = LinearVariationalSolver(problem, solver_parameters=sp)

Setup QoI

In [13]:
P0 = FunctionSpace(mesh, "DG", 0)
J_form = 0
radius = 20e+03*2**level
for gauge in gauges:
    op.gauges[gauge]["data"] = []
    k = Function(P0*P0)
    ku, keta = k.split()
    keta.interpolate(ellipse([op.gauges[gauge]["coords"] + (radius,)], mesh))
    keta.assign(keta/assemble(keta*dx))
#     qd = Function(V)
#     ud, op.gauges[gauge]["obs"] = qd.split()
    op.gauges[gauge]["obs"] = Function(R)
    op.gauges[gauge]["obs_old"] = Function(R)
    J_form += 0.5*dt*keta*(eta - op.gauges[gauge]["obs"])**2*dx
    op.gauges[gauge]["indicator"] = keta

In [14]:
# fig, axes = plt.subplots(figsize=(6, 6))
# tricontourf(keta, axes=axes, cmap='coolwarm')
# axes.axis(False);

Solve forward to generate 'data'

In [15]:
with stop_annotating():
    u_.assign(0.0)
    eta_.project(m_opt*eta0)
#     for gauge in gauges:
#         op.gauges[gauge]["data"].append(float(eta_.at(op.gauges[gauge]["coords"])))
    for i in range(num_timesteps):
        print(f"t = {i*op.dt/60:5.1f} min  ||u|| = {norm(u):.4e}  ||eta|| = {norm(eta):.4e}")
        solver.solve()
        q_.assign(q)
        for gauge in gauges:
            op.gauges[gauge]["data"].append(float(eta.at(op.gauges[gauge]["coords"])))
    print(f"t = {(i+1)*op.dt/60:5.1f} min  ||u|| = {norm(u):.4e}  ||eta|| = {norm(eta):.4e}")

t =   0.0 min  ||u|| = 0.0000e+00  ||eta|| = 0.0000e+00
t =   1.0 min  ||u|| = 4.8480e+03  ||eta|| = 3.9131e+05
t =   2.0 min  ||u|| = 8.7409e+03  ||eta|| = 3.5397e+05


In [16]:
# fig, axes = plt.subplots(figsize=(7, 6))
# tc = tricontourf(eta, axes=axes, cmap='coolwarm')
# cb = plt.colorbar(tc, ax=axes)
# axes.axis(False);

In [17]:
adj_free_form = 0
for gauge in gauges:
    adj_free_form += dt*op.gauges[gauge]["indicator"]*(eta - op.gauges[gauge]["obs"])*eta*dx

Solve forward to annotate to tape

In [18]:
J = 0
adj_free = 0
# solutions = [q.copy(deepcopy=True)]
solutions = []
u_.assign(0.0)
eta_.project(m*eta0)
for i in range(num_timesteps):
    print(f"t = {i*op.dt/60:5.1f} min  ||u|| = {norm(u):.4e}  ||eta|| = {norm(eta):.4e}")
#     J += assemble(theta*J_form)
#     adj_free += assemble(theta*adj_free_form)
    solver.solve()
    q_.assign(q)
    solutions.append(q.copy(deepcopy=True))
    for gauge in gauges:
        op.gauges[gauge]["obs"].assign(op.gauges[gauge]["data"][i])
#     J += assemble((1-theta)*J_form)
#     adj_free += assemble((1-theta)*adj_free_form)
    J += assemble(J_form)
    adj_free += assemble(adj_free_form)
J10 = float(J)
gaf10 = float(adj_free)/10
print(f"t = {(i+1)*op.dt/60:5.1f} min  ||u|| = {norm(u):.4e}  ||eta|| = {norm(eta):.4e}")
stop_annotating();

t =   0.0 min  ||u|| = 8.7409e+03  ||eta|| = 3.5397e+05
t =   1.0 min  ||u|| = 9.6961e+03  ||eta|| = 7.8262e+05
t =   2.0 min  ||u|| = 1.7482e+04  ||eta|| = 7.0795e+05


In [19]:
# fig, axes = plt.subplots(figsize=(7, 6))
# tricontourf(eta, axes=axes, cmap='coolwarm')
# cb = plt.colorbar(tc, ax=axes)
# axes.axis(False);

Create `ReducedFunctional` and compute gradient using discrete adjoint

In [20]:
Jhat = ReducedFunctional(J, c)
gd10 = Jhat.derivative().dat.data[0]

Setup continuous adjoint variational problem and solver

In [21]:
phi = TestFunction(V)
cont_adj = Function(V)
cont_adj_ = Function(V)
z_, zeta_ = cont_adj_.split()
a_star = adjoint(a)
# L_star = inner(phi, cont_adj_)*dx + derivative(J_form, eta, zeta)
L_star = replace(L, {z: z_, zeta: zeta_})
L_star = replace(L_star, {u_: z, eta_: zeta})
for gauge in gauges:
    L_star += theta*dt*op.gauges[gauge]["indicator"]*(eta - op.gauges[gauge]["obs"])*zeta*dx
#     L_star += (1-theta)*dt*op.gauges[gauge]["indicator"]*(eta_ - op.gauges[gauge]["obs_old"])*zeta*dx
adj_problem = LinearVariationalProblem(a_star, L_star, cont_adj, bcs=bc)
adj_solver = LinearVariationalSolver(adj_problem, solver_parameters=sp)

Solve continuous adjoint problem

In [23]:
z, zeta = cont_adj.split()
for i in range(num_timesteps-1, -1, -1):
    print(f"t = {(i+1)*op.dt/60:5.1f} min  ||z|| = {norm(z):.4e}  ||zeta|| = {norm(zeta):.4e}")
#     discrete_adj = solve_blocks[i].adj_sol
    cont_adj_.assign(cont_adj)
    for gauge in gauges:
        op.gauges[gauge]["obs"].assign(op.gauges[gauge]["data"][i])
#         op.gauges[gauge]["obs_old"].assign(op.gauges[gauge]["data"][i-1])
    q.assign(solutions[i])
#     q_.assign(solutions[i-1])
    adj_solver.solve()
print(f"t = {0*op.dt/60:5.1f} min  ||z|| = {norm(z):.4e}  ||zeta|| = {norm(zeta):.4e}")

t =   1.0 min  ||z|| = 0.0000e+00  ||zeta|| = 0.0000e+00
t =   0.0 min  ||z|| = 4.7255e-03  ||zeta|| = 1.0057e-03
t =   0.0 min  ||z|| = 1.1506e-02  ||zeta|| = 1.5677e-03


Compute gradient using continuous adjoint method and check that it matches

In [24]:
gc10 = assemble(eta0*cont_adj.split()[1]*dx)
assert np.isclose(gc10, gd10), f"{gc10} vs. {gd10}"

Sample control space at two more points and compute the discrete adjoint gradient at 5

In [25]:
J2 = Jhat(m.assign(2.0))
J5 = Jhat(m.assign(5.0))
gd5 = Jhat.derivative().dat.data[0]

Rerun the forward solver at 5

In [26]:
J = 0
adj_free = 0
solutions = []
u_.assign(0.0)
eta_.project(5*eta0)
for i in range(num_timesteps):
    solver.solve()
    q_.assign(q)
    solutions.append(q.copy(deepcopy=True))
    for gauge in gauges:
        op.gauges[gauge]["obs"].assign(op.gauges[gauge]["data"][i])
    J += assemble(J_form)
    adj_free += assemble(adj_free_form)
assert np.isclose(float(J), J5)
gaf5 = float(adj_free)/5

Compute the continuous adjoint gradient

In [27]:
cont_adj.assign(0.0)
for i in range(num_timesteps-1, -1, -1):
    cont_adj_.assign(cont_adj)
    for gauge in gauges:
        op.gauges[gauge]["obs"].assign(op.gauges[gauge]["data"][i])
    q.assign(solutions[i])
    adj_solver.solve()
gc5 = assemble(eta0*cont_adj.split()[1]*dx)

In [28]:
L = si.lagrange([2, 5, 10], [J2, J5, J10])
dL = L.deriv()
L_min = -dL.coefficients[1]/dL.coefficients[0]
print("Minimiser of quadratic: {:.4f}".format(L_min))

Minimiser of quadratic: 3.5982


In [29]:
print(f"Exact gradient at 10              = {dL(10.0)}")
print(f"Discrete adjoint gradient at 10   = {gd10}")
print(f"Continuous adjoint gradient at 10 = {gc10}")
print(f"Adjoint-free gradient at 10       = {gaf10}")

Exact gradient at 10              = 10.464451194048001
Discrete adjoint gradient at 10   = 10.464451115366826
Continuous adjoint gradient at 10 = 10.464451119270135
Adjoint-free gradient at 10       = 10.464451079461266


In [30]:
print(f"Exact gradient at  5              = {dL(5.0)}")
print(f"Discrete adjoint gradient at 5    = {gd5}")
print(f"Continuous adjoint gradient at 5  = {gc5}")
print(f"Adjoint-free gradient at 10       = {gaf5}")

Exact gradient at  5              = 2.2913393130917274
Discrete adjoint gradient at 5    = 2.291339302240442
Continuous adjoint gradient at 5  = 2.2913398250282166
Adjoint-free gradient at 10       = 2.2913396500067327


### $m = 10$

| Control | Exact | Discrete adjoint gradient | Continuous adjoint gradient | Adjoint-free gradient |
| --- | --- | --- | --- | --- |
| $\mathcal H_0$ | 996.9466319771284 | 996.9466416950728 | 996.9466387449290 | 996.9466477902282 |
| $\mathcal H_1$ | 609.2850012285952 | 609.2849884854936 | 609.2849774146049 | 609.2849984312206 |
| $\mathcal H_2$ | ... | ... | ... | ... |

### $m=5$

| Control | Exact | Discrete adjoint gradient | Continuous adjoint gradient | Adjoint-free gradient |
| --- | --- | --- | --- | --- |
| $\mathcal H_0$ | 84.64718438015461 | 84.64719747549468 | 84.64720660517209 | 84.64720964339078 |
| $\mathcal H_1$ | 155.87189673017906 | 155.87189515509976 | 155.87189898928300 | 155.87191175728236 |
| $\mathcal H_2$ | ... | ... | ... | ... |

Plot the parameter space

In [None]:
fig, axes = plt.subplots(figsize=(8, 8))
xx = np.linspace(2, 10, 100)
axes.plot(xx, L(xx), ':', color='C0')
axes.plot(L_min, L(L_min), '*', markersize=14, color='C0', label=r"$m^\star={:.4f}$".format(L_min))
axes.set_xlabel("Control parameter")
axes.set_ylabel("Quantity of Interest")
axes.grid(True)
axes.legend();