In [2]:
import random
import numpy as np
from graphviz import Digraph

In [70]:
"""
Assumptions:
    1. Graph is acyclic
"""

"""
TODO:
    1. Forward computation when leaf nodes change.
"""


class Node:
    def __init__(self, value, label = '', children=[], op=''):
        assert label or len(children) is 2
        self.value = value
        self.children = children
        self.label = label if label else f"({children[0].getName()} {op} {children[1].getName()})"
        self.grad = 0.0
        self.op = op

    def getOp(self):
        return self.op
        
    def getValue(self):
        return self.value
    
    def getGrad(self):
        return self.grad
    
    def __str__(self):
        return self.getName()

    def __repr__(self):
        return self.getName()
    
    def getLabel(self):
        return  "{} | data: {:.4f} | grad:  {:.4f}".format(self.label, self.getValue(), self.getGrad())
    
    def getName(self):
        return self.label
    
    def draw(self):
        dot = Digraph(self.getLabel(), graph_attr={'rankdir': 'LR'})
        nodes = self.top_sort()

        for node in nodes:
            dot.node(name=node.getName(), label=node.getLabel(), shape='record')
            if not len(node.children): continue
            [left, right] = node.children
            dot.edge(left.getName(), node.getName())
            dot.edge(right.getName(), node.getName())
        return dot
    
    def top_sort(self):
        nodes = []
        visited = set()
        self.dfs(nodes, visited)
        return nodes
    
    def dfs(self, nodes=[], visited=set()):
        visited.add(self)
        for child in self.children:
            if child not in visited:
                child.dfs(nodes, visited)
        nodes.append(self)        
    
    def backward(self):
        nodes = self.top_sort()
        for node in nodes:
            node.grad = 0.0
        self.grad = 1.0
        for node in reversed(nodes):
            if len(node.children) == 2:
                [left, right] = node.children
                left.setGrad(right, node)
                right.setGrad(left, node)

    def forward(self):
        nodes = self.top_sort()
        for node in nodes:
            node.compute()

    def compute(self):
        if self.op is '': return
        assert len(self.children) is 2
        [ left, right ] = self.children
        if self.op is '+':
            self.value = left.getValue() + right.getValue()
        elif self.op is '*':
            self.value = left.getValue() * right.getValue()

    def setGrad(self, other = None, parent = None):
        if not parent:
            self.grad = 1.0
        if parent.getOp() == '+': # Plus acts like a pass through
            self.grad  += parent.getGrad()
        if parent.getOp() == '*': 
            self.grad += other.getValue() * parent.getGrad()
    
    def __add__(self, _other):
        other = _other if isinstance(_other, Node) else Node(_other, str(_other))
        return Node( self.getValue()+other.getValue(), label='', children=[self, other], op='+')

    def __radd__(self, other):
        return self + other

    
    def __mul__(self, other):
        return Node( self.getValue()*other.getValue(), label='', children=[self, other], op='*')


    
# a = Node(2.0, 'a')
# b = Node(-3.0, 'b')
# c = Node(10.0, 'c') 
# d = a*b+c
# d.backward()
# d.draw()

# a = Node(3.0, 'a')
# b = Node(3.0, 'b')
# c = Node(4.0, 'c')
# d = a*b
# e = c*d
# e.backward()
# e.draw()


a = Node(2.0, 'a')
b = Node(-3.0, 'b')
c = Node(10.0, 'c')
e = a*b
d = e+c
f = Node(-2.0, 'f')
L = d*f

L.forward()
L.backward()

# print(L.getValue())

for _ in range(50):
    old_val = L.value
    alpha = 0.01
    a.value += alpha * a.grad
    b.value += alpha * b.grad
    c.value += alpha * c.grad
    f.value += alpha * f.grad
    L.forward()
    L.backward()
    print(old_val, L.value)

# L.draw()
# print(L.getValue())

#         de/de = 1
#         de/dd = c = 4
#         de/dc = d = 9
#         dd/da = b
#         dd/db = a
#         de/da = de/dd * dd/da = c * b = 12


-8.0 -7.286496
-7.286496 -6.597755834428276
-6.597755834428276 -5.9314458386892275
-5.9314458386892275 -5.285081843134573
-5.285081843134573 -4.6560150657269235
-4.6560150657269235 -4.041415395804313
-4.041415395804313 -3.438251441813934
-3.438251441813934 -2.8432666409313563
-2.8432666409313563 -2.252950632362309
-2.252950632362309 -1.6635049738718164
-1.6635049738718164 -1.0708021264853518
-1.0708021264853518 -0.47033643753822213
-0.47033643753822213 0.14283439258544067
0.14283439258544067 0.7741591821893543
0.7741591821893543 1.4296766869686521
1.4296766869686521 2.116115260126322
2.116115260126322 2.8410120608999754
2.8410120608999754 3.6128558927978274
3.6128558927978274 4.4412587349741175
4.4412587349741175 5.337162280893457
5.337162280893457 6.313087398325539
6.313087398325539 7.383436481315852
7.383436481315852 8.564861322520992
8.564861322520992 9.876712588436257
9.876712588436257 11.341591496396777
11.341591496396777 12.98603023532342
12.98603023532342 14.841335545721124
14.8