In [1]:
import numpy as np
import tensorflow as tf
keras = tf.keras
initializers = tf.keras.initializers

tf.__version__

'1.12.0'

In [2]:
sess = tf.InteractiveSession()

In [3]:
random = np.random.RandomState(34121)

In [4]:
N = 2
h_0 = np.ones([1, N])
h_0 = tf.to_float(h_0)

In [5]:
layer1 = keras.layers.Dense(N, activation="selu")
layer2 = keras.layers.Dense(N, activation="selu")

layers = [layer1, layer2]

def module(state):
    output = state
    for layer in layers:
        output =  layer(state)
    return output

In [6]:
def ODEFunc(state, t):
    h_t = state
    h_output = module(h_t)
    return h_output


def ODEFuncBackward(state, t):    
    h_t = state
    h_output = -module(h_t)
    return h_output

In [38]:
output, info = tf.contrib.integrate.odeint(
    func=ODEFunc,
    y0=h_0,
    t=[0.0, 1.0],
    rtol=1e-7,
    atol=1e-7,
    full_output=True,
)

In [39]:
sess.run(([tf.variables_initializer(l.weights) for l in layers]))

[None, None]

In [40]:
sess.run(output)

array([[[1.        , 1.        ]],

       [[0.7827728 , 0.07143269]]], dtype=float32)

In [41]:
output

<tf.Tensor 'odeint_7:0' shape=(2, 1, 2) dtype=float32>

In [42]:
h_N = output[-1, ...]
h_N

<tf.Tensor 'strided_slice_7:0' shape=(1, 2) dtype=float32>

In [43]:
loss = tf.reduce_sum(h_N**2)
loss

<tf.Tensor 'Sum_1:0' shape=() dtype=float32>

In [44]:
loss.eval()

0.6178359

In [45]:
dfLdh0 = tf.gradients(loss, h_0)[0]

In [46]:
dfLdh0

<tf.Tensor 'gradients_2/AddN_4:0' shape=(1, 2) dtype=float32>

In [47]:
dfLdh0.eval()

array([[1.106395 , 0.2125081]], dtype=float32)

In [48]:
inv_output, inv_info = tf.contrib.integrate.odeint(
    func=ODEFuncBackward,
    y0=h_N,
    t=[0.0, 1.0],
    rtol=1e-7,
    atol=1e-7,
    full_output=True,
)

In [49]:
inv_output[-1, ...].eval()

array([[0.9999998, 0.9999999]], dtype=float32)

In [50]:
dfLdhN = tf.gradients(loss, h_N)[0]
dfLdhN

<tf.Tensor 'gradients_3/pow_1_grad/Reshape:0' shape=(1, 2) dtype=float32>

In [51]:
dfLdhN.eval()

array([[1.5655456 , 0.14286537]], dtype=float32)

In [52]:
def fwd_gradients(ys, xs, d_xs):        
    g = tf.gradients(ys, xs, grad_ys=d_xs)
    return g


def ODEFuncBackwardWithGrad(state, t):    
    hN, dfLdhN = tf.unstack(state)
    h_output = - module(hN)
    jvp_z = - fwd_gradients(h_output, hN, dfLdhN)[0]
    return tf.stack([h_output, jvp_z])

In [57]:
inv_output, inv_info = tf.contrib.integrate.odeint(
    func=ODEFuncBackwardWithGrad,
    y0=tf.stack([h_N, dfLdhN]),
    t=[0.0, 1.0],
    rtol=1e-7,
    atol=1e-7,
    full_output=True,
)

In [58]:
h0_rec, dfLdh0_rec = tf.unstack(inv_output[-1, ...])

In [59]:
h0_rec.eval()

array([[0.9999999, 1.0000004]], dtype=float32)

In [62]:
%timeit dfLdh0_rec.eval()

10.6 ms ± 41.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [78]:
share_variables = lambda func: tf.make_template(
    func.__name__, func, create_scope_now_=True)


@share_variables
def module_backprop_template2(ht, at):
    ht_output = - module(ht)
    jvp_z = - tf.gradients(ht_output, ht, grad_ys=at)[0]
    return ht_output, jvp_z


def ODEFuncBackwardWithGrad(state, t):    
    hN, dfLdhN = tf.unstack(state)
    h_output, jvp_z = module_backprop_template2(hN, dfLdhN)    
    return tf.stack([h_output, jvp_z])

In [83]:
inv_output, inv_info = tf.contrib.integrate.odeint(
    func=ODEFuncBackwardWithGrad,
    y0=tf.stack([h_N, dfLdhN]),
    t=[0.0, 1.0],
    rtol=1e-5,
    atol=1e-5,
    full_output=True,
)

In [84]:
h0_rec, dfLdh0_rec = tf.unstack(inv_output[-1, ...])

In [85]:
h0_rec.eval(), dfLdh0_rec.eval()

(array([[0.9999964, 1.0000768]], dtype=float32),
 array([[1.1063935 , 0.21249634]], dtype=float32))

In [86]:
%timeit dfLdh0_rec.eval()

7.56 ms ± 46.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
