# Tutorial

## *checkpoint_schedule* application: adjoint-based gradient 
This user guideline describes an adjoint-based gradient computation using checkpointing given by *checkpointing_schedules* package. Therefore, we initally define the adjoint-based gradient and then the forward and adjoint solvers prescribed by  *checkpointing_schedules* package.

### Defining the application
Let us consider a one-dimensional (1D) problem where it aims to compute the gradient/sensitivity of an objective functional $I$  with respect to a control parameter. The objective functional, is given by the expression:
$$
I(u) = \int_{\Omega} \frac{1}{2} u(x, \tau)u(x, \tau) \, d x
\tag{1}
$$
which measures the energy of a 1D velocity variable $u = u(x, t)$ governed by the 1D viscous Burgers equation, a nonlinear equation for the advection and diffusion on momentum:
$$
\frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} - \nu \frac{\partial^2 u}{\partial x^2} = 0.
\tag{2},
$$
where $x \in [0, L]$ is the space variable and $t$ represents the time variable. The boundary condition is $u(0, t) = u(L, t) = 0$, where $L$ is the lenght of the 1D domain. The initial condition is given by $u(0, t) = u_0 = \sin(\pi x)$.

The goal is to compute the adjoint-based gradient of the objective functional $I(u)$ with respect to the initial condition $u_0$. Hence, we need to define the adjoint system, which is a function of the forward PDE (Partial Differential Equation) and the objective functional of interest. The adjoint system can be obtained in either a continuous or discrete space. In the continuous space approach, the adjoint PDE is derived from the continuous forward PDE. On the other hand, the discrete space approach utilizes the discrete forward PDE to obtain a discrete adjoint system. Alternatively, a discrete adjoint can be obtained through algorithmic differentiation (AD). 

In this tutorial, the adjoint-based gradient is defined in the continuous space. Thus, the adjoint-based gradient is given by the expression:
$$
\frac{\partial I}{\partial u_0} \delta u_0 = \int_{\Omega}  u^{\dagger}(x, 0) \delta u_0 \, dx,
\tag{3}
$$
where $u^{\dagger}(x, 0)$ is the adjoint variable governed by the adjoint system:
$$
-\frac{\partial u^{\dagger}}{\partial t} + u^{\dagger} \frac{\partial u}{\partial x} - u \frac{\partial u^{\dagger}}{\partial x} - \nu \frac{\partial^2 u^{\dagger}}{\partial x^2} = 0,
\tag{4}
$$
satisfying the boundary condition $u^{\dagger} (0, t) = u^{\dagger}(L, t) = 0$. In this case, the initial condition is $u^{\dagger} (x, \tau) = u(x, \tau)$. 

The adjoint system time-reversed. Therefore, computing the adjoint-based gradient requires storing the forward solution for each time-step, since the adjoint equation depends on the forward solution as seen in adjoint equation (4). Additionally, the gradient expression (3) is a function of $u^{\dagger} (0, t)$, which is the final adjoint time 0. 

#### Discretisation
Both the forward and adjoint systems are discretised using the Finite Element Method (FEM), employing a discretisation methodology detailed in [1]. This methodology uses the Galerkin method with linear trial basis functions to obtain an approximate solution. The backward finite difference method is employed to discretise the equations in time.

#### Coding
The *BurgerGradAdj* class is implemented to set of functionalities for solving Burger's equation (2) and its corresponding adjoint equation (4), as well as computing the objective functional (1). The *BurgerGradAdj* class constructor is responsible for defining the spatial and temporal configurations required for solving the problem. It sets up the necessary parameters and initializes the problem domain. The `copy_fwd_data` function carries forward data copying from either RAM or disk. This data is then used as the initial condition in the forward solver restarting. The functions `store_ram` and `store_disk` are responsible for storing the forward data for restarting purposes. Additionally, the `store_adj_deps` function is responsible for storing the forward data required for the adjoint computation.

In [1]:
from scipy.sparse import lil_matrix
from scipy.optimize import newton_krylov
import numpy as np
import pickle
from scipy.sparse.linalg import spsolve
import functools
from checkpoint_schedules import Forward, EndForward, Reverse, Copy, Move, EndReverse, StorageType

class CheckpointingManager():
    """Manage the forward and backward solvers.

    Attributes
    ----------
    max_n : int
        Total steps used to execute the solvers.
    equation : object
        Adjoint-based gradient object.
    backward : object
        The backward solver.
    save_ram : int
        Number of checkpoint that will be stored in RAM.
    save_disk : int
        Number of checkpoint that will be stored on disk.
    list_actions : list
        Store the list of actions.
    """
    def __init__(self, max_n, save_ram, save_disk):
        self.max_n = max_n
        self.save_ram = save_ram
        self.save_disk = save_disk
        self.list_actions = []
        self.forward_solver = None
        self.backward_solver = None
        self.copy_fwd_data = None
        self.delete_fwd_data = None
        
    def execute(self, cp_schedule):
        """Execute forward and adjoint with checkpointing H-Revolve checkpointing method.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            n1 = min(cp_action.n1, self.max_n)
            self.forward_solver(cp_action)
            
            if n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal model_r
            self.backward_solver(cp_action)
            model_r += cp_action.n1 - cp_action.n0
            
        @action.register(Copy)
        def action_copy(cp_action):
            self.copy_fwd_data(cp_action)

        @action.register(Move)
        def action_move(cp_action):
            if cp_action.to_storage == StorageType.NONE:
                self.delete_fwd_data(cp_action)
            
        @action.register(EndForward)
        def action_end_forward(cp_action):
            pass

        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            pass
            

        model_n = 0
        model_r = 0

        storage_limits = {StorageType(0).name: self.save_ram, 
                          StorageType(1).name: self.save_disk}

        for _, cp_action in enumerate(cp_schedule):
            action(cp_action)
            
            if isinstance(cp_action, EndReverse):  
                break



class Burger(CheckpointingManager):
    """This class is capable to solve the time-dependent forward 
    and adjoint burger's equation.

    Attributes
    ----------
    model : dict
        The model parameters.
    checkpoints : str, optional
        Checkpointing method. The default is "trivial" which means that
        the forward restart data and non-linear dependency data will be stored in 
        StorageType.WORK.
    snapshots : dict
        Store the forward restart data. 
    fwd_working_mem : dict
        Store the forward restart data in StorageType.WORK.
    bwd_working_mem : dict
        Store the adjoint restart data in StorageType.WORK.
    adj_deps : dict
        Store the adjoint dependency data.

    """
    def __init__(self, model, checkpointing="trivial"):
        self.model = model
        self.snapshots = {StorageType.RAM: {}, StorageType.DISK: {}, StorageType.WORK: {}}
        self.fwd_working_mem = {}
        self.bwd_working_mem = {}
        self.adj_deps = {}
        self.checkpointing = checkpointing
        super().__init__(model["max_n"], model["chk_ram"], model["chk_disk"])
      
    def forward(self, cp_action):
        """Solve the non-linear forward burger's equation in time.

        Parameters
        ----------
        cp_action : object
            The forward checkpoint action. 
        
        Notes
        -----
        This action contains the following attributes: n0, n1, write_ics, storage.
        Where n0 is the initial time step, n1 is the final time step, write_ics is a boolean
        that indicates if the forward restart data will be stored, and storage is the storage type.
        """
        dx = self.model["lx"] / self.model["nx"]
        nx = self.model["nx"]
        dt = self.model["dt"]
        nu = self.model["nu"]
        

        if cp_action.write_ics:
            self.snapshots[cp_action.storage][cp_action.n0] = self.fwd_working_mem[cp_action.n0]
        u = self.fwd_working_mem[cp_action.n0]
        self.fwd_working_mem.clear()
        # Assemble the matrix system
        A = lil_matrix((nx, nx))
        B = lil_matrix((nx, nx))
        b = nu / (dx * dx)
        B[0, 0] = -1 / 3
        B[0, 1] = -1 / 6
        B[nx - 1, nx - 1] = -1 / 3
        B[nx - 1, nx - 2] = -1 / 6
        if self.checkpointing == "trivial":
            steps = self.max_n
        else:
            steps = cp_action.n0 - cp_action.n1
        t = 0
        while t < steps:
            def non_linear(u_new):
                u[0] = u[nx - 1] = 0
                A[0, 0] = 1 / 3 - dt * (1/2*u_new[0] / dx + b)
                A[0, 1] = 1 / 6 + dt * (1 / 2 * u_new[0] / dx - b)
                A[nx - 1, nx - 1] = 1 / 3 - dt * (- u_new[nx - 1] / dx + b)
                A[nx - 1, nx - 2] = 1 / 6 + dt * (1 / 2 * u_new[nx - 2] / dx - b)

                for i in range(1, nx - 1):
                    B[i, i] = -2 / 3
                    B[i, i + 1] = B[i, i - 1] = -1 / 6
                    A[i, i - 1] = 1 / 6 - dt * (1 / 2 * u_new[i - 1] / dx + b)
                    A[i, i] = 2 / 3 + dt * (1 / 2 * (u_new[i - 1] - u_new[i]) / dx + 2 * b)
                    A[i, i + 1] = 1 / 6 + dt * (1 / 2 * u_new[i] / dx - b)

                F = A * u_new + B * u
                return F
            
            u_new = newton_krylov(non_linear, u)
            u = u_new.copy()
            t +=1
            if self.checkpointing == "trivial":
                assert cp_action.write_adj_deps is True
                if cp_action.storage == StorageType.WORK:
                    self.snapshots[cp_action.storage][t] = u
                elif cp_action.storage == StorageType.DISK:
                    self.store_on_disk(u, t)
            if self.max_n is None:
                self.max_n = cp_action.n1
       
        if self.checkpointing != "trivial":
            # The forward data stored in memory. 
            # This variable is usefull when we have subsequent Forward actions.
            self.fwd_working_mem = {cp_action.n1: u_new}
            if cp_action.write_adj_deps:
                self.adj_deps = {cp_action.n1, u_new}
        
        if self.max_n == self.model["max_n"] or self.max_n == cp_action.n1:
            self.adj_initcondition(u, self.max_n)

    def backward(self, cp_action):
        """Execute the adjoint equation in time.

        Parameters
        ---------
        cp_action : object
            The adjoint checkpoint action.
        
        Notes
        -----
        This action contains the following attributes: n0, n1, clear_adj_deps.
        Where n0 is the initial time step, n1 is the final time step, and clear_adj_deps is a boolean
        that indicates if the adjoint dependency data will be cleared.

        """
        dx = self.model["lx"] / self.model["nx"]
        nx = self.model["nx"]
        dt = self.model["dt"]
        nu = self.model["nu"]
        b = nu / (dx * dx)
        u = self.bwd_working_mem[cp_action.n1]
        self.bwd_working_mem.clear()
        u_new = np.zeros(nx)
        steps = int(cp_action.n1 - cp_action.n0)
        t = cp_action.n1
        A = lil_matrix((nx, nx))
        B = lil_matrix((nx, nx))
        A[0, 0] = 1 / 3
        A[0, 1] = 1 / 6
        A[nx - 1, nx - 1] = 1 / 3
        A[nx - 1, nx - 2] = 1 / 6
        while t >= 0:
            u[0] = u[nx - 1] = 0
            if self.checkpointing == "trivial":
                uf = self.snapshots[StorageType.WORK][t]
            else:
                uf = self.adj_deps[cp_action.n1]

            B[0, 0] = 1 / 3 - dt * (uf[0] / dx - b - 1 / 3 * (uf[1] - uf[0]) / dx)
            B[0, 1] = (1 / 6 + dt * (1 / 2 * uf[0] / dx + b - 1 / 6 * (uf[2] - uf[1]) / dx))
            B[nx - 1, nx - 1] = (1 / 3 + dt * (uf[nx - 1] / dx - b 
                                - 1 / 3 * (uf[nx - 1] - uf[nx - 2]) / dx))
            B[nx - 1, nx - 2] = (1 / 6 + dt * (1 / 2 * u_new[nx - 2] / dx 
                                + b - 1 / 6 * (uf[nx - 1] - uf[nx - 2]) / dx))
            
            for i in range(1, nx - 1):
                v_m = uf[i] / dx
                v_mm1 = uf[i - 1] / dx
                deri = (uf[i] - uf[i - 1]) / dx
                derip = (uf[i + 1] - uf[i]) / dx
                A[i, i - 1] = 1 / 6
                A[i, i] = 2 / 3
                A[i, i + 1] = 1 / 6
                B[i, i] = 2 / 3 + dt * (1 / 2 * (v_mm1 - v_m) - 2 * b - 2 / 3 * (deri - derip))
                B[i, i - 1] = 1/6 - dt * (1 / 2 * v_mm1 - b - 1 / 6 * deri)
                B[i, i + 1] = 1/6 + dt*(1/2 * v_m + b - 1 / 6 * derip)
    
            d = B.dot(u)
            u_new = spsolve(A, d)
            u = u_new.copy()
            
            if cp_action.clear_adj_deps:
                if self.checkpointing == "trivial":
                    del self.snapshots[StorageType.WORK][t]
                else:
                    self.adj_deps.clear() 
            t -= 1
        self.bwd_working_mem = {cp_action.n1: u_new}
        
    def copy_fwd_data(self, cp_action):
        """Copy the forward data.
        """
        if cp_action.from_storage == StorageType.DISK:
            file_name = self.snapshots[cp_action.from_storage][cp_action.n]
            with open(file_name, "rb") as f:
                u0 = np.asarray(pickle.load(f), dtype=float)
        else:
            u0 = self.snapshots[cp_action.from_storage][cp_action.n]
        self.fwd_working_mem = {cp_action.n: u0}  
    
    def delete_fwd_data(self, cp_action):
        del self.snapshots[cp_action.from_storage][cp_action.n]
        
    def compute_grad(self):
        x = np.linspace(0, self.lx, self.nx)
        sens = np.trapz(self.p[0]*1.01*np.sin(np.pi*x), x=x, dx=self.dx)
        print("Sensitivity:", sens)
    
    def store_in_ram(self, data, step):
        """Store the forward data in RAM.
        """
        self.snapshots[StorageType.RAM][step] = data

    def store_on_disk(self, data, step):
        """Store the forward data on disk.
        """
        file_name = "fwd_data/ufwd_"+ str(step) +".dat"
        with open(file_name, "wb") as f:
            pickle.dump(data, f)
        self.snapshots[StorageType.DISK][step] = file_name

    def fwd_initcondition(self, u0, step):
        """Define the forward initial condition.

        Parameters
        ----------
        u0 : array
            The forward initial condition.
        step : int
            The initial time step.
        """
        self.fwd_working_mem = {step: u0}
    
    def adj_initcondition(self, p0, step):
        """Define the adjoint initial condition.

        Parameters
        ----------
        p0 : array
            The adjoint initial condition.
        step : int
            The initial time step.
        """
        self.bwd_working_mem = {step: p0}

    def execute_burger(self, cp_schedule):
        self.forward_solver = self.forward
        self.backward_solver = self.backward
        self.execute(cp_schedule)
    

### Using *checkpoint_schedules* package
Analagous to the previous example, we mplement the *CheckpointingManager* class, which provides a management of the forward and adjoint solvers coordinated by the sequence of actions given by the *checkpoint_schedules* package. 

In [2]:
model = {"lx": 1,   # lenght domain
         "nx": 100, # number of nodes
         "dt": 0.01, # time step
         "T": 1, # final time
         "nu": 0.01, # viscosity
         "max_n": 10, # total steps
         "chk_ram": 10, # number of checkpoints in RAM
         "chk_disk": 0, # number of checkpoints on disk
        }

Below we define the viscosity parameter, the initial condition imposed at the initial step, and the time and spatial discretisation parameters.

In [3]:
from checkpoint_schedules import SingleStorageSchedule

burger = Burger(model) # create the burger's equation object
x = np.linspace(0, model["lx"], model["nx"]) # create the spatial grid
u0 = np.sin(np.pi*x) # initial condition
burger.fwd_initcondition(u0, 0) # set the initial condition
schedule = SingleStorageSchedule() # create the checkpointing schedule
burger.execute_burger(schedule) # execute the forward and adjoint solvers




{<StorageType.RAM: 0>: {}, <StorageType.DISK: 1>: {}, <StorageType.WORK: -1>: {}}


  warn('spsolve requires A be CSC or CSR matrix format',


We want to get a manager object able to execute the forward and adjoint equations by following the *checkpoint_schedules* actions. To do that, we set the parameters necessary to obtain a sequence of actions. They are the total time-steps, and the number of checkpoint data that we want to store in RAM and on disk. 

In this first example, we set checkpoint data associate to two steps of the forward problem to be stored in RAM and one checkpoint data associate to one step to be stored in disk.

After to define the manager object given by the *CheckpointingManager* class, we execute our adjoint-based gradient problem by the `execute` method as shown below, where the execution depends of the checkpoint schedule that is built from a list of checkpoint operations provided by the H-Revolve checkpointing method.

### References
\[1\] Aksan, E. N. "A numerical solution of Burgersâ€™ equation by finite element method constructed on the method of discretization in time." Applied mathematics and computation 170.2 (2005): 895-904.