# Computation of adjoint-based gradient with checkpoint_schedules package

Checkpointing is a method employed to reduce the peak of memory usage in adjoint-based gradients, which are at the core of many scientific applications, such as sensitivity analyses in fluids, inverse problems, and topology optimization.

In summary, the checkpoint_schedules package prescribes the combination of an original forward computation together one adjoint computation, e.g., when checkpoint data should be stored and loaded and when the forward or adjoint should advance. Therefore, this tutorial aims to apply checkpointing schedules to a simplified case of adjoint-based gradient computation.

## Defining the application
Let us consider a one-dimensional problem where it aims to compute the gradient/sensitivity of an objective functional given by the expression below with respect to a control parameter:
\begin{equation}
    I(u) = \int_{\Omega} \frac{1}{2} \frac{u(x, \tau)u(x, \tau)}{u_0 u_0} \, d x
    \tag{1}
\end{equation}

The variable $u = u(x, t)$ is governed by the viscous Burgers equation, a nonlinear equation for the advection and diffusion on momentum in one dimension:
\begin{equation}
\frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} - \nu \frac{\partial^2 u}{\partial x^2} = 0,
\tag{2}
\end{equation}
which satisfies the boundary condition $u(0, t) = u(L, t) = 0$. The initial condition is given by $u(0, t) = u_0 = \sin(\pi x)$.

In this case, the goal is to compute the sensitivity of the objective functional $I(u)$ with respect to the initial condition $u_0$ that is the initial condition of the problem. 

In the adjoint-based gradient, it is necessary to define the adjoint system that is given as a function of the forward PDE (Partial Differential Equation) and an objective functional of interest. The adjoint system can be obtained in a continuous or discrete space. The former deals with the continuous PDE to define the adjoint PDE in a continuous space, whereas the latter uses the discrete forward PDE to obtain a discrete adjoint system. Discrete adjoint can be also determined by an algorithmic differentiation (AD). In this tutorial, we present the adjoint-based gradient defined in a continuous space that is given by the following expression:
$$\frac{\partial I}{\partial u_0} \delta u_0 = \int_{\Omega}  u^{\dagger}(x, 0) \delta u_0 \, dx,$$
where $u^{\dagger}(x, 0)$ is the solution of an adjoint system that in the current case reads:
\begin{equation}
-\frac{\partial u^{\dagger}}{\partial t} + u^{\dagger} \frac{\partial u}{\partial x} - u \frac{\partial u^{\dagger}}{\partial x} - \nu \frac{\partial^2 u^{\dagger}}{\partial x^2} = 0,
 \tag{3}
\end{equation}
satisfying the boundary condition $u^{\dagger} (0, t) = u^{\dagger}(L, t) = 0$. For the current adjoint-based gradient problem, the initial condition is $u^{\dagger} (x, \tau) = u(x, \tau)$. 

Notice the initial condition of the adjoint equation is given in the time $\tau$. That means the adjoint system is inverse in time. Therefore, computing the adjoint-based gradient means storing the forward solution of every time-step since the adjoint equation depends on the forward solution. In addition, the gradient expression is at a function of the $u^{\dagger} (0, t)$ that is the final adjoint solution at the final adjoint time 0.   

## Discretisation

We discretise the forward and adjoint systems with Finite Element Method (FEM) following the discretization employed in [1], which shows the forward discretisation methodology in detail. The adjoint discretization also follows equal FEM basis and time discretisation employed in the forward system.

## Implementation

The forward and adjoint solutions are numpy arrays. The forward and adjoint sparse matrices are built with scipy library, which is also used to solve the adjoint system. 

In [1]:
import numpy as np
import functools
import matplotlib.pyplot as plt
from sympy import *
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve

As mentioned above, the checkpoint_schedules package prescribes the combination of an original forward computation together one adjoint computation. So, from the checkpoint_schdeules library we import the actions used to execute our current adjoint-based gradient with the H-Revolve checkpointing method [2].

Checkpoint_schedules package define actions used to execute the forward and adjoint solvers. Hence, it is necessary to import the checkpoint_schedules actions.

In [6]:
from checkpoint_schedules import Forward, EndForward, Reverse, Copy, EndReverse

The role of the actions imported above is described as follows. 

- The Forward action is used to advance the forward solver for a step number (s) and to configure the intermediate storage. 
- The Reverse is used to execute the adjoint solver from the start of step n0 to step n1 and to clear the forward data used in the adjoint solver. 
- The Copy action is used to copy the data from a storage level (either RAM or disk) to the forward tape and to indicate whether this data can be deleted from the storage level. 
- The actions End_Forward and End_Reverse indicate whether the forward and reverse solver are finalised, respectively.

From the checkpoint_schedules we also import the RevolveCheckpointSchedule and StorageLocation objects. The former provide an iterator over the sequence of actions, whereas the second provides the locations to store and copy the forward data. The sotrage locations are RAM, DISK, TAPE and NONE. RAM and DISK are described as first and second level of storage, respectivelly. These levels stores  the checkpoint data used to restart the forward solver. TAPE refers to the local that holds the checkpoint data used as initial condition of the forward solver. NONE... 


In [2]:

from checkpoint_schedules import RevolveCheckpointSchedule, StorageLocation

The code below implements the discretizations of the forward and adjoint equations. The burgers_equation constructor defines the spatial and time configurations. The forward burger's system is implemented in the forward method, whereas the adjoint system is implemented in backward method.

In [3]:
class burger_equation():
    """This object define the forward burger's equation its
    and adjoint system.

    Attributes
    ----------
    L : float
        The domain lenght
    nx : int
        Number of nodes.
    dt : float
        Time step.
    T : float
        Period of time.
    nu : float
        The viscosity.
    """
    def __init__(self, L, nx, dt, T, nu):
        self.nt = int(T/dt)
        self.nu = nu
        self.dx = L / (nx - 1)
        self.dt = dt
        self.nx = nx

    def forward(self, u0, t0, tf):
        """Execute the foward system in time.

        Paramters
        ---------
        u0 : numpy
            Forward initial condition.
        t0 : int
            Initial step.
        tn : int
            Final step.
        """
        dx = self.dx
        nx = self.nx
        dt = self.dt
        b = self.nu/(dx * dx)
        u = u0
        u_new = np.zeros(nx)
        b = self.nu/(dx*dx)
        steps = int(tf - t0)
        t = 0
        while t < steps:
            # Assemble of the mattrices system
            A = lil_matrix((nx, nx),)
            B = lil_matrix((nx, nx))
            A[0, 0] = 2/3 + b * dt 
            A[0, 1] = 1/6 - b*dt/2 + dt/4*u[1]/dx
            A[self.nx - 1, nx - 1] = 2/3 + b * dt
            A[self.nx - 1, nx - 2] = 1/6 - b * dt/2 - dt/4 * u[nx - 2]/dx
            B[0, 0] = 2/3 - b*dt
            B[0, 1] = 1/6 + b*dt/2 - dt/4*u[1]/dx
            B[nx - 1, nx - 1] = 2/3 - b * dt
            B[nx - 1, nx - 2] = 1/6 + b * dt/2 + dt/4 * u[nx - 2]/dx
            for i in range(1, nx - 1):
                v_m = u[i]/dx
                v_mm1 = u[i - 1]/dx
                A[i, i - 1] = 1/6 - b * dt/2 - dt/4*v_mm1
                A[i, i] = 2/3 + b * dt + dt/4*(v_mm1 - v_m)
                A[i, i + 1] = 1/6 - b*dt/2 + dt/4*v_m
                B[i, i - 1] = 1/6 + b*dt/2 + dt/4*v_mm1
                B[i, i] = 2/3 - b*dt - dt/4*(v_mm1 - v_m)
                B[i, i + 1] = 1/6 + b*dt/2 - dt/4*v_m
                
            d = B.dot(u)
            u_new = spsolve(A, d)
            u = u_new.copy()
            t += 1

        return u_new

    def backward(self, u_fwd, p0):
        """Execute the adjoint system in time.

        Parameters
        ---------
        u_fwd : numpy
            Forward solution that is the adjoint dependency.
        p0 : numpy
            Adjoint solution used to initialise de adjoint solver.
        """
        dx = self.dx
        nx = self.nx
        dt = self.dt
        b = self.nu/(dx * dx)
        p_new = np.zeros(nx)
        b = self.nu/(dx*dx)

        # Assemble of the mattrices system
        A = lil_matrix((nx, nx),)
        B = lil_matrix((nx, nx))
        A[0, 0] = 2/3 + b * dt 
        A[0, 1] = 1/6 - b*dt/2 + dt/4*u[1]/dx
        A[self.nx - 1, nx - 1] = 2/3 + b * dt
        A[self.nx - 1, nx - 2] = 1/6 - b * dt/2 - dt/4 * u[nx - 2]/dx
        B[0, 0] = 2/3 - b*dt
        B[0, 1] = 1/6 + b*dt/2 - dt/4*u[1]/dx
        B[nx - 1, nx - 1] = 2/3 - b * dt
        B[nx - 1, nx - 2] = 1/6 + b * dt/2 + dt/4 * u[nx - 2]/dx
        for i in range(1, nx - 1):
            v_m = u_fwd[i]/dx
            v_mm1 = u_fwd[i - 1]/dx
            A[i, i - 1] = 1/6 - b * dt/2 - dt/4*v_mm1
            A[i, i] = 2/3 + b * dt + dt/4*(v_mm1 - v_m)
            A[i, i + 1] = 1/6 - b*dt/2 + dt/4*v_m
            B[i, i - 1] = 1/6 + b*dt/2 + dt/4*v_mm1
            B[i, i] = 2/3 - b*dt - dt/4*(v_mm1 - v_m)
            B[i, i + 1] = 1/6 + b*dt/2 - dt/4*v_m
        
        d = B.dot(u)
        p_new = spsolve(A, d)
        return p_new

    def forward_solution(self):
        u = np.zeros((self.nt, self.nx))
        x = np.linspace(0, L, self.nx)
        t0 = np.exp(1/(8*self.nu))
        for step in range(self.nt):
            t = 1 + step*self.dt
            for i in range(self.nx):
                e = np.exp(x[i]*x[i]/(4*self.nu*t))
                a = (t/t0)**(0.5)
                u[step][i] = (x[i]/(t))/(1 + (a*e))

        return u[-1]

    def dudx(self):
        x = Symbol('x')
        t0 = exp(1/(8*self.nu))
        e = exp(x*x/(4*self.nu))
        t = 1
        a = (t/t0)**(0.5)
        u0 = (x/(t))/(1 + (a*e))
        du_dx = diff(u0, x)
        arr = np.linspace(0, L, self.nx)
        der = np.zeros(nx)
        for i in range(self.nx):
            der[i] = du_dx.subs(x, arr[i])
        return der



The CheckpointingManager object is employed to execute the forward and adjoint solvers by following a sequence of actions provided by the checkpoint_schedules package. The schedules are built according to the total steps defined in the forward solver, in the number of checkpoint saved in RAM and the number of checkpopint saved in disk. 

In [4]:
class CheckpointingManager():
    """Manage the forward and backward solvers.

    Attributes
    ----------
    max_n : int
        Total steps used to execute the solvers.
    equation : object
        The object....
    backward : object
        The backward solver.
    save_ram : int
        Number of checkpoint that will be stored in RAM.
    save_disk : int
        Number of checkpoint that will be stored in disk.
    """
    def __init__(self, max_n, equation, save_ram, save_disk):
        self.max_n = max_n
        self.save_ram = save_ram
        self.save_disk = save_disk
        self.equation = equation
        self.list_actions = []
        

    def execute(self):
        """Execute forward and adjoint with checkpointing mehtod.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            nonlocal model_n
#             u0 = self.equation.forward(u0, cp_action.n0, cp_action.n1)
            n1 = min(cp_action.n1, self.max_n)
            model_n = n1
            if cp_action.n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal model_r, p0
#             nonlocal p0
#             if model_r == 0:
#                 p0 = u
#             p0 = self.equation(u0, p0)
            model_r += cp_action.n1 - cp_action.n0
            if cp_action.clear_adj_deps:
                data.clear()

        @action.register(Copy)
        def action_copy(cp_action):
            pass

        @action.register(EndForward)
        def action_end_forward(cp_action):
            pass

        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            pass

        model_n = 0
        model_r = 0
        p0 = None  # Initialiase the reverse computation.
        ics = set()
        data = set()

        snapshots = {StorageLocation(0).name: {}, StorageLocation(1).name: {}}
        cp_schedule = RevolveCheckpointSchedule(self.max_n, self.save_ram,
                                                snap_on_disk=self.save_disk)
        snapshots = {StorageLocation(0).name: {}, StorageLocation(1).name: {}}
        
        while True:
            cp_action = next(cp_schedule)
            action(cp_action)
            self.list_actions.append([str(cp_action)])
            if isinstance(cp_action, EndReverse):  
#                 assert model_r == 0
                break


Initially, consider the following spatial setup.

In [5]:
L = 1  # Domain lenght
nx = 500 # Number of nodes.
nu = 0.005 # Viscosity

The first case consider few times step only to exemplify how works the forward and adjoint computations with checkpoint_schedules package.

In [6]:
dt = 0.01 # Time variation.
T = 0.05 # Final time
burger = burger_equation(L, nx, dt, T, nu) # Buger's equation object

On setting the time parameters as above, we have a case that execute the forward and adjoint solvers only for fives steps. 

Next, let us define the parameters necessary to obtain a sequence of actions. Thus, we set to store only two step of the forward data in RAM and 0 in disk.

In [7]:
max_n = int(T/dt) # Total steps.
equation = burger # Equation object.
save_ram = 2 # Number of steps to save in RAM.
save_disk = 0 # Number of steps to save in disk.
chk_manager = CheckpointingManager(max_n, equation, save_ram, save_disk)

In [8]:
chk_manager.execute()

To exemplify how the solvers work with checkpoint_schedules iterator, we store the actions which are printed below.

In [9]:
from tabulate import tabulate
print(tabulate(chk_manager.list_actions, headers=["checkpoint_schedules actions"]))


checkpoint_schedules actions
-----------------------------------
Forward(0, 3, True, False, 'RAM')
Forward(3, 4, True, False, 'RAM')
Forward(4, 5, False, True, 'RAM')
EndForward()
Reverse(5, 4, True)
Copy(3, 'RAM', 'TAPE', True)
Forward(3, 4, False, True, 'RAM')
Reverse(4, 3, True)
Copy(0, 'RAM', 'TAPE', False)
Forward(0, 1, False, False, 'NONE')
Forward(1, 2, True, False, 'RAM')
Forward(2, 3, False, True, 'RAM')
Reverse(3, 2, True)
Copy(1, 'RAM', 'TAPE', True)
Forward(1, 2, False, True, 'RAM')
Reverse(2, 1, True)
Copy(0, 'RAM', 'TAPE', True)
Forward(0, 1, False, True, 'RAM')
Reverse(1, 0, True)
EndReverse(True,)


To remind that the checkppint_schedule actions are: Forward, Reverse, Copy, EndForward and EndReverse. 
Where Forward = Forward(n0, n1, write_ics, write_adj_deps, storage), Reverse = Reverse(n0, n1, clear_adj_deps), Copy = Copy(n, from_storage, to_storage, delete), EndForward and EndReverse().
 
Therefore, the action 
Forward(0, 3, True, False, 'RAM'): execute the forward solver from step 0 to step 3, write the foward data (write_ics) of the step 0 in RAM (storage) and does not store the forward data for the adjoint computation. 

Another case with the forward action: 
Forward(4, 5, False, True, 'RAM'): execute the forward solver from step 0 to step 3, does not write the foward data (write_ics) of the step 4 and store the forward data for the adjoint computation of the step 5 in RAM.

Reverse(4, 3, True): execute the adjoint solver from step 4 to step 3, and clear adj_deps that was used in the adjoint computation.

Copy(0, 'RAM', 'TAPE', False): copy the forward data related to the step 0 from the storage level RAM to the TAPE and does not delete from the storage level RAM (since will be used again to restart the forward solver). 

Copy(0, 'RAM', 'TAPE', True): copy the forward data related to the step 0 from the storage level RAM to the TAPE and delete from the storage level RAM (since it is not being to be used again). 


In [10]:
save_ram = 1 # Number of steps to save in RAM.
save_disk = 1 # Number of steps to save in disk.
chk_manager = CheckpointingManager(max_n, equation, save_ram, save_disk)
chk_manager.execute()

In [11]:
from tabulate import tabulate
print(tabulate(chk_manager.list_actions, headers=["checkpoint_schedules actions"]))


checkpoint_schedules actions
-----------------------------------
Forward(0, 2, True, False, 'DISK')
Forward(2, 4, True, False, 'RAM')
Forward(4, 5, False, True, 'RAM')
EndForward()
Reverse(5, 4, True)
Copy(2, 'RAM', 'TAPE', False)
Forward(2, 3, False, False, 'NONE')
Forward(3, 4, False, True, 'RAM')
Reverse(4, 3, True)
Copy(2, 'RAM', 'TAPE', True)
Forward(2, 3, False, True, 'RAM')
Reverse(3, 2, True)
Copy(0, 'DISK', 'TAPE', False)
Forward(0, 1, True, False, 'RAM')
Forward(1, 2, False, True, 'RAM')
Reverse(2, 1, True)
Copy(0, 'RAM', 'TAPE', True)
Forward(0, 1, False, True, 'RAM')
Reverse(1, 0, True)
EndReverse(True,)


Forward(0, 2, True, False, 'DISK'): execute the forward solver from step 0 to step 2, write the foward data (write_ics) of the step 0 in DISK (storage) and does not store the forward data for the adjoint computation. 

Copy(0, 'DISK', 'TAPE', False): copy the forward data related to the step 0 from the storage level DISK to the TAPE and does not delete from the storage level DISK. 

Forward(0, 1, True, False, 'RAM'): execute the forward solver from step 0 to step 1, write the foward data (write_ics) of the step 0 in RAM (storage) and does not store the forward data for the adjoint computation. 

Copy(0, 'RAM', 'TAPE', True): copy the forward data related to the step 0 from the storage level DISK to the TAPE and does not delete from the storage level DISK. 

## References
\[1\] Dogan, Abdulkadir. "A Galerkin finite element approach to Burgers' equation." Applied mathematics and computation 157.2 (2004): 331-346.

\[2\] Aupy, Guillaume, and Julien Herrmann. H-Revolve: a framework for adjoint computation on synchrone hierarchical platforms. Working Paper or preprint.(https://hal.inria.fr/hal-02080706/document), 2019.