In [1]:
import numpy as np
import gradflow.comp_graph as cg

import autograd as ag

In [2]:
d_in = 7
d_out = 1
d_hid = 13

B = 17 # batch size
w_A_np = np.random.normal(size=(d_in, d_hid))
w_C_np = np.random.normal(size=(d_hid, d_out))
x_np = np.random.normal(size=(B, d_in))
omega_true = np.random.normal(size=(d_in, 1))
def true_fct(z):
    return np.cos(np.dot(z, omega_true))
y_true = np.array([true_fct(z) for z in x_np])

x = cg.Value("x", x_np) # input vector

w_A = cg.Value("w_A", w_A_np) # params in linear layer
linear1 = cg.Dot("linear1")

act1 = cg.Tanh("act1")

w_C = cg.Value("w_C", w_C_np) # params in a linear layer
linear2 = cg.Dot("linear2")

# loss function comput node
loss = cg.MSELoss("mse")

# we use calls to forward pass to define the function topology
x.forward()
w_A.forward()
w_C.forward()
y = linear1.forward(x, w_A)
y = act1.forward(linear1)
y = linear2.forward(act1, w_C)

# finally we pass inputs through a loss function that 
# returns a scalar
l = loss.forward(linear2, y_true)

In [3]:
def test_embed_function(input):
    z = ag.numpy.dot(input, w_A_np)
    z = ag.numpy.tanh(z)
    z = ag.numpy.dot(z, w_C_np)
    return z

def test_loss_function(input):
    z = test_embed_function(input)
    l = ag.numpy.mean( (z - y_true)**2. , keepdims=True)
    return l

In [4]:
atol = 1e-5

y_ag = test_embed_function(x_np)
l_ag = test_loss_function(x_np)
assert np.allclose(y, y_ag, atol=atol)
assert np.allclose(l, l_ag, atol=atol)

In [5]:
my_graph = cg.Graph("my_fun", [loss])
my_graph.backward()

In [7]:
ag_grad_fn = ag.elementwise_grad(test_loss_function)
dl_dx_ag = ag_grad_fn(x_np)

assert np.allclose(x.d_out, dl_dx_ag, atol=atol)

In [8]:
x.d_out.shape

(17, 7)

In [9]:
dl_dx_ag.shape

(17, 7)

In [10]:
x.d_out

array([[ 3.58767055e-02,  4.33668953e-01, -2.14073174e-01,
         3.39136650e-01,  1.38235952e-01,  2.24035787e-01,
         1.77872455e-01],
       [ 2.03627101e-01,  1.60989384e+00, -1.05404200e-01,
         4.95893915e-01,  5.20818440e-01,  4.93407717e-01,
         1.42524287e+00],
       [-1.91649967e-02, -1.69900280e-01, -3.61757251e-02,
        -1.03548954e-01,  1.57852559e-02, -2.03279658e-01,
        -3.10751027e-02],
       [ 2.32922242e-01, -1.29817930e+00,  4.04538472e-01,
        -6.10956506e-01, -5.77378780e-01, -7.17980264e-01,
        -1.30869977e+00],
       [-1.55031884e-01,  5.85983935e-01,  4.07419486e-01,
         1.54504465e-01,  5.20052895e-01,  1.93103915e-02,
         1.26545315e+00],
       [-2.61507373e-01, -5.65959511e-01,  4.60237749e-01,
        -4.96455128e-01, -2.20096398e-01, -3.19849182e-01,
        -9.36647949e-01],
       [-1.18756977e+00,  1.05902265e+00, -1.17751263e+00,
         1.16205000e+00,  7.94625309e-01,  2.23744876e-01,
         1.1687716

In [11]:
dl_dx_ag

array([[ 3.58767055e-02,  4.33668953e-01, -2.14073174e-01,
         3.39136650e-01,  1.38235952e-01,  2.24035787e-01,
         1.77872455e-01],
       [ 2.03627101e-01,  1.60989384e+00, -1.05404200e-01,
         4.95893915e-01,  5.20818440e-01,  4.93407717e-01,
         1.42524287e+00],
       [-1.91649967e-02, -1.69900280e-01, -3.61757251e-02,
        -1.03548954e-01,  1.57852559e-02, -2.03279658e-01,
        -3.10751027e-02],
       [ 2.32922242e-01, -1.29817930e+00,  4.04538472e-01,
        -6.10956506e-01, -5.77378780e-01, -7.17980264e-01,
        -1.30869977e+00],
       [-1.55031884e-01,  5.85983935e-01,  4.07419486e-01,
         1.54504465e-01,  5.20052895e-01,  1.93103915e-02,
         1.26545315e+00],
       [-2.61507373e-01, -5.65959511e-01,  4.60237749e-01,
        -4.96455128e-01, -2.20096398e-01, -3.19849182e-01,
        -9.36647949e-01],
       [-1.18756977e+00,  1.05902265e+00, -1.17751263e+00,
         1.16205000e+00,  7.94625309e-01,  2.23744876e-01,
         1.1687716