In [67]:
# we need to define our function primitives
# for now, we will define a tanh() function
# and a linear transformation function

import numpy as np

class Identity():
    def __init__(self):
        pass

    def forward_fct(self, input, weights):
        return input
    
    def vjp_fct(self, input, weights, vec):
        return vec

class Linear():
    def __init__(self):
        pass
    
    def forward_fct(self, input, weights):
        return np.dot(weights, input)

    def vjp_fct(self, input, weights, vec):
        return np.dot(vec, weights)

class Tanh():
    def __init__(self):
        pass

    def forward_fct(self, input, weights):
        return np.tanh(input)
    
    def vjp_fct(self, input, weights, vec):
        sech2 = 1./ np.cosh(input)**2.
        return vec * sech2


In [68]:
d = 13
w = np.random.normal(size=d**2).reshape((d, d))
x = np.random.normal(size=d)

assert (Linear().forward_fct(x, w) == np.dot(w, x)).all()
assert np.allclose(Tanh().forward_fct(np.pi, None), np.tanh(np.pi))

In [69]:
class CompGraphNode:
    def __init__(self, fct, name, weights):
        self.parents = [] # comp graph nodes that call this function
        self.children = [] # comp graph nodes that are called by this function
        self.fct = fct
        self.name = name
        self.input = None
        self.weights = weights

In [70]:
# let's build a simple example CG
# X -> Linear -> Tanh -> Linear -> Tanh
# X ->  A(X)  -> B(a) - > C(b)  -> D(c)

d_in = 11
d_out = 1
d_hid = 23
w_A = np.random.normal(size=(d_hid, d_in))
w_C = np.random.normal(size=(d_out, d_hid))

start_node = CompGraphNode(Identity(), 'start', None)
node_A = CompGraphNode(Linear(), 'A', w_A)
node_B = CompGraphNode(Tanh(), 'B', None)
node_C = CompGraphNode(Linear(), 'C', w_C)
node_D = CompGraphNode(Tanh(), 'D', None)

# this defines the topology
start_node.children.append(node_A)
node_A.parents.append(start_node)

node_A.children.append(node_B)
node_B.parents.append(node_A)

node_B.children.append(node_C)
node_C.parents.append(node_B)

node_C.children.append(node_D)
node_D.parents.append(node_C)

In [71]:
def topo_sort(node):
    topo = []
    visited = set()
    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for child in v.children:
                build_topo(child)
            topo.append(v)
    build_topo(node)
    return topo

In [72]:
for n in topo_sort(start_node):
    print(f"node {n.name}")

node D
node C
node B
node A
node start


In [73]:
def prop_forward(head_node, x):
    node = head_node
    y = x.copy()
    while node.children:
        children = node.children
        for child in children:
            child.input = y.copy()
            y = child.fct.forward_fct(y, child.weights)
            node = child
    return y

def prop_backward(head_node, y):
    # get the topologically sorted nodes
    nodes = topo_sort(head_node)
    # init the VJP as y, the output of the function
    vjp = y.copy()
    # for each node, prop the VJP 
    for node in nodes:
        vjp = node.fct.vjp_fct(node.input, node.weights, vjp)
    return vjp

In [74]:
x = np.random.normal(size=d_in)
y_bg = prop_forward(start_node, x)
dydx_bg = prop_backward(start_node, y_bg)

In [75]:
import autograd as ag

In [76]:
def test_function(input):
    z = ag.numpy.dot(w_A, input)
    z = ag.numpy.tanh(z)
    z = ag.numpy.dot(w_C, z)
    z = ag.numpy.tanh(z)
    return z

In [77]:
y_ag = test_function(x)
dydx_ag = ag.grad(test_function)(x)

In [78]:
atol = 1e-2
assert np.allclose(y_bg, y_ag, atol=atol)
assert np.allclose(dydx_bg, dydx_ag, atol=atol)

AssertionError: 

In [79]:
dydx_bg

array([ 0.08591834,  0.26385867,  0.12385824,  0.20827776,  0.30192723,
       -0.18920137,  0.06440053, -0.29868297, -0.10433536,  0.21340459,
        0.03709308])

In [80]:
dydx_ag

array([ 0.20473063,  0.628736  ,  0.29513577,  0.49629494,  0.71944772,
       -0.45083874,  0.15345689, -0.71171712, -0.24861565,  0.50851141,
        0.08838729])