# Tutorial

## Computation of adjoint-based gradient with checkpoint_schedules package
The aim of this tutorial is to demonstrate the application of checkpointing schedules in a simplified scenario of adjoint-based gradient computation.

### Defining the application
Let us consider a one-dimensional problem where it aims to compute the gradient/sensitivity of an objective functional given by the expression below with respect to a control parameter. The objective functional, denoted as $I(u)$, is given by the expression:

\begin{equation}
    I(u) = \int_{\Omega} \frac{1}{2} \frac{u(x, \tau)u(x, \tau)}{u_0 u_0} \, d x
    \tag{1}
\end{equation}

The variable $u = u(x, t)$ is governed by the viscous Burgers equation, a nonlinear equation for the advection and diffusion on momentum in one dimension:
\begin{equation}
\frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} - \nu \frac{\partial^2 u}{\partial x^2} = 0,
\tag{2}
\end{equation}
which satisfies the boundary condition $u(0, t) = u(L, t) = 0$. The initial condition is given by $u(0, t) = u_0 = \sin(\pi x)$.

In this case, the goal is to compute the sensitivity of the objective functional $I(u)$ with respect to the initial condition $u_0$ that is the initial condition of the problem. 

To compute the adjoint-based gradient, we need to define the adjoint system, which is a function of the forward PDE (Partial Differential Equation) and the objective functional of interest. The adjoint system can be obtained in either a continuous or discrete space. In the continuous space approach, the adjoint PDE is derived from the continuous forward PDE. On the other hand, the discrete space approach utilizes the discrete forward PDE to obtain a discrete adjoint system. Alternatively, a discrete adjoint can be obtained through algorithmic differentiation (AD).

In this tutorial, we present the adjoint-based gradient defined in a continuous space that is given by the following expression:
\begin{equation}
\frac{\partial I}{\partial u_0} \delta u_0 = \int_{\Omega}  u^{\dagger}(x, 0) \delta u_0 \, dx,
\tag{3}
\end{equation}
where $u^{\dagger}(x, 0)$ is the solution of an adjoint system that in the current case reads:
\begin{equation}
-\frac{\partial u^{\dagger}}{\partial t} + u^{\dagger} \frac{\partial u}{\partial x} - u \frac{\partial u^{\dagger}}{\partial x} - \nu \frac{\partial^2 u^{\dagger}}{\partial x^2} = 0,
\tag{4}
\end{equation}
satisfying the boundary condition $u^{\dagger} (0, t) = u^{\dagger}(L, t) = 0$. For the current adjoint-based gradient problem, the initial condition is $u^{\dagger} (x, \tau) = u(x, \tau)$. 

It is important to note that the initial condition of the adjoint equation is given at time $\tau$, making the adjoint system time-reversed. Therefore, computing the adjoint-based gradient requires storing the forward solution for each time-step since the adjoint equation depends on the forward solution as seen in adjoint equation (4). Additionally, the gradient expression (3) is a function of $u^{\dagger} (0, t)$, which represents the final the final adjoint time 0.   

### Discretisation

We discretise both the forward and adjoint systems using the Finite Element Method (FEM), following the discretisation methodology described in [1]. The forward discretization methodology is explained in detail in [1], and we adopt the same FEM basis and time discretization for the adjoint system.

### Implementation

The forward and adjoint solutions are stored in numpy arrays. We utilize the scipy Python library to construct sparse matrices for both the forward and adjoint systems. Furthermore, we employ the scipy Python library to solve the adjoint system.

In [1]:
import numpy as np
import functools
import matplotlib.pyplot as plt
from sympy import *
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve

As mentioned above, the *checkpoint_schedules* package prescribes the combination of an original forward computation together one adjoint computation. Hence, from the *checkpoint_schedules* library we import the actions used to execute our current adjoint-based gradient with the H-Revolve checkpointing method [2].

The *checkpoint_schedules* package provides a set of actions that define the forward and adjoint executions. Therefore, it is essential to import the actions from the *checkpoint_schedules* package to ensure proper functionality.

In [6]:
from checkpoint_schedules import Forward, EndForward, Reverse, Copy, EndReverse
from checkpoint_schedules import RevolveCheckpointSchedule, StorageLocation

The actions have specific roles in the computation process. Let us describe their functions in more detail:

- The Forward action is responsible for advancing the forward solver by a specified number of steps. It also configures the intermediate checkpoint data storage during the forward computation.

- The Reverse action is used to execute the adjoint solver starting from a initial step (n0) to the step, n1. Additionally,  it clears the forward data that is used in the adjoint solver.

- The Copy action plays an essential role in the checkpointing process. It copies data from a storage level, which can be either RAM or disk, and transfers it to the forward tape. Additionally, it indicates whether this data can be safely deleted from the storage level, freeing up memory resources.

- The actions End_Forward and End_Reverse serve as indicators of the finalization of the forward and reverse solvers, respectively. They signify the completion of the corresponding computation processes.

We also import two additional objects from the *checkpoint_schedules* package: *RevolveCheckpointSchedule* and *StorageLocation*.

The *RevolveCheckpointSchedule* object serves as an iterator, allowing us to iterate over a sequence of the checkpoint schedule actions. It provides a structured and organized way to execute the actions in the desired order.

The *StorageLocation* object is responsible for specifying the locations where the checkpoint data is stored and copied during the computation. The storage locations are: RAM, DISK, TAPE, and NONE.

- RAM is the first level of storage and holds the checkpoint data that is used to restart the forward solver.

- DISK is referred as the second level of storage.

- TAPE refers to the local storage that holds the checkpoint data used as the initial condition for the forward solver. It represents a separate storage location specifically designated for storing this particular type of checkpoint data.

- NONE indicates that there is no specific storage location defined for the checkpoint data. 

The code provided below implements the discretization of the forward and adjoint equations. The constructor of the *burgers_equation* class defines the spatial and temporal configurations for the problem. The forward equation system is implemented in the forward method, while the adjoint equation system is implemented in the backward method.

In [3]:
class burger_equation():
    """This object define the forward burger's equation its
    and adjoint system.

    Attributes
    ----------
    L : float
        The domain lenght
    nx : int
        Number of nodes.
    dt : float
        Time step.
    T : float
        Period of time.
    nu : float
        The viscosity.
    """
    def __init__(self, L, nx, dt, T, nu):
        self.nt = int(T/dt)
        self.nu = nu
        self.dx = L / (nx - 1)
        self.dt = dt
        self.nx = nx

    def forward(self, u0, t0, tf):
        """Execute the foward system in time.

        Parameters
        ---------
        u0 : numpy
            Forward initial condition.
        t0 : int
            Initial step.
        tn : int
            Final step.
        """
        dx = self.dx
        nx = self.nx
        dt = self.dt
        b = self.nu/(dx * dx)
        u = u0
        u_new = np.zeros(nx)
        b = self.nu/(dx*dx)
        steps = int(tf - t0)
        t = 0
        while t < steps:
            # Assemble of the mattrices system
            A = lil_matrix((nx, nx),)
            B = lil_matrix((nx, nx))
            A[0, 0] = 2/3 + b * dt 
            A[0, 1] = 1/6 - b*dt/2 + dt/4*u[1]/dx
            A[self.nx - 1, nx - 1] = 2/3 + b * dt
            A[self.nx - 1, nx - 2] = 1/6 - b * dt/2 - dt/4 * u[nx - 2]/dx
            B[0, 0] = 2/3 - b*dt
            B[0, 1] = 1/6 + b*dt/2 - dt/4*u[1]/dx
            B[nx - 1, nx - 1] = 2/3 - b * dt
            B[nx - 1, nx - 2] = 1/6 + b * dt/2 + dt/4 * u[nx - 2]/dx
            for i in range(1, nx - 1):
                v_m = u[i]/dx
                v_mm1 = u[i - 1]/dx
                A[i, i - 1] = 1/6 - b * dt/2 - dt/4*v_mm1
                A[i, i] = 2/3 + b * dt + dt/4*(v_mm1 - v_m)
                A[i, i + 1] = 1/6 - b*dt/2 + dt/4*v_m
                B[i, i - 1] = 1/6 + b*dt/2 + dt/4*v_mm1
                B[i, i] = 2/3 - b*dt - dt/4*(v_mm1 - v_m)
                B[i, i + 1] = 1/6 + b*dt/2 - dt/4*v_m
                
            d = B.dot(u)
            u_new = spsolve(A, d)
            u = u_new.copy()
            t += 1

        return u_new

    def backward(self, u_fwd, p0):
        """Execute the adjoint system in time.

        Parameters
        ---------
        u_fwd : numpy
            Forward solution that is the adjoint dependency.
        p0 : numpy
            Adjoint solution used to initialise de adjoint solver.
        """
        dx = self.dx
        nx = self.nx
        dt = self.dt
        b = self.nu/(dx * dx)
        p_new = np.zeros(nx)
        b = self.nu/(dx*dx)

        # Assemble of the mattrices system
        A = lil_matrix((nx, nx),)
        B = lil_matrix((nx, nx))
        A[0, 0] = 2/3 + b * dt 
        A[0, 1] = 1/6 - b*dt/2 + dt/4*u_fwd[1]/dx
        A[self.nx - 1, nx - 1] = 2/3 + b * dt
        A[self.nx - 1, nx - 2] = 1/6 - b * dt/2 - dt/4 * u_fwd[nx - 2]/dx
        B[0, 0] = 2/3 - b*dt
        B[0, 1] = 1/6 + b*dt/2 - dt/4 * u_fwd[1]/dx
        B[nx - 1, nx - 1] = 2/3 - b * dt
        B[nx - 1, nx - 2] = 1/6 + b * dt/2 + dt/4 * u_fwd[nx - 2]/dx
        for i in range(1, nx - 1):
            v_m = u_fwd[i]/dx
            v_mm1 = u_fwd[i - 1]/dx
            A[i, i - 1] = 1/6 - b * dt/2 - dt/4*v_mm1
            A[i, i] = 2/3 + b * dt + dt/4*(v_mm1 - v_m)
            A[i, i + 1] = 1/6 - b*dt/2 + dt/4*v_m
            B[i, i - 1] = 1/6 + b*dt/2 + dt/4*v_mm1
            B[i, i] = 2/3 - b*dt - dt/4*(v_mm1 - v_m)
            B[i, i + 1] = 1/6 + b*dt/2 - dt/4*v_m
        
        d = B.dot(p0)
        p_new = spsolve(A, d)
        return p_new

    def physics_der(u_fwd):
        
    def forward_solution(self):
        u = np.zeros((self.nt, self.nx))
        x = np.linspace(0, L, self.nx)
        t0 = np.exp(1/(8*self.nu))
        for step in range(self.nt):
            t = 1 + step*self.dt
            for i in range(self.nx):
                e = np.exp(x[i]*x[i]/(4*self.nu*t))
                a = (t/t0)**(0.5)
                u[step][i] = (x[i]/(t))/(1 + (a*e))

        return u[-1]

    def dudx(self):
        x = Symbol('x')
        t0 = exp(1/(8*self.nu))
        e = exp(x*x/(4*self.nu))
        t = 1
        a = (t/t0)**(0.5)
        u0 = (x/(t))/(1 + (a*e))
        du_dx = diff(u0, x)
        arr = np.linspace(0, L, self.nx)
        der = np.zeros(nx)
        for i in range(self.nx):
            der[i] = du_dx.subs(x, arr[i])
        return der

The *CheckpointingManager* object plays a crucial role in executing the forward and adjoint solvers by coordinating a sequence of actions defined by the *checkpoint_schedules* package. These actions are designed to implement the desired checkpointing strategy.

The schedules used by the *CheckpointingManager* are built based on several factors. First, they are constructed according to the total number of steps defined in the forward solver. This ensures that the checkpointing process aligns with the progression of the solver.

Additionally, the schedules are designed to manage the number of checkpoints saved in RAM and on disk. This allows for efficient memory usage and storage management. The specific number of checkpoints to be saved in RAM and on disk can be determined based on the requirements and constraints of the computation.

By using the *CheckpointingManager* and configuring the schedules accordingly, the forward and adjoint solvers can effectively execute the checkpointing strategy and perform the necessary computations.

In [4]:
class CheckpointingManager():
    """Manage the forward and backward solvers.

    Attributes
    ----------
    max_n : int
        Total steps used to execute the solvers.
    equation : object
        The object....
    backward : object
        The backward solver.
    save_ram : int
        Number of checkpoint that will be stored in RAM.
    save_disk : int
        Number of checkpoint that will be stored in disk.
    """
    def __init__(self, max_n, equation, save_ram, save_disk):
        self.max_n = max_n
        self.save_ram = save_ram
        self.save_disk = save_disk
        self.equation = equation
        self.list_actions = []
        

    def execute(self):
        """Execute forward and adjoint with checkpointing mehtod.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            nonlocal model_n
#             u0 = self.equation.forward(u0, cp_action.n0, cp_action.n1)
            n1 = min(cp_action.n1, self.max_n)
            model_n = n1
            if cp_action.n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal model_r, p0
#             nonlocal p0
#             if model_r == 0:
#                 p0 = u
#             p0 = self.equation(u0, p0)
            model_r += cp_action.n1 - cp_action.n0
            if cp_action.clear_adj_deps:
                data.clear()

        @action.register(Copy)
        def action_copy(cp_action):
            pass

        @action.register(EndForward)
        def action_end_forward(cp_action):
            pass

        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            pass

        model_n = 0
        model_r = 0
        p0 = None  # Initialiase the reverse computation.
        ics = set()
        data = set()

        snapshots = {StorageLocation(0).name: {}, StorageLocation(1).name: {}}
        cp_schedule = RevolveCheckpointSchedule(self.max_n, self.save_ram,
                                                snap_on_disk=self.save_disk)
        snapshots = {StorageLocation(0).name: {}, StorageLocation(1).name: {}}
        
        while True:
            cp_action = next(cp_schedule)
            action(cp_action)
            self.list_actions.append([str(cp_action)])
            if isinstance(cp_action, EndReverse):  
#                 assert model_r == 0
                break


Initially, consider the following spatial setup.

In [5]:
L = 1  # Domain lenght
nx = 500 # Number of nodes.
nu = 0.005 # Viscosity

The first case consider few times step only to exemplify how works the forward and adjoint computations with *checkpoint_schedules* package.

In [6]:
dt = 0.01 # Time variation.
T = 0.05 # Final time
burger = burger_equation(L, nx, dt, T, nu) # Buger's equation object

On setting the time parameters as above, we have a case that execute the forward and adjoint solvers only for fives steps. 

Next, let us define the parameters necessary to obtain a sequence of actions. Thus, we set to store only two step of the forward data in RAM and 0 in disk.

In [7]:
max_n = int(T/dt) # Total steps.
equation = burger # Equation object.
save_ram = 2 # Number of steps to save in RAM.
save_disk = 0 # Number of steps to save in disk.
chk_manager = CheckpointingManager(max_n, equation, save_ram, save_disk)

In [8]:
chk_manager.execute()

To exemplify how the solvers work with *checkpoint_schedules* iterator, we store the actions which are printed below.

In [9]:
from tabulate import tabulate
print(tabulate(chk_manager.list_actions, headers=["checkpoint_schedules actions"]))


checkpoint_schedules actions
-----------------------------------
Forward(0, 3, True, False, 'RAM')
Forward(3, 4, True, False, 'RAM')
Forward(4, 5, False, True, 'RAM')
EndForward()
Reverse(5, 4, True)
Copy(3, 'RAM', 'TAPE', True)
Forward(3, 4, False, True, 'RAM')
Reverse(4, 3, True)
Copy(0, 'RAM', 'TAPE', False)
Forward(0, 1, False, False, 'NONE')
Forward(1, 2, True, False, 'RAM')
Forward(2, 3, False, True, 'RAM')
Reverse(3, 2, True)
Copy(1, 'RAM', 'TAPE', True)
Forward(1, 2, False, True, 'RAM')
Reverse(2, 1, True)
Copy(0, 'RAM', 'TAPE', True)
Forward(0, 1, False, True, 'RAM')
Reverse(1, 0, True)
EndReverse(True,)



* *Forward(0, 3, True, False, 'RAM')*:
    - Execute the forward solver from step 0 to step 3.
    - Write the forward data (*write_ics*) of step 0 to RAM (storage).
    - The forward data is not stored for the adjoint computation (*write_adj_deps* is False).

* *Forward(4, 5, False, True, 'RAM')*:
    - Execute the forward solver from step 4 to step 5.
    - Do not write the forward data (*write_ics*) of step 4.
    - Store the forward data for the adjoint computation (*write_adj_deps* is *True*) of step 5 in RAM.


* *Reverse(4, 3, True)*:
    - Execute the adjoint solver from step 4 to step 3.
    - Clear the adjoint dependencies (*adj_deps*) used in the adjoint computation.


* Copy(0, 'RAM', 'TAPE', False):
    - Copy the forward data related to step 0 from RAM to TAPE.
    - Do not delete the copied data from RAM (*delete* is *False*) since it will be used again to restart the forward solver.


* Copy(0, 'RAM', 'TAPE', True):
    - Copy the forward data related to step 0 from RAM to TAPE.
    - Delete the copied data from RAM (*delete* is *True*) as it is not needed anymore.

## References
\[1\] Dogan, Abdulkadir. "A Galerkin finite element approach to Burgers' equation." Applied mathematics and computation 157.2 (2004): 331-346.

\[2\] Aupy, Guillaume, and Julien Herrmann. H-Revolve: a framework for adjoint computation on synchrone hierarchical platforms. (https://hal.inria.fr/hal-02080706/document), 2019.