In [1]:
import numpy as np
import numpy.random as npr
import tensorflow as tf
keras = tf.keras

In [2]:
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()

In [3]:
class Module(tf.keras.Model):
    
  def __init__(self):
    super(Module, self).__init__(name='Module')        
    self.dense_1 = keras.layers.Dense(3, activation='sigmoid')
    self.dense_2 = keras.layers.Dense(3, activation='sigmoid')

  def call(self, inputs):    
    x = self.dense_1(inputs)    
    return self.dense_2(x)

  def compute_output_shape(self, input_shape):
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = 3
    return tf.TensorShape(shape)

In [277]:
def EulerSolver(fun, state, dt):    
    dstate = fun(state)
    return [h + dt * dh for h, dh in zip(state, dstate)]


def RK2Solver(fun, state, dt):    
    k1 = fun(state)    
    k2 = fun([s + dt * k for s,k in zip(state, k1)])
    return [h + dt * (a + b)/2 for h, a, b in zip(state, k1, k2)]


class NeuralODE:
    def __init__(self, module_fn, solver, num_steps: int = 40):
        self._num_steps = num_steps
        self._dt = 1 / num_steps
        self._module_fn = module_fn
        self._solver = solver
        
    def forward(self, h_input):
        
        def _forward_dynamics(h):
            return [self._module_fn(h[0])]
        
        state = [h_input]
        for k in range(self._num_steps):            
            state = self._solver(_forward_dynamics, state, self._dt)
        return state[0]
    
    def _backward_dynamics(self, state):
        ht = state[0]
        at = - state[1]
        
        with tf.GradientTape() as g:
            g.watch(ht)
            ht_new = self._module_fn(ht)

        gradients = g.gradient(
            target= ht_new, 
            sources=[ht] + self._module_fn.weights, 
            output_gradients=at
        )
        return [ht_new, *gradients]
        
    def backward(self, h_output, grad_h_output):
        
        num_weights = len(self._module_fn.weights) + 1
        dWeights = [tf.zeros_like(w) for w in self._module_fn.weights]
        neg_dt = - self._dt
        ht = h_output
        at = grad_h_output
        
        state = [h_output, grad_h_output, *dWeights]
        for k in range(self._num_steps):
            
            state = self._solver(self._backward_dynamics, state, neg_dt)
            
        return state[0], state[1], state[2:]  

In [278]:
mdl = Module()

In [279]:
input_h = tf.to_float(np.random.randn(*[1, 3]))

In [280]:
input_h

<tf.Tensor: id=61087, shape=(1, 3), dtype=float32, numpy=array([[ 0.39672357, -0.17706469,  0.5449612 ]], dtype=float32)>

In [303]:
neural_ode = NeuralODE(mdl, EulerSolver, num_steps=10)

In [304]:
h_end = neural_ode.forward(input_h)

In [305]:
with tf.GradientTape() as g:
    g.watch(h_end)
    loss = tf.reduce_sum(h_end**2)
    
dLoss_dh_end = g.gradient(loss, h_end)
dLoss_dh_end

<tf.Tensor: id=90728, shape=(1, 3), dtype=float32, numpy=array([[1.644217  , 0.36412543, 2.3599567 ]], dtype=float32)>

In [306]:
with tf.GradientTape() as g:
    g.watch(input_h)
    h_end = neural_ode.forward(input_h)    

auto_gradients = g.gradient(h_end, [input_h, *mdl.weights], dLoss_dh_end)

In [307]:
with tf.GradientTape() as g:
    g.watch(input_h)
    h_end = neural_ode.forward(input_h)
    loss = tf.reduce_sum(h_end**2)

auto_gradients_loss = g.gradient(loss, [input_h, *mdl.weights])

In [308]:
[tf.reduce_sum(ag - ag_loss).numpy() for ag, ag_loss in zip(auto_gradients, auto_gradients_loss)]

[0.0, 0.0, 0.0, 0.0, 0.0]

In [309]:
h_start, dfdh0, dWeights = neural_ode.backward(h_end, dLoss_dh_end)

In [312]:
tf.abs((h_start - input_h) / h_start)

<tf.Tensor: id=91823, shape=(1, 3), dtype=float32, numpy=array([[0.00130136, 0.00374753, 0.00073708]], dtype=float32)>

In [313]:
(dfdh0 - auto_gradients[0]) / auto_gradients[0]

<tf.Tensor: id=91826, shape=(1, 3), dtype=float32, numpy=array([[ 1.3487248e-05, -2.1883039e-05,  2.4345043e-06]], dtype=float32)>

In [314]:
[tf.reduce_sum(tf.abs(dw - auto_wd)).numpy() for dw, auto_wd in zip(dWeights, auto_gradients[1:])]

[0.017087633, 0.00011493638, 0.0060084797, 0.00045181066]

In [315]:
neural_ode_rk2 = NeuralODE(mdl, RK2Solver, num_steps=10)

In [319]:
h_end_rk2 = neural_ode_rk2.forward(input_h)

In [320]:
with tf.GradientTape() as g:
    g.watch(h_end_rk2)
    loss = tf.reduce_sum(h_end_rk2**2)
    
dLoss_dh_end = g.gradient(loss, h_end_rk2)
dLoss_dh_end

<tf.Tensor: id=93405, shape=(1, 3), dtype=float32, numpy=array([[1.6447359 , 0.36479747, 2.3595555 ]], dtype=float32)>

In [328]:
with tf.GradientTape() as g:
    g.watch(input_h)
    h_end_rk2 = neural_ode_rk2.forward(input_h)    

auto_gradients = g.gradient(h_end_rk2, [input_h, *mdl.weights], dLoss_dh_end)

In [329]:
h_start, dfdh0, dWeights = neural_ode_rk2.backward(h_end_rk2, dLoss_dh_end)

In [330]:
(h_start - input_h) / h_start

<tf.Tensor: id=96376, shape=(1, 3), dtype=float32, numpy=array([[-2.2536344e-07, -0.0000000e+00,  0.0000000e+00]], dtype=float32)>

In [331]:
dfdh0

<tf.Tensor: id=96350, shape=(1, 3), dtype=float32, numpy=array([[1.671073 , 0.3547482, 2.3500018]], dtype=float32)>

In [332]:
(dfdh0 - auto_gradients[0]) / auto_gradients[0]

<tf.Tensor: id=96380, shape=(1, 3), dtype=float32, numpy=array([[ 1.4267397e-07, -3.3603905e-07, -2.0290923e-07]], dtype=float32)>

In [333]:
[tf.reduce_sum(tf.abs(dw - auto_wd)).numpy() for dw, auto_wd in zip(dWeights, auto_gradients[1:])]

[5.296897e-09, 7.450581e-09, 1.3411045e-07, 8.940697e-08]

In [335]:
input_h = tf.to_float(np.random.randn(*[32, 3]))
neural_ode = NeuralODE(mdl, RK2Solver, num_steps=10)

In [336]:
h_end = neural_ode_rk2.forward(input_h)

In [340]:
with tf.GradientTape() as g:
    g.watch(h_end)
    loss = tf.reduce_sum(h_end**2)
    
dLoss_dh_end = g.gradient(loss, h_end)
dLoss_dh_end.shape

TensorShape([Dimension(32), Dimension(3)])

In [341]:
h_start, dfdh0, dWeights = neural_ode.backward(h_end, dLoss_dh_end)

In [343]:
with tf.GradientTape() as g:
    g.watch(input_h)
    h_end = neural_ode.forward(input_h)    
    loss = tf.reduce_sum(h_end**2)

auto_gradients = g.gradient(loss, [input_h, *mdl.weights])

In [346]:
[tf.reduce_sum(tf.abs(dw - auto_wd)).numpy() for dw, auto_wd in zip([dfdh0, *dWeights], auto_gradients)]

[2.1945685e-05, 4.656613e-07, 1.1920929e-07, 2.5331974e-06, 9.536743e-07]