## Using *checkpoint_schedules* package

The *checkpoint_schedules* package offers a squedule of actions able to drive adjoint-based gradient executions through a checkpoint strategy, i. e., the schedules coordinate the selective storage of the forward used to initialise the forward solver and for use in the adjoint computation. Also, the foward and back advancement in time, and the retrievement of the storage data. 
That is carrided out by a checkpoint schedule built by the sequence of actions referred to as *Forward, EndForward, Reverse, Copy, Move, EndReverse*. The actions provide functionalities such as storing the forward checkpoint data used to restart the forward solver, storing the forward checkpoint data for adjoint computations, and retrieving the stored data for both the forward solver restart and the adjoint computation. 

In the following code, we have implemented the `CheckpointingManager` class, which allows the manegement of the forward and adjoint executions in time. On using `CheckpointingManager.execute` method, we iterate over a sequence of actions given by the schedule `cp_schedule`. The actions are defined by using single-dispatch functions, where the `action` function is the generic function using the singledispatch decorator. Specific functions for different types of *checkpoint_schedules* actions are provided by using the register method of the base function `action`.

In [67]:
from checkpoint_schedules import Forward, EndForward, Reverse, Copy, Move, EndReverse
import functools

class CheckpointingManager():
    """Manage the forward and adjoint solvers.

    Attributes
    ----------
    save_ram : int
        Number of checkpoint that will be stored in RAM.
    save_disk : int
        Number of checkpoint that will be stored on disk.
    list_actions : list
        Store the list of actions.
    max_n : int
        Total steps used to execute the solvers.
    """
    def __init__(self, max_n, save_ram, save_disk=0):
        self.max_n = max_n
        self.save_ram = save_ram
        self.save_disk = save_disk
        self.list_actions = []
        
    def execute(self, cp_schedule):
        """Execute forward and adjoint with a checkpointing strategy.

        Parameters
        ----------
        cp_schedule : object
            Checkpointing schedule.
        """
        @functools.singledispatch
        def action(cp_action):
            raise TypeError("Unexpected action")

        @action.register(Forward)
        def action_forward(cp_action):
            nonlocal model_n
            n1 = min(cp_action.n1, self.max_n)
            if cp_action.write_ics:
                print((" +").rjust(cp_action.n0*4))
            if cp_action.write_adj_deps:
                print(("+").rjust(n1*4))
   
            print(("|" + "--->"*(n1-cp_action.n0)).rjust(n1*4) +
                   "   "*(self.max_n - n1 + 4) + 
                   self.list_actions[len(self.list_actions) - 1])

            model_n = n1
            if n1 == self.max_n:
                cp_schedule.finalize(n1)

        @action.register(Reverse)
        def action_reverse(cp_action):
            nonlocal model_r
            print(("<---"*(cp_action.n1-cp_action.n0) + "|").rjust(cp_action.n1*4) 
                  + "   "*(self.max_n - cp_action.n1 + 4) + 
                    self.list_actions[len(self.list_actions) - 1])

            model_r += cp_action.n1 - cp_action.n0
            
        @action.register(Copy)
        def action_copy(cp_action):
            print(("c").rjust(cp_action.n*4) 
                  + "   "*(self.max_n + 4) + 
                    self.list_actions[len(self.list_actions) - 1])

        @action.register(Move)
        def action_move(cp_action):
            print(("-").rjust(cp_action.n*4) 
                  + "   "*(self.max_n + 3) + 
                    self.list_actions[len(self.list_actions) - 1])

        @action.register(EndForward)
        def action_end_forward(cp_action):
            assert model_n == self.max_n
            # The correct number of adjoint steps has been taken
            if cp_schedule._max_n is None:
                cp_schedule._max_n = self.max_n
            
        @action.register(EndReverse)
        def action_end_reverse(cp_action):
            nonlocal model_r
            assert model_r == self.max_n
            print("End Reverse")

        model_n = 0
        model_r = 0

        for count, cp_action in enumerate(cp_schedule):
            self.list_actions.append(str(cp_action))
            action(cp_action)
            if isinstance(cp_action, EndReverse):  
                break

In [68]:
max_n = 4 # Total number of time steps.
save_ram = 2 # Number of steps to save i RAM.

In [69]:
chk_manager = CheckpointingManager(max_n, save_ram) # manager object

In [57]:
from checkpoint_schedules import NoneCheckpointSchedule
rev_ite = NoneCheckpointSchedule()
chk_manager.execute(rev_ite)

|--->--->--->--->            Forward(0, 9223372036854775807, False, False, <StorageType.NONE: None>)


In [58]:
from checkpoint_schedules import SingleStorageSchedule
rev_ite = SingleStorageSchedule()
chk_manager.execute(rev_ite)

 +
               +
|--->--->--->--->            Forward(0, 9223372036854775807, True, True, <StorageType.WORKING_MEMORY: 4>)
<---<---<---<---|            Reverse(4, 0, True)
End Reverse


In [70]:
from checkpoint_schedules import Revolve
rev_ite = Revolve(max_n, save_ram)

`rev_ite.execute(revolver)` runs the executions of the forward and adjoint solvers with.

In [60]:
chk_manager.execute(rev_ite)

 +
|--->--->                  Forward(0, 2, True, False, <StorageType.RAM: 0>)
       +
       |--->               Forward(2, 3, True, False, <StorageType.RAM: 0>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
       -                     Move(2, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
   +
|--->                     Forward(0, 1, False, True, <StorageType.AD

In [71]:
from checkpoint_schedules import HRevolve
save_disk = 1
save_ram = 1
chk_manager = CheckpointingManager(max_n, save_ram, save_disk=save_disk) # manager object
revolver = HRevolve(max_n, save_ram, snap_on_disk=save_disk)
chk_manager.execute(revolver)

 +
|--->--->--->               Forward(0, 3, True, False, <StorageType.RAM: 0>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
   +
|--->                     Forward(0, 1, False, True, <StorageType.ADJ_D

In [62]:
from checkpoint_schedules import DiskRevolve
revolver = DiskRevolve(max_n, save_ram)
chk_manager.execute(revolver)

 +
|--->--->--->               Forward(0, 3, True, False, <StorageType.RAM: 0>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
   +
|--->                     Forward(0, 1, False, True, <StorageType.ADJ_D

In [63]:
from checkpoint_schedules import PeriodicDiskRevolve
revolver = PeriodicDiskRevolve(max_n, save_ram)
chk_manager.execute(revolver)

We use periods of size  3
 +
|--->--->--->               Forward(0, 3, True, False, <StorageType.RAM: 0>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->--->                  Forward(0, 2, False, False, <StorageType.FWD_RESTART: 2>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
c                        Copy(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-                     Move(0, <StorageType.RAM: 0>, <StorageType.FWD_RESTART: 2>)
   +
|--->                     Forward(0, 1, False

In [64]:
from checkpoint_schedules import MixedCheckpointSchedule
snapshots = 2
revolver = MixedCheckpointSchedule(max_n, snapshots)
chk_manager.execute(revolver)

   +
|--->                     Forward(0, 1, False, True, <StorageType.DISK: 1>)
   +
   |--->--->               Forward(1, 3, True, False, <StorageType.DISK: 1>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
   -                     Move(1, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.DISK: 1>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
   -                     Move(1, <StorageType.DISK: 1>, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-                     Move(0, <StorageType.DISK: 1>, <StorageType.ADJ_DEPS: 3>)
<---|                     Reverse(1, 0, True)
End Reverse


In [65]:
from checkpoint_schedules import MultistageCheckpointSchedule
snapshots = 2
revolver = MultistageCheckpointSchedule(max_n, 0, snapshots)
chk_manager.execute(revolver)

 +
|--->--->                  Forward(0, 2, True, False, <StorageType.DISK: 1>)
       +
       |--->               Forward(2, 3, True, False, <StorageType.DISK: 1>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
       -                     Move(2, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
c                        Copy(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|                  Reverse(2, 1, True)
-                     Move(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
   +
|--->                     Forward(0, 1, False, True, <StorageTy

In [66]:
from checkpoint_schedules import TwoLevelCheckpointSchedule
revolver = TwoLevelCheckpointSchedule(2, snapshots)
chk_manager.execute(revolver)


 +
|--->--->                  Forward(0, 2, True, False, <StorageType.DISK: 1>)
       +
       |--->--->            Forward(2, 4, True, False, <StorageType.DISK: 1>)
       c                        Copy(2, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
       |--->               Forward(2, 3, False, False, <StorageType.FWD_RESTART: 2>)
               +
           |--->            Forward(3, 4, False, True, <StorageType.ADJ_DEPS: 3>)
           <---|            Reverse(4, 3, True)
       c                        Copy(2, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
           +
       |--->               Forward(2, 3, False, True, <StorageType.ADJ_DEPS: 3>)
       <---|               Reverse(3, 2, True)
c                        Copy(0, <StorageType.DISK: 1>, <StorageType.FWD_RESTART: 2>)
|--->                     Forward(0, 1, False, False, <StorageType.FWD_RESTART: 2>)
       +
   |--->                  Forward(1, 2, False, True, <StorageType.ADJ_DEPS: 3>)
   <---|       