In [23]:
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
keras = tf.keras
initializers = tf.keras.initializers

tf.__version__

'1.12.0'

In [2]:
sess = tf.InteractiveSession()

In [3]:
random = np.random.RandomState(34121)

In [4]:
N = 2
h_0 = np.ones([1, N])
h_0 = tf.to_float(h_0)

In [5]:
layer = keras.layers.Dense(N, kernel_initializer=initializers.random_normal(stddev=0.01), use_bias=False)

tf.stop_gradient(layer.weights)    

def module(state):
    output = layer(state)    
    return output

In [6]:
def ODEFunc(state, t):
    h_t = state
    h_output = module(h_t)
    return h_output


def ODEFuncBackward(state, t):    
    h_t = state
    h_output = -module(h_t)
    return h_output

In [7]:
output, info = tf.contrib.integrate.odeint(
    func=ODEFunc,
    y0=h_0,
    t=[0.0, 100.0],
    rtol=1e-6,
    atol=1e-6,
    full_output=True,
)

In [8]:
sess.run(tf.variables_initializer(layer.weights))

In [9]:
# sess.run(output)

In [10]:
exp_Wt = tf.linalg.expm(
    layer.weights[0] * 100,
    name=None
)

tf.matmul(h_0, exp_Wt).eval()

array([[-2.0738602,  3.1465704]], dtype=float32)

In [11]:
# sess.run(layer(tf.to_float(h_0)))

In [12]:
output

<tf.Tensor 'odeint:0' shape=(2, 1, 2) dtype=float32>

In [13]:
h_N = output[-1, ...]
h_N

<tf.Tensor 'strided_slice:0' shape=(1, 2) dtype=float32>

In [14]:
loss = tf.reduce_sum(h_N**2)
loss

<tf.Tensor 'Sum:0' shape=() dtype=float32>

In [15]:
loss.eval()

14.2017975

In [16]:
dfLdh0 = tf.gradients(loss, h_0)[0]

In [17]:
dfLdh0

<tf.Tensor 'gradients/AddN:0' shape=(1, 2) dtype=float32>

In [18]:
dfLdh0.eval()

array([[12.035124, 16.368471]], dtype=float32)

In [26]:
inv_output, inv_info = tf.contrib.integrate.odeint(
    func=ODEFuncBackward,
    y0=h_N,
    t=[0.0, 100.0],
    rtol=1e-6,
    atol=1e-12,
    full_output=True,
)

In [27]:
inv_output[-1, ...].eval()

array([[1.0000006, 1.       ]], dtype=float32)

In [28]:
dfLdhN = tf.gradients(loss, h_N)[0]
dfLdhN

<tf.Tensor 'gradients_2/pow_grad/Reshape:0' shape=(1, 2) dtype=float32>

In [29]:
dfLdhN.eval()

array([[-4.1477146,  6.2931433]], dtype=float32)

In [82]:
def dynamics(ht, dL):
    ht = tf.stop_gradient(ht)
    dL = tf.stop_gradient(dL)
    with tf.Graph().as_default() as g:
        h_output = - module(ht)
        jvp_h = - tf.gradients(h_output, ht, grad_ys=dL)[0]
    
    return h_output, jvp_h


def ODEFuncBackwardWithGrad(state, t):    
    hN, dfLdhN = tf.unstack(state)
    h_output, jvp_h = dynamics(hN, dfLdhN)    
    return tf.stack([h_output, jvp_h])

In [83]:
h_output, jvp_h = tfe.py_func(dynamics, [h_N, dfLdhN], [tf.float32, tf.float32])

jvp_h

<tf.Tensor 'EagerPyFunc_14:1' shape=<unknown> dtype=float32>

In [84]:
inv_output, inv_info = tf.contrib.integrate.odeint(
    func=ODEFuncBackwardWithGrad,
    y0=[h_N, dfLdhN],
    t=[0.0, 100.0],
    rtol=1e-6,
    atol=1e-6,
    full_output=True,
)

In [85]:
h0_rec, dfLdh0_rec = tf.unstack(inv_output[-1, ...])

In [86]:
h0_rec.eval()

array([[1.0000011 , 0.99999994]], dtype=float32)

In [87]:
dfLdh0_rec.eval()

array([[12.035125, 16.368467]], dtype=float32)

In [88]:
loss_exp = tf.reduce_sum(tf.matmul(h_0, exp_Wt)**2)
dfLdh0_exp = tf.gradients(loss_exp, h_0)[0]

In [89]:
dfLdh0_exp.eval() - dfLdh0_rec.eval()

array([[-1.9073486e-06,  1.3351440e-05]], dtype=float32)

In [90]:
h_N, dfLdhN

(<tf.Tensor 'strided_slice:0' shape=(1, 2) dtype=float32>,
 <tf.Tensor 'gradients_2/pow_grad/Reshape:0' shape=(1, 2) dtype=float32>)

In [124]:
def pack_variables(ht, at, theta):
    tensor = tf.concat([tf.reshape(ht, [-1]), tf.reshape(at, [-1]), tf.reshape(theta, [-1])], 0)    
    return tensor


def unpack_variables(state):
    ht = tf.reshape(state[:2], [1, 2])
    at = tf.reshape(state[2:4], [1, 2])
    theta = tf.reshape(state[4:], [2, 2])    
    return ht, at, theta


share_variables = lambda func: tf.make_template(
    func.__name__, func, create_scope_now_=True)


@share_variables
def module_backprop_template(ht, at): 
    ht = tf.stop_gradient(ht)    
    W = tf.stop_gradient(layer.weights[0])        
    # ht_output = - module(ht) # does not work :(
    ht_output = - tf.matmul(ht, W)    
    jvp_z = - tf.gradients(ht_output, ht, grad_ys=at)[0]    
    jvp_theta = - tf.gradients(ht_output, W, grad_ys=at)[0]
    return ht_output, jvp_z, jvp_theta


def ODEFuncBackwardFullWithGrad(state, t):    

    hN, dfLdhN, theta = unpack_variables(state)   
    ht_output, jvp_z, jvp_theta = module_backprop_template(hN, dfLdhN)    
    return pack_variables(ht_output, jvp_z, jvp_theta)

In [125]:
pack_variables(h_N, dfLdhN, tf.zeros_like(layer.weights[0]))

<tf.Tensor 'concat_4:0' shape=(8,) dtype=float32>

In [126]:
h_last = tf.to_float(h_N.eval())
dfLdhN_last = tf.to_float(dfLdhN.eval())

In [127]:
with tf.name_scope("asd"):
    inv_output= tf.contrib.integrate.odeint(
        func=ODEFuncBackwardFullWithGrad,
        y0=pack_variables(h_last, dfLdhN_last, tf.zeros_like(layer.weights[0])),
        t=[0.0, 100.0],
        rtol=1e-6,
        atol=1e-6,
        full_output=False,
    )

ValueError: Cannot compute gradient inside while loop with respect to op 'odeint/dense/kernel'. We do not support taking the gradient wrt or through the initial value of a loop variable. Gradients can be computed through loop invariants or wrt the input parameters to the loop body.

In [107]:
h0_rec, dfLdh0_rec, dLdTheta_rec = unpack_variables(inv_output[-1, ...])

In [108]:
h0_rec.eval()

array([[1.000001 , 0.9999995]], dtype=float32)

In [109]:
dfLdh0_rec.eval()

array([[12.0351305, 16.368464 ]], dtype=float32)

In [110]:
dLdTheta_rec.eval()

array([[ 369.10162 ,  256.72568 ],
       [  26.418175, 2471.258   ]], dtype=float32)

In [111]:
# sess.run(module_backprop_template(h_last, dfLdhN_last))

In [112]:
loss_exp = tf.reduce_sum(tf.matmul(h_0, exp_Wt)**2)
dfLdTheta_exp = tf.gradients(loss_exp, layer.weights[0])[0]

In [115]:
dfLdTheta_exp.eval() - dLdTheta_rec.eval()

array([[ 3.0517578e-05, -1.8920898e-03],
       [-1.3008118e-03,  4.8828125e-04]], dtype=float32)

In [114]:
tf.gradients(loss, layer.weights[0])[0].eval()

array([[ 369.09937 ,  256.72662 ],
       [  26.416153, 2471.255   ]], dtype=float32)