In [664]:
import numpy as np

class Identity():
    def __init__(self):
        pass

    def forward_fct(self, input, weights):
        return input
    
    def vjp_fct(self, input, weights, vec):
        return vec

class Linear():
    def __init__(self):
        pass
    
    def forward_fct(self, input, weights):
        return np.dot(weights, input)

    def vjp_fct(self, input, weights, vec):
        return np.dot(vec, weights)

class Tanh():
    def __init__(self):
        pass

    def forward_fct(self, input, weights):
        return np.tanh(input)
    
    def vjp_fct(self, input, weights, vec):
        # eps = 1e-5 # for numerical stability
        # sech2 = 1./ (np.cosh(input)**2. + eps)
        sech2 = 1./ (np.cosh(input)**2.)
        return vec * sech2


In [665]:
d = 13
w = np.random.normal(size=d**2).reshape((d, d))
x = np.random.normal(size=d)

assert (Linear().forward_fct(x, w) == np.dot(w, x)).all()
assert np.allclose(Tanh().forward_fct(np.pi, None), np.tanh(np.pi))

In [666]:
class CompGraphNode:
    def __init__(self, fct, name, weights):
        self.parents = [] # comp graph nodes that call this function
        self.children = [] # comp graph nodes that are called by this function
        self.fct = fct
        self.name = name
        self.input = None
        self.weights = weights

In [667]:
# let's build a simple example CG
# X -> Linear -> Tanh -> Linear -> Tanh
# X ->  A(X)  -> B(a) - > C(b)  -> D(c)

d_in = 7
d_out = 1
d_hid = 13
w_A = np.random.normal(size=(d_hid, d_in))
w_C = np.random.normal(size=(d_out, d_hid))

start_node = CompGraphNode(Identity(), 'start', None)
node_A = CompGraphNode(Linear(), 'A', w_A)
node_B = CompGraphNode(Tanh(), 'B', None)
# node_B = CompGraphNode(Identity(), 'B', None)
node_C = CompGraphNode(Linear(), 'C', w_C)
node_D = CompGraphNode(Tanh(), 'D', None)
# node_D = CompGraphNode(Identity(), 'D', None)

# this defines the topology
start_node.children.append(node_A)
node_A.parents.append(start_node)

node_A.children.append(node_B)
node_B.parents.append(node_A)

node_B.children.append(node_C)
node_C.parents.append(node_B)

node_C.children.append(node_D)
node_D.parents.append(node_C)

In [668]:
def topo_sort(node):
    topo = []
    visited = set()
    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for child in v.children:
                build_topo(child)
            topo.append(v)
    build_topo(node)
    return topo

In [669]:
# for n in topo_sort(start_node):
#     print(f"node {n.name}")

In [670]:
def prop_forward(head_node, x):
    node = head_node
    y = np.copy(x)
    while node.children:
        children = node.children
        for child in children:
            child.input = np.copy(y)
            y = child.fct.forward_fct(y, child.weights)
            # print(f"node {child.name}, input {child.input}, output {y} ")
            node = child
    return y

def prop_backward(head_node, y):
    # get the topologically sorted nodes
    nodes = topo_sort(head_node)
    # init the VJP
    vjp = np.array([1.])
    # for each node, prop the VJP 
    for node in nodes:
        # print(f"node {node.name}, input {node.input}, vec {vjp}")
        vjp = node.fct.vjp_fct(node.input, node.weights, vjp)
    return vjp

In [671]:
x = np.random.normal(size=d_in)
# print(f"x : {x}")
y_bg = prop_forward(start_node, x)

In [672]:
dydx_bg = prop_backward(start_node, y_bg)

In [673]:
import autograd as ag

In [674]:
def test_function(input):
    z = ag.numpy.dot(w_A, input)
    z = ag.numpy.tanh(z)
    z = ag.numpy.dot(w_C, z)
    z = ag.numpy.tanh(z)
    return z

In [675]:
y_ag = test_function(x)
dydx_ag = ag.grad(test_function)(x)

In [676]:
atol = 1e-5
assert np.allclose(y_bg, y_ag, atol=atol)

In [677]:
assert np.allclose(dydx_bg, dydx_ag, atol=atol)

In [678]:
dydx_bg

array([ 0.0052105 ,  0.01403135, -0.00236915, -0.00395498, -0.00347078,
       -0.0040526 ,  0.00373018])

In [679]:
dydx_ag

array([ 0.0052105 ,  0.01403135, -0.00236915, -0.00395498, -0.00347078,
       -0.0040526 ,  0.00373018])