# Using *checkpoint_schedules*

This tutorial aims to introduce the usage of *checkpoint_schedules* for step based incremental checkpointing of the adjoints to computer models. This tutorial aims to reach illustrative purposes only. However, the code is fully functional and can be used as a starting point for more complex applications.

## Managing the forward and adjoint executions with  schedules
We initially write the `CheckpointingManager` class, which is intended to manage the execution of forward and adjoint models using a checkpointing schedule. `CheckpointingManager` constructor takes the argument `max_n`, which represents the maximum number of steps for the models execution. The attributes `index_action` and `list_actions` are used only for illustration matter.

`CheckpointingManager` has the method `execute` to manage the step executions of the forward and adjoint models. `execute` takes the `cp_schedule` parameter that expects to be a generator given by *checkpoint_schedules* package. Inside of `execute`, the code iterates over elements in the `cp_schedule` by using a `for` loop. The iteration reached with `enumerate(cp_schedule)` returns a tuple `(count, cp_action)` where `count` is the list index, and `cp_action` is a *checkpoint_schedules* action. The latter is argument of a single-dispatch generic function `action` designed to handle different types of actions using specific functions. A specific function is for instance the `action_forward` that is registered to handle the `Forward` action. Hence, if `cp_action` is the *Forward* action, the `action_forward` function is called and inside of this specific function we can either implement or call any code required to execute the forward model. Analogously with the other *checkpoint_schedules* actions *Reverse*, *Copy*, *Move*, *EndForward* and *EndReverse*.

**Notes:**
* The *checkpoint_schedules* action will treated with more details in the following section of this tutorial.
* The codes insider of the specific functions are intended to be illustrative only. Hence, we only added symbolic print statements to illustrate the execution of the forward and adjoint models.


In [None]:
from checkpoint_schedules import *
import functools

class CheckpointingManager():
    """Manage the executions of the forward and adjoint solvers.

    Attributes
    ----------
    max_n : int
        Total steps used to execute the solvers.
    list_actions : list
        Store the actions. Only used for the illustration matter.
    index_action : int
        Index of the action. Only used for the illustration matter.
    """
    def __init__(self, max_n):
        self.max_n = max_n
        self.list_actions = []
        self.index_action = 0
        
    def execute(self, cp_schedule):
        """Execute forward and adjoint with a checkpointing schedule.

        Parameters
        ----------
        cp_schedule : CheckpointSchedule
            Checkpoint schedule object.

        Notes
        -----
        `cp_schedule` provides the schedule of the actions to be taken and also a
        generator that yields the *checkpoint_schedules* actions.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            nonlocal step_n
            def illustrate_runtime(a, b, singlestorage):
                # function used to illustrate the runtime of the forward execution   
                if singlestorage:
                    time_exec = ".   "*cp_action.n0 + (a + '--' + b)*(n1-cp_action.n0)
                else:
                    time_exec = ".   "*cp_action.n0 + (a + ('---' + b)*(n1-cp_action.n0))
                return time_exec
            
            n1 = min(cp_action.n1, self.max_n)

            # writting the symbols used in the illustrations            
            if cp_action.write_ics and cp_action.write_adj_deps:
                singlestorage = True
                a = '\u002b' 
                b = '\u25b6'
            else:
                singlestorage = False
                if cp_action.write_ics and cp_action.storage == StorageType.DISK:
                    a = '+'
                elif cp_action.write_ics and cp_action.storage == StorageType.RAM:
                    a = '*'
                else:
                    a = ''
                if cp_action.write_adj_deps:
                    b = "\u25b6"
                else:
                    b = "\u25b7"
            # Illustration of the forward execution in time
            time_exec = illustrate_runtime(a, b, singlestorage)

            self.list_actions.append([self.index_action, time_exec, str(cp_action)])

            step_n = n1
            if n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal step_r
            # Illustration of the adjoint execution in time 
            steps  = (cp_action.n1-cp_action.n0)
            step_r += cp_action.n1 - cp_action.n0
            time_exec = ".   "*(self.max_n - step_r) + (('\u25c0' + '---')*steps)
                                
            self.list_actions.append([self.index_action, time_exec, str(cp_action)])
            
        @action.register(Copy)
        def action_copy(cp_action):
            self.list_actions.append([self.index_action, " ", str(cp_action)])

        @action.register(Move)
        def action_move(cp_action):
            self.list_actions.append([self.index_action, " ", str(cp_action)])

        @action.register(EndForward)
        def action_end_forward(cp_action):
            assert step_n == self.max_n
            act = "End Forward" # action
            self.list_actions.append([self.index_action, act, str(cp_action)])
            if cp_schedule._max_n is None:
                cp_schedule._max_n = self.max_n
            
        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            nonlocal step_r, is_exhausted
            # verifying whether the adjoint execution reached the end
            assert step_r == self.max_n
            # Informing the schedule that the execution is exhausted
            is_exhausted = cp_schedule.is_exhausted
            act = "End Reverse"  # action
            self.list_actions.append([self.index_action, act, str(cp_action)])
            
        step_n = 0 # forward step
        step_r = 0 # adjoint step
        is_exhausted = False # flag to indicate whether the schedule is exhausted
        for count, cp_action in enumerate(cp_schedule):
            self.index_action = count
            action(cp_action)
            if isinstance (cp_action, EndReverse):
                break
        
        # Printing the illustration of the execution
        from tabulate import tabulate
        print(tabulate(self.list_actions, headers=['Action index:', 'Run-time illustration', 
                                                    'Action:']))
        self.list_actions = []


## A trivial schedule for forward computation

Firstly, let us define the maximum solvers time steps `max_n = 4`. Next, we instantiate an object named `solver_manager` of the `CheckpointingManager` class, using the `max_n` value.


In [None]:
max_n = 4 # Total number of time steps.
solver_manager = CheckpointingManager(max_n) # manager object

The `NoneCheckpointSchedule` class provides a schedule object providing execute` method. In this case, the schedule is built to execute the forward solver exclusively, excluding any data storage.


In [None]:
cp_schedule = NoneCheckpointSchedule() # Checkpoint schedule object
solver_manager.execute(cp_schedule) # Execute the forward solver by following the schedule.

When executing `solver_manager.execute(cp_schedule)`, the output provides a visual representation of the three distinct informations: 

1. An index linked to each action,

2. A visualisation demonstrating the advancing of time-steps,

3. Actions associated with each step.

Notice in the output that we have two actions: *Forward* and *EndForward()*. The fundamental structure of the *Forward* action is given by:
```python
Forward(n0, n1, write_ics, write_adj_deps, storage_type)
```
This action is read as:

    - Advance the forward solver from step `n0` to the start of any step `n1`.

    - `write_ics` and `write_adj_deps` are booleans that indicate whether the forward solver should store the forward restart data and the forward data required for the adjoint computation, respectively.

    - `storage_type` indicates the type of storage required for the forward restart data and the forward data required for the adjoint computation.

Therefore, for the current example, the `Forward` action indicates the following:

    - Advance the forward solver from step `n0 = 0` to the start of any step `n1`.

    - Both `write_ics` and `write_adj_deps` are  set to `'False'`, indicating no storage of the forward restart data and the forward data required for the adjoint computation. 

    - The storage type is `StorageType.NONE`, indicating that no specific storage type is required. 

*This schedule is built without specifying a maximum step for the forward solver execution. Therefore, using the `NoneCheckpointSchedule` schedule offers the flexibility to determine the desired steps while the forward solver is time advancing.*

In the current example, we determine the maximum step `max_n = 4`, an attribute within the `CheckpointingManager`. Next, we conclude the forward solver execution with the following python script:
```python
 cp_schedule.finalize(n1)
```
where `n1 = max_n = 4`. This line is incorporated in the `action_forward` that is `singledispatch` registered function from `CheckpointingManager.execute`.

Another action provided by the current schedule is the `EndForward()`, which indicates the forward solver has reached the end of the time interval.


## Trivial Schedule for all storage data

We now begin to present the schedules when there is the adjoint solver computation. 

The following code is valuable for the cases where the user intend to store the forward data for all time-steps. This schedule is achieved by using the `SingleMemoryStorageSchedule` class.

Storing the forward restart data is unnecessary by this schedule, as there is no need to recompute the forward solver while time advancing the adjoint solver.

*The `SingleMemoryStorageSchedule` schedule offers the flexibility to determine the desired steps while the forward solver is time advancing.*

In [None]:
cp_schedule = SingleMemoryStorageSchedule()
solver_manager.execute(cp_schedule)


In this particular case, the *Forward* action is given by:

    - Advance the forward solver from the step `n0 = 0` to the start of any step `n1`.

    - Do not store the forward restart data once if `write_ics` is `'False'`.

    - Store the forward data required for the adjoint computation once `write_adj_deps` is `'True'`.
    
    - Storage type is `<StorageType.WORK: 3>`, which indicates the storage that has imediate usage. I this case the usage is the adjoint computation.

When the adjoint computation is considered in the schedule, we have the *Reverse* action that is fundamentally given by:
```python
Reverse(n0, n1, clear_adj_deps)
```
This is interpreted as follows:

    - Advance the adjoint solver from the step `n0` to the start of the step `n1`.

    - Clear the adjoint dependency data if `clear_adj_deps` is `'True'`.

In the current example, the *Reverse* action reads:

    -  Advance the forward solver from the step `4` to the start of the step `0`.

    - Clear the adjoint dependency (forward data) once `clear_adj_deps` is `'True'`.

When adjoint computations are taken into account in the schedules, an additional action referred to a `EndReverse(True)` is required to indicate the end of the adjoint advancing.

The *checkpoint_schedules* additionally allows users to execute forward and adjoint solvers while storing all adjoint dependencies on `'disk'`. The following code shows this schedule applied in the forward and adjoint executions with the object generated by the `SingleDiskStorageSchedule` class.

In [None]:
cp_schedule = SingleDiskStorageSchedule()
solver_manager.execute(cp_schedule)


In the case illustrated above, forward and adjoint executions with `SingleDiskStorageSchedule` have the *Copy* action (see the outputs associated with the indexes 2, 4, 6, 8) which indicates copying of the forward data from one storage type to another.  

The *Copy* action has the fundamental structure:
```python
Copy(n, from_storage, to_storage)
```
which reads:

    - Copy the data associated with step `n`.

    - The term `from_storage` denotes the storage type responsible for retaining forward data at step n, while `to_storage` refers to the designated storage type for storing this forward data.

Hence, on considering the *Copy* action associated with the output `Action index 4`, we have:
    - Copy the data associated with step `4`.

    - The forward data is copied from `'disk'` storage, and the specified storage type for coping (`StorageType.WORK`) refers to the storage type that indicates a prompt usage for the adjoint computation.

Now, let us consider the case where the objective is to move the data from one storage type to another insteady of copying it. To achieve this, the optional `move_data` parameter within the `SingleDiskStorageSchedule` need to be set as `True`. This configuration is illustrated in the following code example:

In [None]:
cp_schedule = SingleDiskStorageSchedule(move_data=True)
solver_manager.execute(cp_schedule)

The *Move* action follows a basic structure:
```python
Move(n, from_storage, to_storage)
```

This can be understood as:

    - Move the data associated with step `n`.

    - The terms `from_storage` and `to_storage` hold the same significance as in the *Copy* action.

Now, on considering one of the *Move* action associated with the output `Action index: 4`:

    - Move the data associated with the step `4`.
    
    - The forward data is moved from `'disk'` storage to a storage used for the adjoint computation.

**The *Move* action entails that the data, once moved, becomes no longer accessible in the original storage type. Whereas the *Copy* action means that the copied data remains available in the original storage type.**

## Schedules given by checkointing methods
### Revolve
Now, let us consider the schedules given by the checkpointing strategies. We begin by employing the Revolve approach, according to introduced in reference [1].

The Revolve checkpointing strategy generates a schedule that only uses `'RAM'` storage type. 

The `Revolve` class gives a schedule according to two essential parameters: the total count of forward time steps (`max_n = 4`) and the number of checkpoints to store in `'RAM'` (`snaps_in_ram = 2`).

The code below shows the execution of the forward and adjoint solvers with the the `Revolve` schedule.

In [None]:
snaps_in_ram = 2 
solver_manager = CheckpointingManager(max_n) # manager object
cp_schedule = Revolve(max_n, snaps_in_ram) 
solver_manager.execute(cp_schedule)

The employment of the checkpointing strategies within an adjoint-based gradient requires the forward solver recomputation. As demonstrated in the output above, we have the *Forward* action associated with the `Action index: 0` that is read as follows:

    - Advance from time step 0 to the start of the time step 2.

    - Store the forward data required to restart the forward solver from time step 0.

    - The storage of the forward restart data is done in RAM.

* In the displayed time step illustrations, we have `'*−−−▷−−−▷'` associated to

```python
Forward(0, 2, True, False, <StorageType.RAM: 0>)
```
The symbol `'*'` indicates that the forward data necessary for restarting the forward computation from step 0 is stored in `'RAM'`. In the time illustrations, we have `'−−−▷'` that indicates the forward data used for the adjoint computation is **not** stored. On the other hand, the illustration `'−−−▶'` indicates that the forward data is stored.

To summarize:
    - `'*'`: Forward data for restarting the forward solver is stored in `'RAM'`.

    - `'−−−▷'`: Forward data used for adjoint computation is not stored.
    
    - `'−−−▶'`: Forward data used for adjoint computation is stored.*

### Multistage checkpoiting 

The schedule as depicted below, employes a *MultiStage* distribution of checkpoints between `'RAM'` and `'disk'` as described in [2]. This checkpointing allows exclusively the memory storage (`'RAM'`), or exclusively the `'disk'` storage, or in both storage locations. 

The following code use two types of storage, `'RAM'` and `'disk'`. 

*MultiStage* checkpointing schedule is given by `MultistageCheckpointSchedule`, which requires the parameters: number of checkpoints stored in `'RAM'` and `'disk'`. 

See the forward and adjoint executions with `MultistageCheckpointSchedule` in the following example:

In [None]:
snaps_in_ram = 1  # number of checkpoints stored in RAM
snaps_on_disk = 1 # number of checkpoints stored in disk
cp_schedule = MultistageCheckpointSchedule(max_n, snaps_in_ram, snaps_on_disk)
solver_manager.execute(cp_schedule)

The symbol `'*'` indicates that the forward data necessary for restarting the forward computation from step 0 is stored in `'RAM'`.

### Disk-Revolve
The following code shows the the execution of a solver over time using the *Disk-Revolve* schedule, as described in reference [3]. This schedule considers two type of storage: memory (`'RAM'`) and `'disk'`. 

The *Disk-Revolve* algorithm, available within the *checkpoint_schedules*, requires the definition of checkpoints stored in memory to be greater than 0 (`'snap_in_ram > 0'`). Specifying the checkpoints stored on `'disk'` is not required, as the algorithm itself calculates this value.

The number of checkpoints stored in `'disk'` is determined according the costs associated with advancing the backward and forward solvers in a single time-step, and the costs of writing and reading the checkpoints saved on disk. Additional details of the definition of these parameters can be found in the references [3], [4] and [5].

In [None]:
snaps_in_ram = 1 # number of checkpoints stored in RAM
cp_schedule = DiskRevolve(max_n, snapshots_in_ram=snaps_in_ram) # checkpointing schedule object
solver_manager.execute(cp_schedule)

### Periodic Disk Revolve

The schedule used in the following code was presented in reference [4]. It is a two type hierarchical schedule and it is referred here to as *Periodic Disk Revolve*. Analogously to the *Disk Revolve* schedule, this approach requires the specification of the maximum number of steps (`max_n`) and the number of checkpoints saved in memory (`snaps_in_ram`). The *Periodic Disk Revolve* computes automatically the number of checkpoint stored in disk.

*It is essential for the number of checkpoints in `'RAM'` to be greater than zero (`'snap_in_ram > 0'`)*

In [None]:
snaps_in_ram = 1
cp_schedule = PeriodicDiskRevolve(max_n, snaps_in_ram)
solver_manager.execute(cp_schedule)

### H-Revolve 
The following code illustrates the forward and adjoint computations using the checkpointing given by H-Revolve strategy [5]. This checkpointing schedule is generated with `HRevolve` class, which requires the following parameters: maximum steps stored in RAM (`snap_in_ram`), maximum steps stored on disk (`snap_on_disk`), and the number of time steps (`max_n`). 

*It is essential for the number of checkpoints in `'RAM'` to be greater than zero (`'snap_in_ram > 0'`)*

In [None]:
snaps_on_disk = 1
snaps_in_ram = 1
cp_schedule = HRevolve(max_n, snaps_in_ram, snaps_on_disk)  # checkpointing schedule
solver_manager.execute(cp_schedule) # execute forward and adjoint in time with the schedule

### Mixed checkpointing

The *Mixed* checkpointing strategy works under the assumption that the data required to restart the forward computation is of the same size as the data required to advance the adjoint computation in one step. Further details into the *Mixed* checkpointing schedule was discussed in reference [6].

This specific schedule provides the flexibility to store the forward restart data either in `'RAM'` or on `'disk'`, but not both simultaneously within the same schedule.

In [None]:
snaps_on_disk = 1
max_n = 4
cp_schedule = MixedCheckpointSchedule(max_n, snaps_on_disk)
solver_manager.execute(cp_schedule)

In the example mentioned earlier, the storage of the forward restart data is default configured for `'disk'`. To modify the storage type to `'RAM'`, the user can set the `MixedCheckpointSchedule` argument `storage = StorageType.RAM`, as displayed below.

In [None]:
snaps_in_ram = 1
cp_schedule = MixedCheckpointSchedule(max_n, snaps_on_disk, storage=StorageType.RAM)
solver_manager.execute(cp_schedule)

### Two-level binomial 

Two-level binomial schedule was presented in reference [6], and its application was performed in the work [7]. 

The two-level binomial checkpointing stores the forward restart data based on the user-defined `period`. In this schedule, the user also define the limite for additional storage of the forward restart data to use during the advancing of the adjoint between periodic storage checkpoints. The default sotrage type is `'disk'`.

Now, let us define the period of storage `period = 2` and the extra forward restart data storage `add_snaps = 1`. The code displayed below shows the execution in time illustration for this setup.

In [None]:
add_snaps = 1 # of additional storage of the forward restart data
period = 2
revolver = TwoLevelCheckpointSchedule(period, add_snaps)
solver_manager.execute(revolver)

Now, let us modify the storage type to `'RAM'` of the additional forward restart checkpointing by setting the optional `TwoLevelCheckpointSchedule` argument `binomial_storage = StorageType.RAM`. Thus, on the example above, ones notices that the action associated with `Action index: 8` implies the forward restart data storage should be on `'disk'`. On the other hand, the example below displays that the action associated to `Action index: 8`  indicates that the forward restart data storage should be in `'RAM'`.


In [None]:
revolver = TwoLevelCheckpointSchedule(3, binomial_snapshots=snaps_on_disk, 
                                      binomial_storage=StorageType.RAM)
solver_manager.execute(revolver)

### References

[1] Griewank, A., & Walther, A. (2000). Algorithm 799: revolve: an implementation of checkpointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1), 19-45., doi: https://doi.org/10.1145/347837.347846

[2] Stumm, P., & Walther, A. (2009). Multistage approaches for optimal offline checkpointing. SIAM Journal on Scientific Computing, 31(3), 1946-1967. https://doi.org/10.1137/080718036

[3] Aupy, G., Herrmann, J., Hovland, P., & Robert, Y. (2016). Optimal multistage algorithm for adjoint computation. SIAM Journal on Scientific Computing, 38(3), C232-C255. DOI: https://doi.org/10.1145/347837.347846.

[4] Aupy, G., & Herrmann, J. (2017). Periodicity in optimal hierarchical checkpointing schemes for adjoint computations. Optimization Methods and Software, 32(3), 594-624. doi: https://doi.org/10.1080/10556788.2016.1230612

[5] Herrmann, J. and Pallez (Aupy), G. (2020). H-Revolve: a framework for adjoint computation on synchronous hierarchical platforms. ACM Transactions on Mathematical Software (TOMS), 46(2), 1-25. DOI: https://doi.org/10.1145/3378672.

[6] Maddison, J. R. (2023). On the implementation of checkpointing with high-level algorithmic differentiation. arXiv preprint arXiv:2305.09568. https://doi.org/10.48550/arXiv.2305.09568.

[7] Pringle, G. C., Jones, D. C., Goswami, S., Narayanan, S. H. K., and  Goldberg, D. (2016). Providing the ARCHER community with adjoint modelling tools for high-performance oceanographic and cryospheric computation. https://nora.nerc.ac.uk/id/eprint/516314.

[8] Goldberg, D. N., Smith, T. A., Narayanan, S. H., Heimbach, P., and Morlighem, M. (2020). Bathymetric Influences on Antarctic Ice‐Shelf Melt Rates. Journal of Geophysical Research: Oceans, 125(11), e2020JC016370. doi: https://doi.org/10.1029/2020JC016370.


