## Using *checkpoint_schedules* package

This first example aims to introduce the *checkpoint_schedules* usage with an initial illustration showing how this package woks and how to read the run time of an forward and adjoint solvers executed by following the schedules.

We initially implement `CheckpointingManager` class built to handle both forward and adjoint executions. This management is achieved by iterating over a sequence of schedules through the execution of `CheckpointingManager.execute(cp_schedule)`. The `cp_schedule` must be a *checkpoint_schedules* objective with a generator method that allows the execution of the forward and adjoint solvers driven by the schedules of actions: *Forward, EndForward, Reverse, Copy, Move, EndReverse.*

Whitin `CheckpointingManager.execute` method, we define the actions using single-dispatch functions, with the `action` function being the generic function decorated with the `singledispatch` function. Specific *checkpoint_schedules* actions functions are set through the register method of the `action` base function.

In [45]:
from checkpoint_schedules import Forward, EndForward, Reverse, Copy, Move, EndReverse, StorageType
import functools

class CheckpointingManager():
    """Manage the forward and adjoint solvers.

    Attributes
    ----------
    save_ram : int
        Number of checkpoint that will be stored in `'RAM'`.
    save_disk : int
        Number of checkpoint that will be stored on `'disk'`.
    list_actions : list
        Store the list of actions.
    max_n : int
        Total steps used to execute the solvers.
    """
    def __init__(self, max_n, chk_ram=0, chk_disk=0):
        self.max_n = max_n
        self.save_ram = chk_ram
        self.save_disk = chk_disk
        self.list_actions = []
        
    def execute(self, cp_schedule):
        """Execute forward and adjoint with a checkpointing strategy.

        Parameters
        ----------
        cp_schedule : CheckpointSchedule
            Checkpointing schedule object.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            nonlocal model_n
            n1 = min(cp_action.n1, self.max_n)
            if cp_action.write_ics and cp_action.write_adj_deps:
                print(("+   "*(n1-cp_action.n0 + 1)).rjust(n1*self.max_n))
                print(("x   "*(n1-cp_action.n0 + 1)).rjust(n1*self.max_n))
            else:
                if cp_action.write_ics:
                    print(("+").rjust(cp_action.n0*4))
                if cp_action.write_adj_deps:
                    print(("x").rjust(n1*4))

            print(("|" + "--->"*(n1-cp_action.n0)).rjust(n1*4) +
                   "   "*(self.max_n - n1 + 4) + 
                   self.list_actions[len(self.list_actions) - 1])

            model_n = n1
            if n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal model_r
            print(("<---"*(cp_action.n1-cp_action.n0) + "|").rjust(cp_action.n1*4) 
                  + "   "*(self.max_n - cp_action.n1 + 4) + 
                    self.list_actions[len(self.list_actions) - 1])
            if cp_action.clear_adj_deps:
                print(("-   "*(cp_action.n1-cp_action.n0 + 1)).rjust(cp_action.n1*self.max_n))
            model_r += cp_action.n1 - cp_action.n0
            
        @action.register(Copy)
        def action_copy(cp_action):
            print(("c").rjust(cp_action.n*4) 
                  + "   "*(self.max_n + 4) + 
                    self.list_actions[len(self.list_actions) - 1])

        @action.register(Move)
        def action_move(cp_action):
            print(("-").rjust(cp_action.n*4) 
                  + "   "*(self.max_n + 3) + 
                    self.list_actions[len(self.list_actions) - 1])

        @action.register(EndForward)
        def action_end_forward(cp_action):
            assert model_n == self.max_n
            # The correct number of adjoint steps has been taken
            print("End Forward" + "   "*(self.max_n + 2) + 
                    self.list_actions[len(self.list_actions) - 1])
            if cp_schedule._max_n is None:
                cp_schedule._max_n = self.max_n
            
        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            nonlocal model_r
            assert model_r == self.max_n
            print("End Reverse" + "   "*(self.max_n + 2) + 
                  self.list_actions[len(self.list_actions) - 1])

        model_n = 0
        model_r = 0

        for count, cp_action in enumerate(cp_schedule):
            self.list_actions.append(str(cp_action))
            action(cp_action)
            if isinstance(cp_action, EndReverse):  
                break

It is optional to define the number of checkpoint steps to store in ram (`chk_ram`) and in disk (`chk_disk`) in the `CheckpointingManager` constructor. The reason is that *checkpointing_schedules* has schedules for the case where no adjoint calculation is performed, leading no need of storing any forward checkpointing data. Also, for the case where all forward restart and forward used for the adjoint compution are stored either in memory or in disk. Therefore, integrate this package with an adjoint-based gradient solver is feasible even if the solver is not restrict on the application of a checkpointing strategy that employ revolvers. 

To exemplify this case of no adjoint computation, let us impose the number of total steps `max_n` and the object to drive the forward and adjoint solvers (`solver_manager`). 

In [46]:
max_n = 4 # Total number of time steps.
solver_manager = CheckpointingManager(max_n) # manager object

Now, let us consider the first schedule for the in which no adjoint calculation is necessary. This schedule is give by `NoneCheckpointSchedule`. Below we have illustrated the employment of this schedule for the time steps (`max_n`) defined above.

In [47]:
from checkpoint_schedules import NoneCheckpointSchedule
cp_schedule = NoneCheckpointSchedule() # checkpoint schedule object
solver_manager.execute(cp_schedule)

|--->--->--->--->            Forward(0, 9223372036854775807, False, False, <StorageType.NONE: None>)
End Forward                  EndForward()


The ouputs of the `solver_manager.execute(cp_schedule)` illustrate the execution in time on the right side, and print *checkpoint_schedules* action on the left side.

To clarify how the `Forward` action works, consider initially the general form:
    
* *Forward(n0, n1, write_ics, write_adj_deps, storage)*, which reads:
    - Advance the forward solver from the step `n0` to the start of the step `n1`.
    - Store the forward data required to initialise the forward solver from the step `n0` if `write_ics` is `'True'`.
    - Write the forward data required for the adjoint computation from the step `n1` to the step `n0` if `write_adj_deps` is `'True'`.
    - Storage type to save the forward data.

Therefore, for this particular case, we have:

* *Forward(0, 9223372036854775807, False, False, <StorageType.NONE: None>)*
Notice that `n1` is indertemined. This schedule is referred to as online because it is not necessary to provide a value of the maximal steps (`NoneCheckpointSchedule._max_n` is None) to obtain a schedule. Therefore, the user can define any step as required. Therefore, this action is read as follow:
    - Advance the forward solver from the step `n0` to the start of any step `n1`.
    - Do not store the forward restart data once if `write_ics` is `'False'`.
    - Do not store the forward data required for the adjoint computation once `write_adj_deps` is `'False'`.
    - Storage type is `<StorageType.NONE: None>`, which means no necessary to define a storage type.

* *EndForward()* indicates that the forward solver has reached the end of the time interval.

The next schedule is set for the case where all forward restart data and forward data used for the adjoint compution are stored in memory. This type of schedule is provide by `SingleMemoryStorageSchedule`, which is onlline approach.

In [48]:
from checkpoint_schedules import SingleMemoryStorageSchedule

cp_schedule = SingleMemoryStorageSchedule()
solver_manager.execute(cp_schedule)

+   +   +   +   +   
x   x   x   x   x   
|--->--->--->--->            Forward(0, 9223372036854775807, True, True, <StorageType.RAM: 0>)
End Forward                  EndForward()
<---<---<---<---|            Reverse(4, 0, True)
-   -   -   -   -   
End Reverse                  EndReverse(False,)


In this particular case, we have:

* *Forward(0, 9223372036854775807, True, True, <StorageType.RAM: 0>)*, which reads:
    - Advance the forward solver from the step `n0` to the start of any step `n1`.
    - Store the forward restart data once if `write_ics` is `'True'`.
    - Sotre the forward data required for the adjoint computation once `write_adj_deps` is `'True'`.
    - Storage type is `<StorageType.RAM: 0>`.

The Reverse action in general for is read as:
* *Reverse(n1, n0, clear_adj_deps)*
    -  Advance the forward solver from the step `n1` to the start of any step `n0`.
    - Clear the the foward data (adjoint depedency) if `clear_adj_deps` is `'True'`.
    For this particular case:
* *Reverse(4, 0, True)*
-  Advance the forward solver from the step `4` to the start of any step `0`.
- Clear the the foward data (adjoint depedency) once `clear_adj_deps` is `'True'`.
* *EndReverse(True)* indicates that the reverse actions reached the end of the time interval.

In the illusration of time execution printed on the right side, we have the symbols `+` and `x`, they indicates that the forward restart data and the forward data used for the adjoint computation are stored, recpectivelly. The symbol `-` below, indicates that the adjoint dependency data is cleared.


Now, we start to present the checkpoint schedules for the the cases where checkpointing strategies are employed.
Firtly, let us consider the `Revolve` approach as presente by 


In [49]:
from checkpoint_schedules import Revolve
chk_ram = 2
solver_manager = CheckpointingManager(max_n, chk_ram=chk_ram) # manager object
cp_schedule = Revolve(max_n, chk_ram)
solver_manager.execute(cp_schedule)

+
|--->--->                  Forward(0, 2, True, False, <StorageType.RAM: 0>)
       +
       |--->               Forward(2, 3, True, False, <StorageType.RAM: 0>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
End Forward                  EndForward()
           <---|            Reverse(4, 3, True)
        -   -   
       -                     Move(2, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       x
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-   -   
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTAR

In [50]:
from checkpoint_schedules import HRevolve
chk_disk = 1
chk_ram = 1
solver_manager = CheckpointingManager(max_n, chk_ram=chk_ram, chk_disk=chk_disk) # manager object
cp_schedule = HRevolve(max_n, chk_ram, snap_on_disk=chk_disk)
solver_manager.execute(cp_schedule)

+
|--->--->--->               Forward(0, 3, True, False, <StorageType.RAM: 0>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
End Forward                  EndForward()
           <---|            Reverse(4, 3, True)
        -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       x
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-   -   
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 

In [51]:
from checkpoint_schedules import DiskRevolve
solver_manager = CheckpointingManager(max_n, chk_ram=chk_ram) # manager object``
cp_schedule = DiskRevolve(max_n, chk_ram)
solver_manager.execute(cp_schedule)

+
|--->--->--->               Forward(0, 3, True, False, <StorageType.RAM: 0>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
End Forward                  EndForward()
           <---|            Reverse(4, 3, True)
        -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       x
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-   -   
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 

In [52]:
from checkpoint_schedules import PeriodicDiskRevolve
cp_schedule = PeriodicDiskRevolve(max_n, chk_ram)
solver_manager.execute(cp_schedule)

We use periods of size  3
+
|--->--->--->               Forward(0, 3, True, False, <StorageType.RAM: 0>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
End Forward                  EndForward()
           <---|            Reverse(4, 3, True)
        -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       x
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-   -   
-                     Move(0, <StorageType.RAM: 0>, 

In [53]:
from checkpoint_schedules import MixedCheckpointSchedule

cp_schedule = MixedCheckpointSchedule(max_n, chk_disk)
solver_manager.execute(cp_schedule)

+
|--->--->--->               Forward(0, 3, True, False, <StorageType.DISK: 1>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
End Forward                  EndForward()
           <---|            Reverse(4, 3, True)
        -   -   
c                        Copy(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
-                     Move(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
   x
|--->                     Forward(0, 1, False, True, <StorageType.DISK: 1>)
       x
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-   -   
-                     Move(0, <StorageType.DISK: 1>, <StorageType.ADJ_DEPS: 3>)
<

In [54]:
from checkpoint_schedules import MultistageCheckpointSchedule

cp_schedule = MultistageCheckpointSchedule(max_n, 0, chk_disk)
solver_manager.execute(cp_schedule)

+
|--->--->--->               Forward(0, 3, True, False, <StorageType.DISK: 1>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
End Forward                  EndForward()
           <---|            Reverse(4, 3, True)
        -   -   
c                        Copy(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
c                        Copy(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       x
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-   -   
-                     Move(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTA

In [55]:
from checkpoint_schedules import TwoLevelCheckpointSchedule
revolver = TwoLevelCheckpointSchedule(2, chk_disk)
solver_manager.execute(revolver)


+
|--->--->                  Forward(0, 2, True, False, <StorageType.DISK: 1>)
       +
       |--->--->            Forward(2, 4, True, False, <StorageType.DISK: 1>)
End Forward                  EndForward()
       c                        Copy(2, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
       |--->               Forward(2, 3, False, False, <StorageType.FWD_RESTART: 2>)
               x
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
        -   -   
       c                        Copy(2, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
           x
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
    -   -   
c                        Copy(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       x
   |--->                