In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# recurse
> A class that provides a decorator for visualizing recursion trees and caching results

In [None]:
#| default_exp recurse

In [None]:
#| export
from recursion_visualizer.node import Node 
from recursion_visualizer.animate import get_edges

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
class RecursionVisualizer:
  """A class that provides a decorator for visualizing recursion trees and caching results.
  """

  def __init__(self,
               verbose: bool = False, # if true, print all nodes
               animate: bool = True, # if true, create an animation of the recursion tree
               save: bool = False, # if true, save the animation to a html file
               path: str ='', # path to save the animation to
               ): 
    
    self.verbose = verbose
    self.animate = animate
    self.save = save
    self.path = path
    self._reset()

  def _reset(self):
    """
    self.nodes = preorder traversal of nodes
    self.history = element i was discovered or finished at time i
    self.pos = position of vertices in animate
    """
    self.nodes, self.edge_labels, self.history = {}, {}, []
    self.id, self.time, self.depth = 0, 0, 0
    self.cache = {}
    self.func_name = ''

  def _animate(self, nodes, history, edge_labels, func_name):
    edge_to_label = get_edges(nodes, history, edge_labels)

    # # create recursion tree animation
    # fig = animate(history, nodes, func_name)
    # fig.show()

    # # save figure
    # if self.save:
    #   if self.path == '':
    #     input = ','.join(list(map(str, nodes[0]['input'])))
    #     self.path = './{}_{}.html'.format(func_name, input)
    #   fig.write_html(self.path)

  def __call__(self, 
               func: callable # function to be visualized or cached via decorator
               ):
    """A custom `__call__` function records the id, function input, function output, depth, discovery time, 
    and finish time in a node each time the function is called. After all function calls are made, `__call__`
    will animate the recursion tree. This is the main workhorse of the `RecursionVisualizer` class.
    
    At a high-level, the `__call__` function looks something like:
    
    ```
    def __call__(self, func):
      def memoized_func(*args, **kwargs):
        # record discovery time, function input, and depth
        node.discovery = time
        node.input = args
        node.depth = depth
        
        # if node not in cache, compute and cache result
        if node not in self.cache:
          self.cache[args] = func(*args, **kwargs)
          
        # record finish time and function output
        node.output = self.cache[args]
        node.finish = time
        
        if depth == 0:
          animate()
        
      return memoized_func
    ```
    """

    def memoized_func(*args, **kwargs):
      if self.depth == 0:
        self._reset()

      # record node's depth, discovery time, and input arguments
      id_ = len(self.nodes)
      node = Node(id_=id_, input=args, depth=self.depth, discovery=self.time)
      self.history.append(node.id_)
      self.nodes[node.id_] = node
      self.time += 1

      # update depth and call the function `func`
      self.depth += 1
      # if args not in self.cache:
      self.cache[args] = func(*args, **kwargs)
      self.depth -= 1

      # record node's output, finish time, history, and edge_label
      self.nodes[id_].output = self.cache[args]
      self.nodes[id_].finish = self.time
      
      self.edge_labels[id_] = kwargs['edge_label'] if kwargs and 'edge_label' in kwargs else ''
      self.history.append(node.id_)
      self.time += 1

      if self.verbose:
        print(node)

      # animate after done traversing through the entire tree
      if self.animate and self.depth == 0:
        self._animate(self.nodes, self.history, self.edge_labels, func.__name__)

      return self.cache[args]
    return memoized_func

In [None]:
@RecursionVisualizer()
def fib(n):
  if n <= 2: 
    return 1
  return fib(n-1) + fib(n-2)

In [None]:
fib(5)

5

In [None]:
#show_doc(RecursionVisualizer.__call__)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

JSONDecodeError: Expecting value: line 3 column 1 (char 2)