In [63]:
"""
An Implementation of the method Neural Ordinary Differential 
Equation presented in: https://arxiv.org/abs/1806.07366


TODO: 
- implement a residual neural network
# - add the training loop
# - add the backpropagation

- implement a neural ODE


NOTES:
- residual structure doesn't make sense? inputs and outputs in the 
  residual block are being broadcasted as they don't have the same 
  dimensions. Also specifiying different depths has no effect on 
  the model predictions.

"""

import numpy as np

"""
Initialse the model parameters.
"""
def init_weights(layers, scale=1.0, seed=0):
    rng = np.random.RandomState(seed)
    return [(scale * rng.randn(m, n), scale * rng.randn(n))
            for m, n in zip(layers[:-1], layers[1:])]

"""
A basic residual neural network model set up so that 
skips are performed between layers of equal dimensions.
"""
class residual_NN:    
    def __init__(self, layers, skip_start, skip_end):
        
        # intialise the parameters
        self.weights = init_weights(layers)
        self.skip_start = skip_start
        self.skip_end = skip_end
        
        # check the chosen settings
        skip_warning = "Skip layers must be equal dimensions." 
        for idx, l in enumerate(skip_start):            
            assert layers[l] == layers[skip_end[idx]], skip_warning
    
    """
    Get the forward prediction of shape (batch_size, state_dim)
    """
    def __call__(self, inputs):      
        
        prev_inputs = None
        for l, (w, b) in enumerate(self.weights):
            
            # keep track of skipping 
            skip_value = 0
            if l in self.skip_start: prev_inputs = inputs
            if l in self.skip_end: skip_value = prev_inputs 
                        
            # linear + activation
            outputs = np.dot(inputs, w) + b            
            inputs = np.tanh(outputs) + skip_value
                        
        return outputs    
            

model = residual_NN(
    layers=[2, 20, 20, 20, 1],
    skip_start=[1],
    skip_end=[3]
)


model(inputs=np.ones((10, 2)))

(10, 20)
------------
(10, 20)
------------
(10, 20)
------------
(10, 1)
------------


array([[-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146],
       [-0.44257146]])