# Automatic Differentiation

The `autograd` package in Python automate the computation of backward passes for **automatic differentiation**. When using `autograd`, a **computational graph** is defined under the hood. Any function written in the source code is represented as a computational graph; nodes in the graph will be tensors, and edges will be functions that produce output tensors from input tensors. Backpropagating through this graph then allows you to easily compute gradients.

It is the mathematical equivalent of journeying around the world with zero planning. I can just keep composing and stacking functions, always assured that `autograd` is going to be able follow the breadcrumbs and compute a derivative for me.

**Source**: [Autodidact: A tutorial implementation of Autograd](https://github.com/mattjj/autodidact)

## Define the computational graph

In [None]:
"""Tracing utilities.
This library provides functions for constructing a computation graph. With this
library, one can,
- Build a computation graph. (trace)
- Register wrapper types for unwrapped values based on type(). (Box.register)
- Build functions that can deal with wrapped values. (primitive,
  notrace_primitive)
- Box values. (new_box)
"""
from collections import defaultdict
from contextlib import contextmanager

from .util import subvals, wraps

def trace(start_node, fun, x):
    with trace_stack.new_trace() as trace_id:
        # Wrap 'x' in a box.
        start_box = new_box(x, trace_id, start_node)

        # Apply fun() to boxed value. This will carry the value throughout the
        # comutation as well as the box.
        end_box = fun(start_box)

        if isbox(end_box) and end_box._trace_id == start_box._trace_id:
            # Extract final value (== fun(x)) and its node in the computation
            # graph.
            return end_box._value, end_box._node
        else:
            # Output seems independent of input
            return end_box, None

## Define the graph nodes. 

Each node has a value obtained by applying a function 

In [None]:
class Node(object):
    """A node in a computation graph."""
    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        """
        Args:
          value: output of fun(*args, **kwargs)
          fun: wrapped numpy that was applied.
          args: all (unboxed) positional arguments.
          kwargs: dict of additional keyword args.
          parent_argnums: integers corresponding to positional indices of boxed
            values.
          parents: Node instances corresponding to parent_argnums.
        """
        self.parents = parents
        self.recipe = (fun, value, args, kwargs, parent_argnums)

    def initialize_root(self):
        self.parents = []
        self.recipe = (lambda x: x, None, (), {}, [])

    @classmethod
    def new_root(cls, *args, **kwargs):
        root = cls.__new__(cls)
        root.initialize_root(*args, **kwargs)
        return root

In [None]:


def primitive(f_raw):
    """Wraps a function so that its gradient (vjp) can be specified and its
    invocation can be recorded."""
    @wraps(f_raw)
    def f_wrapped(*args, **kwargs):
        # Fetch boxed arguments with largest trace_id.  This ensures that the
        # computational graph being constructed only consists of other nodes
        # from the same call to trace().
        boxed_args, trace_id = find_top_boxed_args(args)
        if boxed_args:
            # Replace some elements of args with corresponding unboxed values.
            argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args])
            # Get nodes for each boxed argument.
            parents = tuple(box._node for _, box in boxed_args)

            # Get argument indices for each boxed argument.
            argnums = tuple(argnum for argnum, _ in boxed_args)

            # Calculate result of applying original numpy function.
            #
            # Note that we use a recursive call here in order to also augment
            # outer calls to trace() with lower trace_ids. See TraceStack's
            # docstring for details.
            ans = f_wrapped(*argvals, **kwargs)

            # Create a new node
            node = Node(ans, f_wrapped, argvals, kwargs, argnums, parents)
            return new_box(ans, trace_id, node)
        else:
            return f_raw(*args, **kwargs)
    return f_wrapped

def notrace_primitive(f_raw):
    """Wrap a raw numpy function by discarding boxes.
    Results are not boxed. Unboxing is a signal that the f_raw() is
    non-differentiable with respect to its arguments. Consider the computation,
    ```
    x = 1.5
    y = np.floor(x) + x
    ```
    What is the derivative of y wrt x? Autograd says 1. as np.floor has zero
    derivative near x=1.5.
    """
    @wraps(f_raw)
    def f_wrapped(*args, **kwargs):
        # Extract np.ndarray values from boxed values.
        argvals = map(getval, args)

        # Call original function. Note that f_raw()'s arguments may still be
        # boxed, but with a lower trace_id.
        return f_raw(*argvals, **kwargs)
    return f_wrapped

def find_top_boxed_args(args):
    """Finds boxed arguments with largest trace_id.
    Equivalent to finding the largest trace_id of any argument, keeping args
    with the same, and dropping the remainder.
    Args:
      args: Arguments to function wrapped by primitive().
    Returns:
      top_boxes: List of (index, boxed argument). Arguments have same, largest
        trace_id.
      top_trace_id: trace_id of all elements in top_boxes.
    """
    top_trace_id = -1
    top_boxes = []
    for argnum, arg in enumerate(args):
        if isbox(arg):
            if arg._trace_id > top_trace_id:
                top_boxes = [(argnum, arg)]
                top_trace_id = arg._trace_id
            elif arg._trace_id == top_trace_id:
                top_boxes.append((argnum, arg))
    return top_boxes, top_trace_id

class TraceStack(object):
    """Tracks number of times trace() has been called.
    This is critical to ensure calling grad() on a function that also calls
    grad() resolves correctly. For example,
    ```
    def f(x):
      def g(y):
        return x * y
      return grad(g)(x)
    y = grad(f)(5.)
    ```
    First, grad(f)(5.) wraps 5. in a Box and calls f(Box(5)). Then, grad(g)(x)
    wraps Box(5) again and calls g(Box(Box(5)). When computing grad(g), we want
    to treat x=Box(5) as fixed -- it's not a direct argument to g(). How does
    Autograd know that x is fixed, when all it can see is
    np.multipy(Box(5.), Box(Box(5.))? Because the second argument has a larger
    trace_id than the former!
    """
    def __init__(self):
        self.top = -1

    @contextmanager
    def new_trace(self):
        """Increment trace depth."""
        self.top += 1
        yield self.top
        self.top -= 1

trace_stack = TraceStack()

class Box(object):
    """Boxes a value within a computation graph."""

    # Type -> subclasses of Box. Types may be instances of Box. Subclasses must
    # take same arguments for __init__().
    type_mappings = {}

    # Non-Box types that can be boxed.
    types = set()

    def __init__(self, value, trace_id, node):
        self._value = value
        self._node = node
        self._trace_id = trace_id

    def __bool__(self):
        return bool(self._value)

    __nonzero__ = __bool__

    def __str__(self):
        return "Autograd {0} with value {1}".format(
            type(self).__name__, str(self._value))

    @classmethod
    def register(cls, value_type):
        """Register a class as a Box for type 'value_type'.
        Should be called immediately after declaration.
        Args:
          cls: Inherits from Box. Type to box values of type 'value_type'.
          value_type: Type to be boxed.
        """
        Box.types.add(cls)
        Box.type_mappings[value_type] = cls

        # The Box implementation for a Box type is itself. Why? Imagine a nested
        # call to grad(). One doesn't want the inner Box's computation graph to
        # interact with the outer Box's.
        Box.type_mappings[cls] = cls


box_type_mappings = Box.type_mappings

def new_box(value, trace_id, node):
    """Box an unboxed value.
    Args:
      value: unboxed value
      trace_id: int. Trace stack depth.
      node: Node corresponding to this boxed value.
    Returns:
      Boxed value.
    """
    try:
        return box_type_mappings[type(value)](value, trace_id, node)
    except KeyError:
        raise TypeError("Can't differentiate w.r.t. type {}".format(type(value)))

box_types = Box.types

# If True, the value is Box.
isbox  = lambda x: type(x) in box_types  # almost 3X faster than isinstance(x, Box)

# Get value from a Box.
getval = lambda x: getval(x._value) if isbox(x) else x

In [1]:
import numpy as np

class Graph():
    """ Computational graph class. 
    Initilizes a global variable _g that describes the graph.
    Each graph consists of a set of
        1. operators
        2. variables
        3. constants
        4. placeholders
    """
    def __init__(self):
        self.operators = set()
        self.constants = set()
        self.variables = set()
        self.placeholders = set()
        global _g
        _g = self
        
    def reset_counts(self, root):
        if hasattr(root, 'count'):
            root.count = 0
        else:
            for child in root.__subclasses__():
                self.reset_counts(child)

    def reset_session(self):
        try:
            del _g
        except:
            pass
        self.reset_counts(Node)
        
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.reset_session()

## Define the graph nodes. 

The operator node is virtual, it is never called. Only subclasses of it containing actual operations should ever be called.

### Make a basic Node class to inherit from

This won't do anything other than allow us to check if in object is a Graph node or not

In [2]:
class Node:
    def __init__(self):
        pass

### Define Variables, Constants, Placeholders, and Operators

#### Placeholders

In [3]:
class Placeholder(Node):
    """An placeholder node in the computational graph. This holds
    a node, and awaits further input at computation time.
    Args: 
        name: defaults to "Plc/"+count
        dtype: the type that the node holds, float, int, etc.
    """
    count = 0
    def __init__(self, name, dtype=float):
        _g.placeholders.add(self)
        self.value = None
        self.gradient = None
        self.name = f"Plc/{Placeholder.count}" if name is None else name
        Placeholder.count += 1
        
    def __repr__(self):
        return f"Placeholder: name:{self.name}, value:{self.value}"

#### Constants

In [4]:
class Constant(Node):
    """An constant node in the computational graph.
    Args: 
        name: defaults to "const/"+count
        value: a property protected value that prevents user 
               from reassigning value
    """
    count = 0
    def __init__(self, value, name=None):
        _g.constants.add(self)
        self._value = value
        self.gradient = None
        self.name = f"Const/{Constant.count}" if name is None else name
        Constant.count += 1
        
    def __repr__(self):
        return f"Constant: name:{self.name}, value:{self.value}"
    
    @property
    def value(self):
        return self._value
    
    @value.setter
    def value(self):
        raise ValueError("Cannot reassign constant")

#### Variables

In [5]:
class Variable(Node):
    """An variable node in the computational graph. Variables are
    automatically tracked during graph computation.
    Args: 
        name: defaults to "var/"+count
        value: a mutable value
    """
    count = 0
    def __init__(self, value, name=None):
        _g.variables.add(self)
        self.value = value
        self.gradient = None
        self.name = f"Var/{Variable.count}" if name is None else name
        Variable.count += 1
        
    def __repr__(self):
        return f"Variable: name:{self.name}, value:{self.value}"

#### Define Operators

This way, we can provide addition and multiplication as dunder functions, and overload the python operators '+' and '*'.

In [6]:
class Operator(Node):
    """An operator node in the computational graph.
    Args: 
        name: defaults to "operator name/"+count
    """
    def __init__(self, name='Operator'):
        _g.operators.add(self)
        self.value = None
        self.inputs = []
        self.gradient = None
        self.name = name
    
    def __repr__(self):
        return f"Operator: name:{self.name}"

### Create some actual operators that do things

In [7]:
class add(Operator):
    count = 0
    """Binary addition operation."""
    def __init__(self, a, b, name=None):
        super().__init__(name)
        self.inputs=[a, b]
        self.name = f'add/{add.count}' if name is None else name
        add.count += 1
        
    def forward(self, a, b):
        return a+b
    
    def backward(self, a, b, dout):
        return dout, dout

class multiply(Operator):
    count = 0
    """Binary multiplication operation."""
    def __init__(self, a, b, name=None):
        super().__init__(name)
        self.inputs=[a, b]
        self.name = f'mul/{multiply.count}' if name is None else name
        multiply.count += 1
        
    def forward(self, a, b):
        return a*b
    
    def backward(self, a, b, dout):
        return dout*b, dout*a
    
class divide(Operator):
    count = 0
    """Binary division operation."""
    def __init__(self, a, b, name=None):
        super().__init__(name)
        self.inputs=[a, b]
        self.name = f'div/{divide.count}' if name is None else name
        divide.count += 1
   
    def forward(self, a, b):
        return a/b
    
    def backward(self, a, b, dout):
        return dout/b, dout*a/np.power(b, 2)
    
    
class power(Operator):
    count = 0
    """Binary exponentiation operation."""
    def __init__(self, a, b, name=None):
        super().__init__(name)
        self.inputs=[a, b]
        self.name = f'pow/{power.count}' if name is None else name
        power.count += 1
   
    def forward(self, a, b):
        return np.power(a, b)
    
    def backward(self, a, b, dout):
        return dout*b*np.power(a, (b-1)), dout*np.log(a)*np.power(a, b)
    
class matmul(Operator):
    count = 0
    """Binary multiplication operation."""
    def __init__(self, a, b, name=None):
        super().__init__(name)
        self.inputs=[a, b]
        self.name = f'matmul/{matmul.count}' if name is None else name
        matmul.count += 1
        
    def forward(self, a, b):
        return a@b
    
    def backward(self, a, b, dout):
        return dout@b.T, a.T@dout

## For convenience, overload all of these

In [8]:
def node_wrapper(func, self, other):
    if isinstance(other, Node):
        return func(self, other)
    if isinstance(other, float) or isinstance(other, int):
        return func(self, Constant(other))
    raise TypeError("Incompatible types.")

Node.__add__ = lambda self, other: node_wrapper(add, self, other)
Node.__mul__ = lambda self, other: node_wrapper(multiply, self, other)
Node.__div__ = lambda self, other: node_wrapper(divide, self, other)
Node.__neg__ = lambda self: node_wrapper(multiply, self, Constant(-1))
Node.__pow__ = lambda self, other: node_wrapper(power, self, other)
Node.__matmul__ = lambda self, other: node_wrapper(matmul, self, other)

## Test it out!

In [9]:
with Graph() as g:
    x = Variable(1.3)
    y = Variable(0.9)
    z = x*y+5

In [10]:
print(g.constants)
print(g.variables)
print(g.operators)

{Constant: name:Const/0, value:5}
{Variable: name:Var/0, value:1.3, Variable: name:Var/1, value:0.9}
{Operator: name:mul/0, Operator: name:add/0}


# Autograd

In [11]:
def topological_sort(head_node=None, graph=_g):
    """Performs topological sort of all nodes prior to and 
    including the head_node. 
    Args:
        graph: the computational graph. This is the global value by default
        head_node: last node in the forward pass. The "result" of the graph.
    Returns:
        a sorted array of graph nodes.
    """
    vis = set()
    ordering = []
    
    def _dfs(node):
        if node not in vis:
            vis.add(node)
            if isinstance(node, Operator):
                for input_node in node.inputs:
                    _dfs(input_node)
            ordering.append(node)
            
    if head_node is None:
        for node in graph.operators:
            _dfs(node)
    else:
        _dfs(head_node)
        
    return ordering

In [12]:
def forward_pass(order, feed_dict={}):
    """ Performs the forward pass, returning the output of the graph.
    Args:
        order: a topologically sorted array of nodes
        feed_dict: a dictionary values for placeholders.
    Returns:
        1. the final result of the forward pass.
        2. directly edits the graph to fill in its current values.
    """
    for node in order:
        
        if isinstance(node, Placeholder):
            node.value = feed_dict[node.name]
                    
        elif isinstance(node, Operator):
            node.value = node.forward(*[prev_node.value for prev_node in node.inputs])

    return order[-1].value

In [13]:
def backward_pass(order, target_node=None):
    """ Perform the backward pass to retrieve gradients.
    Args:
        order: a topologically sorted array of graph nodes.
               by default, this assigns the graident of the final node to 1
    Returns:
        gradients of nodes as listed in same order as input argument
    """
    vis = set()
    order[-1].gradient = 1
    for node in reversed(order):
        if isinstance(node, Operator):
            inputs = node.inputs
            grads = node.backward(*[x.value for x in inputs], dout=node.gradient)
            for inp, grad in zip(inputs, grads):
                if inp not in vis:
                    inp.gradient = grad
                else:
                    inp.gradient += grad
                vis.add(inp)
    return [node.gradient for node in order]

In [14]:
val1, val2, val3 = 0.9, 0.4, 1.3

with Graph() as g:
    x = Variable(val1, name='x')
    y = Variable(val2, name='y')
    c = Constant(val3, name='c')
    z = (x*y+c)*c + x

    order = topological_sort(z)
    res = forward_pass(order)
    grads = backward_pass(order)

    print("Node ordering:")
    for node in order:
        print(node)

    print('-'*10)
    print(f"Forward pass expected: {(val1*val2+val3)*val3+val1}")
    print(f"Forward pass computed: {res}")

Node ordering:
Variable: name:x, value:0.9
Variable: name:y, value:0.4
Operator: name:mul/0
Constant: name:c, value:1.3
Operator: name:add/0
Operator: name:mul/1
Operator: name:add/1
----------
Forward pass expected: 3.0580000000000003
Forward pass computed: 3.0580000000000003


In [15]:
dzdx_node = [a for a in order if a.name=='x'][0]
dzdy_node = [a for a in order if a.name=='y'][0]
dzdc_node = [a for a in order if a.name=='c'][0]

print(f"dz/dx expected = {val3*val2+1}")
print(f"dz/dx computed = {dzdx_node.gradient}")

print(f"dz/dy expected = {val1*val3}")
print(f"dz/dy computed = {dzdy_node.gradient}")

print(f"dz/dc expected = {val1*val2+2*val3}")
print(f"dz/dc computed = {dzdc_node.gradient}")

dz/dx expected = 1.52
dz/dx computed = 1.52
dz/dy expected = 1.1700000000000002
dz/dy computed = 1.1700000000000002
dz/dc expected = 2.96
dz/dc computed = 2.96
