## Using *checkpoint_schedules* package

The *checkpoint_schedules* package offers a squedule of actions able to drive adjoint-based gradient executions through a checkpoint strategy, i. e., the schedules coordinate the selective storage of the forward used to initialise the forward solver and for use in the adjoint computation. Also, the foward and back advancement in time, and the retrievement of the storage data. 
That is carrided out by a checkpoint schedule built by the sequence of actions referred to as *Forward, EndForward, Reverse, Copy, EndReverse*. The actions provide functionalities such as storing the forward checkpoint data used to restart the forward solver, storing the forward checkpoint data for adjoint computations, and retrieving the stored data for both the forward solver restart and the adjoint computation. 

In the following code, we have implemented the `CheckpointingManager` class, which allows the manegement of the forward and adjoint executions in time. On using `CheckpointingManager.execute` method, we iterate over a sequence of actions given by the schedule `cp_schedule`. The actions are defined by using single-dispatch functions, where the `action` function is the generic function using the singledispatch decorator. Specific functions for different types of *checkpoint_schedules* actions are provided by using the register method of the base function `action`.

In [1]:
from checkpoint_schedules import Forward, EndForward, Reverse, Copy, EndReverse
import functools

class CheckpointingManager():
    """Manage the forward and adjoint solvers.

    Attributes
    ----------
    save_ram : int
        Number of checkpoint that will be stored in RAM.
    save_disk : int
        Number of checkpoint that will be stored on disk.
    list_actions : list
        Store the list of actions.
    max_n : int
        Total steps used to execute the solvers.
    """
    def __init__(self, max_n, save_ram, save_disk=0):
        self.max_n = max_n
        self.save_ram = save_ram
        self.save_disk = save_disk
        self.list_actions = []
        
    def execute(self, cp_schedule):
        """Execute forward and adjoint with a checkpointing strategy.

        Parameters
        ----------
        cp_schedule : object
            Checkpointing schedule.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            nonlocal model_n
            if cp_action.write_ics:
                print(("*").rjust(cp_action.n0*4))

            print(("|" + "--->"*(cp_action.n1-cp_action.n0)).rjust(cp_action.n1*4) +
                   "   "*(self.max_n - cp_action.n1 + 4) + 
                   self.list_actions[len(self.list_actions) - 1])

            n1 = min(cp_action.n1, self.max_n)
            model_n = n1
            if cp_action.n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal model_r
            print(("<---"*(cp_action.n1-cp_action.n0) + "|").rjust(cp_action.n1*4) 
                  + "   "*(self.max_n - cp_action.n1 + 4) + 
                    self.list_actions[len(self.list_actions) - 1])

            model_r += cp_action.n1 - cp_action.n0
            
        @action.register(Copy)
        def action_copy(cp_action):
            print(("+").rjust(cp_action.n*4) 
                  + "   "*(self.max_n + 4) + 
                    self.list_actions[len(self.list_actions) - 1])

    
        @action.register(EndForward)
        def action_end_forward(cp_action):
            assert model_n == self.max_n
            print("End Forward")
            
        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            nonlocal model_r
            assert model_r == self.max_n
            print("End Reverse")

        model_n = 0
        model_r = 0

        count = 0
        while True:
            cp_action = next(cp_schedule)
            self.list_actions.append(str(cp_action))
            action(cp_action)
            count += 1
            if isinstance(cp_action, EndReverse):  
                break

Firstly, let us define the total steps used in the computations and set the number of time steps to store the forward checkpoint data in RAM.

In [2]:
max_n = 4 # Total number of time steps.
save_ram = 2 # Number of steps to save i RAM.

Next, let us set the `CheckpointingManager` manager object with the attributes defined above. 

In [3]:
chk_manager = CheckpointingManager(max_n, save_ram) # manager object

The *checkpoint_schedules* package provide a range o iterators for the the following checkpoint strategies:
* Revolve 
* Multistage 
* two-level mixed periodic/binomial
* H-Revolve
* Mixing Storage
* Periodic disk storage

We start showing how the *checkpoint_schedules* works by building a schedule iterator (named `chk_iterator`) with the Revolve strategy. The schedule is building on executing `revolver.sequence()`.

In [4]:
from checkpoint_schedules import Revolve
rev_ite = Revolve(max_n, save_ram)
rev_ite.sequence()

`rev_ite.execute(revolver)` runs the executions of the forward and adjoint solvers with.

In [5]:
chk_manager.execute(rev_ite)

*
|--->--->                  Forward(0, 2, True, False, 'RAM')
       *
       |--->               Forward(2, 3, True, False, 'RAM')
           |--->            Forward(3, 4, False, True, 'RAM')
End Forward
           <---|            Reverse(4, 3, True)
       +                        Copy(2, 'RAM', True)
       |--->               Forward(2, 3, False, True, 'RAM')
       <---|               Reverse(3, 2, True)
+                        Copy(0, 'RAM', False)
|--->                     Forward(0, 1, False, False, <StorageType.NONE: None>)
   |--->                  Forward(1, 2, False, True, 'RAM')
   <---|                  Reverse(2, 1, True)
+                        Copy(0, 'RAM', True)
|--->                     Forward(0, 1, False, True, 'RAM')
<---|                     Reverse(1, 0, True)
End Reverse


The output above illustrates how it works the forward and adjoint executions in time with the *checkpoint_schedules* package. The symbol `|` indicates the step that the solver initialises. The symbom `*` on top of `|` indicates that the data used to restart the forward solver is stored. Whereas the symbol `+` indicates the action of copying the storage data used as initial conditions for the forward solver recomputation. 

To complement the illustration above, consider some of the actions explained as follow:

* Forward action
    - General form: *Forward(n0, n1, write_ics, write_adj_deps, 'storage')*
    - Particular form: 
        * *Forward(0, 2, True, False, 'RAM')*:
            - Execute the forward solver from the starting step 0 until 
            to reach the forward solution at step 2.
            - Write the forward data required to initialise the forward solver from step 0. 
            The storage is in RAM.
            - It is not required to store the forward data for the adjoint computation 
            (*write_adj_deps* is False).

        * *Forward(3, 4, False, True, 'RAM')*:
            - Execute the forward solver from the starting step 3 until 
            to reach the forward solution at step 4.
            - It is not required to store the forward data (*write_ics* is False) used to initialise the forward solver.
            - Write the forward data required for the adjoint computation (*write_adj_deps* is *True*). The storage is in RAM.

* Reverse action
     - General form: *Reverse(n1, n0, clear_adj_deps)*
     - Particular form:
        * *Reverse(4, 3, True)*: 
            - Execute the adjoint solver from the starting step 4 untial to reach the reverse solution at the step 3.
            - Clear the adjoint dependencies (*clear_adj_deps* is True) used in the adjoint computation.
* Copy action
     - General form: Copy(n, from_storage, delete)
     - Particular form:
        * Copy(2, 'RAM', True):
            - Copy the forward data stored in RAM that is required to initialise the forward solver from the step 2.
            - Delete the stored data from RAM (*delete* is *True*) as it is not needed anymore to restart the forward solver.

        * Copy(0, 'RAM', False):
            - Copy the forward data stored in RAM that is required to initialise the forward solver from the step 0.
            - Do not delete the stored data from RAM (*delete* is `FALSE`).



In [6]:
from checkpoint_schedules import HRevolve
save_disk = 1
save_ram = 1
chk_manager = CheckpointingManager(max_n, save_ram, save_disk=save_disk) # manager object
revolver = HRevolve(max_n, save_ram, snap_on_disk=save_disk)
revolver.sequence(w_cost=(0, 0.5), r_cost=(0, 0.5))
chk_manager.execute(revolver)

*
|--->--->                  Forward(0, 2, True, False, 'DISK')
       *
       |--->               Forward(2, 3, True, False, 'RAM')
           |--->            Forward(3, 4, False, True, 'RAM')
End Forward
           <---|            Reverse(4, 3, True)
       +                        Copy(2, 'RAM', True)
       |--->               Forward(2, 3, False, True, 'RAM')
       <---|               Reverse(3, 2, True)
+                        Copy(0, 'DISK', False)
*
|--->                     Forward(0, 1, True, False, 'RAM')
   |--->                  Forward(1, 2, False, True, 'RAM')
   <---|                  Reverse(2, 1, True)
+                        Copy(0, 'RAM', True)
|--->                     Forward(0, 1, False, True, 'RAM')
<---|                     Reverse(1, 0, True)
End Reverse


Below we have the schedule obtained from Disk-Revolve and Periodic-Disk-Revolve checkpoint strategies.

In [7]:
from checkpoint_schedules import DiskRevolve
revolver = DiskRevolve(max_n, save_ram, snap_on_disk=save_disk)
print(revolver._schedule)
revolver.sequence(w_cost=0.5, r_cost=0.5)
chk_manager.execute(revolver)

None
*
|--->--->                  Forward(0, 2, True, False, 'DISK')
       *
       |--->               Forward(2, 3, True, False, 'RAM')
           |--->            Forward(3, 4, False, True, 'RAM')
End Forward
           <---|            Reverse(4, 3, True)
       +                        Copy(2, 'RAM', True)
       |--->               Forward(2, 3, False, True, 'RAM')
       <---|               Reverse(3, 2, True)
+                        Copy(0, 'DISK', False)
*
|--->                     Forward(0, 1, True, False, 'RAM')
   |--->                  Forward(1, 2, False, True, 'RAM')
   <---|                  Reverse(2, 1, True)
+                        Copy(0, 'RAM', True)
|--->                     Forward(0, 1, False, True, 'RAM')
<---|                     Reverse(1, 0, True)
End Reverse


In [8]:
from checkpoint_schedules import PeriodicDiskRevolve
revolver = PeriodicDiskRevolve(max_n, save_ram, snap_on_disk=save_disk)
revolver.sequence(period=2)
chk_manager.execute(revolver)

We use periods of size  2
*
|--->--->                  Forward(0, 2, True, False, 'DISK')
       *
       |--->               Forward(2, 3, True, False, 'RAM')
           |--->            Forward(3, 4, False, True, 'RAM')
End Forward
           <---|            Reverse(4, 3, True)
       +                        Copy(2, 'RAM', True)
       |--->               Forward(2, 3, False, True, 'RAM')
       <---|               Reverse(3, 2, True)
+                        Copy(0, 'DISK', False)
*
|--->                     Forward(0, 1, True, False, 'RAM')
   |--->                  Forward(1, 2, False, True, 'RAM')
   <---|                  Reverse(2, 1, True)
+                        Copy(0, 'RAM', True)
|--->                     Forward(0, 1, False, True, 'RAM')
<---|                     Reverse(1, 0, True)
End Reverse


In [9]:
# from checkpoint_schedules import MixedCheckpointSchedule
# snapshots = 2
# revolver = MixedCheckpointSchedule(max_n, snapshots)
# chk_manager.execute(revolver)

In [10]:
from checkpoint_schedules import MultistageCheckpointSchedule
snapshots = 2
revolver = MultistageCheckpointSchedule(max_n, 0, 2)
chk_manager.execute(revolver)

*
|--->--->                  Forward(0, 2, True, False, 'DISK')
       *
       |--->               Forward(2, 3, True, False, 'DISK')
           |--->            Forward(3, 4, False, True, 'RAM')
End Forward
           <---|            Reverse(4, 3, True)
       +                        Copy(2, 'DISK', True)
       |--->               Forward(2, 3, False, True, 'RAM')
       <---|               Reverse(3, 2, True)
+                        Copy(0, 'DISK', False)
|--->                     Forward(0, 1, False, False, 'NONE')
   |--->                  Forward(1, 2, False, True, 'RAM')
   <---|                  Reverse(2, 1, True)
+                        Copy(0, 'DISK', True)
|--->                     Forward(0, 1, False, True, 'RAM')
<---|                     Reverse(1, 0, True)
End Reverse


In [11]:
# from checkpoint_schedules import TwoLevelCheckpointSchedule
# revolver = TwoLevelCheckpointSchedule(2, 10)
# chk_manager.execute(revolver)


This first example gives the basics of executions involving an adjoint-based gradient using *checkpoint_schedules* package. The next section shows an example an application of adjoint-based gradient problem.