# Computational Graphs - Symbolic Computation
*Christoph Heindl 2017, https://github.com/cheind/py-cgraph/*

This is part two in a series about computational graphs and their applications. The first part covered theoretical foundations of computational graphs and associated algorithms to perform forward function evaluation and backward derivative computations.

This part will focus on developing Python code that allows numeric and symbolic differentiation of arbitrary (real valued) functions.


## CGraph 

CGraph is the name of the Python library to be developed during the remainder of this notebook. While the code inside the notebook is functional a separate, self-contained and enhanced implementation of CGraph is available [cgraph.py](../cgraph.py). 

CGraph performs numeric and symbolic differentiation using backpropagation.

```Python
import cgraph as cg

x = cg.Symbol('x')
y = cg.Symbol('y')
z = cg.Symbol('z')

f = (x * y + 3) / (z - 2)

# Evaluate function
cg.value(f, {x:2, y:3, z:3}) # 9.0

# Partial derivatives (numerically)
d = cg.numeric_gradient(f, {x:2, y:3, z:3})
d[x] # df/dx 3.0
d[z] # df/dz -9.0

# Partial derivatives (symbolically)
d = cg.symbolic_gradient(f)
cg.simplify(d[x]) # (y*(1/(z - 2)))
cg.value(d[x], {x:2, y:3, z:3}) # 3.0

# Higher order derivatives
ddx = cg.symbolic_gradient(d[x])
cg.simplify(ddx[y]) # ddf/dxdy
# (1/(z - 2))
```

Python 3.5 will be used for development. The reader is assumed to be familiar with its concepts including generators and decorators. Also a technique called monkey patching will be used to iteratively refine classes already introduced.

### Expression trees

Before diving into code, we need to cover the concept of [expression trees](https://en.wikipedia.org/wiki/Binary_expression_tree). Expression trees will be used to represent function decompositions in Python code to be developed. While there are not a fundamentally new concept they deserve some words at this point.

An expression tree is similar to the computational graphs introduced, but the arrows by default point backwards. It turns out that constructing function expression in a tree like manner (top node is the function itself, function parameters are leafs) simplifies development dramatically.

Take the CG of the toy example used $f(x,y)=(x+y)x$

<img src="intro_0.png" width="400">

The following expression tree represents the same function

<img src="exp_tree.png" width="400">

Notice that we now have a tree like structure. Our root node is the final operation to be executed to receive the result of $f(x,y)$. Also notice that $x$ shows up twice. Finding the value of an expression tree requires to compute values for nodes in lower layers first and bubble information up towards the root node. Backpropagation on the other hand ist just a matter of following the forward edges. Again, when computing derivatives, a summation over all paths from the top that lead to a given node will be performed.

### Representing expression trees
First we need to come up with a way to represent expression that were introducted in the previous part. trees in Python code. Naturally, we will have base class `Node` that manages child references and derived classes that actually implement operations, symbols and constants.

In [1]:
class Node:

    def __init__(self, nary=0):
        self.children = [None]*nary

    def __repr__(self):
        return self.__str__()

Node for now just tracks references to its children. Note that operations can be binary (e.g. addition), unary (e.g cosine) or don't have children at all (e.g. symbols). We can also think of n-ary functions such as summation. Next, we'll define the leaf nodes Symbol and Constant

In [2]:
class Symbol(Node):

    def __init__(self, name):
        super(Symbol, self).__init__(nary=0)
        self.name = name

    def __str__(self):
        return self.name

    def __hash__(self):
        return hash(self.name)            
    
    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.name == other.name      
        else:
            return False

Symbols are identified by their name, like  $x$. They don't have any children. When printed we print the name of the symbol.

In [3]:
class Constant(Node):

    def __init__(self, value):
        super(Constant, self).__init__(nary=0)
        self.value = value

    def __str__(self):
        return str(self.value)

Constants are 'immutable' values. Next we start to add operations. For this notebook we will provide addition, multiplication. cgraph.py has more operations defined and once you know how to implement them it will be easy for you to add new ones.

In [4]:
class Add(Node):

    def __init__(self):
        super(Add, self).__init__(nary=2)

    def __str__(self):
        return '({}+{})'.format(str(self.children[0]), str(self.children[1]))
    
class Mul(Node):

    def __init__(self):
        super(Mul, self).__init__(nary=2)

    def __str__(self):
        return '({}*{})'.format(str(self.children[0]), str(self.children[1]))

`Add` and `Mul` don't do much yet expect that stating that they are binary functions plus some pretty printing (recursively calling `__str__` on its children). Next, well just have a helper function that builds our toy function $f(x,y)=(x+y)x$. This looks a bit clumsy right now but we'll improve the syntax as we go.

In [5]:
def gen_f(x, y): 
    a = Add()
    a.children[0] = x
    a.children[1] = y
    
    m = Mul()
    m.children[0] = a
    m.children[1] = x
    
    return m

x = Symbol('x')
y = Symbol('y')
f = gen_f(x, y)
f

((x+y)*x)

### Computing function values

Next we'll turn our attention towards computing values of functions represented as expression trees. As mentioned earlier to compute the value, we'll need to bubble up information from layers further down in hierarchy up to the root. Traversing expression trees can be performed in multiple ways. What we are looking for is [depth-first-search](https://en.wikipedia.org/wiki/Depth-first_search) in [post-order](https://en.wikipedia.org/wiki/Tree_traversal). There are many ways to implement the traversal, i've chosen the recursive generator approach because of its shortness.

In [6]:
def postorder(node):
    for c in node.children:
        yield from postorder(c)
    yield node

In [7]:
[n for n in postorder(f)]

[x, y, (x+y), x, ((x+y)*x)]

As you can see, children are evaluated before their parents. Excactly what's needed for computing function values. Next, define the method that computes the forward pass, i.e the value of the function.

In [8]:
def values(f, fargs):
    """Returns a dictionary of computed values for each node in the expression tree including `f`."""
    v = {}
    v.update(fargs)
    for n in postorder(f):
        if not n in v:
            v[n] = n.compute_value(v)
    return v

This method calls for each node `compute_value(values)` and expects the node to return its value. Since we haven't defined this function for our nodes yet, it's time to do so. Also note that `fargs` will be assumed to contain the values for the symbols in the expression tree.

In [9]:
# Monkey patching for compute_value
Symbol.compute_value = lambda self, values: values[self]
Constant.compute_value = lambda self, values : self.value
Add.compute_value = lambda self, values: values[self.children[0]] + values[self.children[1]]
Mul.compute_value = lambda self, values: values[self.children[0]] * values[self.children[1]]

After monkey patching in `compute_value` for all nodes we can evaluate `f` by

In [10]:
values(f, {x:2, y:3})[f]

10

Since `values` computes all the values even for intermediate nodes we need to add `[f]` as a postfix. Just accessing the value of `f` will however be so common task that we provide a shortcut for it named `value`.

In [11]:
def value(f, fargs):
    return values(f, fargs)[f]

value(f, {x:2, y:3})

10

### Syntactic sugar

Before continuing it makes sense to use Python's internal methods for 'overloading' the `+` and `*` operation for Nodes. First, we'll define a decorator that will wrap plain numbers to `Constants`.

In [12]:
from numbers import Number

def wrap_args(func):
    """Wraps function arguments that are numbers as Constant objects."""
    def wrapped(*args, **kwargs):
        new_args = []
        for a in args:
            if isinstance(a, Number):
                a = Constant(a)
            new_args.append(a)
        return func(*new_args, **kwargs)
    return wrapped

Next, we'll define some free functions that perform the 'lengthy' addition and multiplication. By convention these free functions will start with the prefix `sym_` (for symbolic). When adding new functionality you should always provide such a function (e.g `sym_pow, sym_cos`).

In [13]:
@wrap_args
def sym_add(x, y):
    n = Add()
    n.children[0] = x
    n.children[1] = y
    return n

@wrap_args
def sym_mul(x, y):
    n = Mul()
    n.children[0] = x
    n.children[1] = y
    return n

Finally we monkey patch `Node` to support `+` and `*` operations

In [14]:
Node.__add__ = lambda self, other: sym_add(self, other)
Node.__radd__ = lambda self, other: sym_add(other, self)
Node.__mul__ = lambda self, other: sym_mul(self, other)
Node.__rmul__ = lambda self, other: sym_mul(other, self)

Note that the `__r*` methods are also overloaded so that expressions of the type `n*3` and `3*n` work equally well. With that we can rewrite `gen_f` introduced by simply saying

In [15]:
f = (x + y)*x
f

((x+y)*x)