In [92]:
"""
An Implementation of the method Neural Ordinary Differential 
Equation presented in: https://arxiv.org/abs/1806.07366


TODO: 
- implement a residual neural network
# - add the training loop
# - add the backpropagation

- implement a neural ODE


NOTES:
- residual structure doesn't make sense? inputs and outputs in the 
  residual block are being broadcasted as they don't have the same 
  dimensions. Also specifiying different depths has no effect on 
  the model predictions.

"""

import numpy as np

"""
Initialse the model parameters.
"""
def init_weights(layers, scale=1.0, seed=0):
    rng = np.random.RandomState(seed)
    
    #########################################
    # TODO: add in additional weights layer
    #########################################
    
    return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n in zip(layers[:-1], layers[1:])]

"""
A basic residual neural network model set up so that 
skips are performed between layers of equal dimensions.
"""
class residual_NN:    
    def __init__(self, layers, skips):
        
        # intialise the parameters
        self.weights = init_weights(layers)
        self.A = []
        self.skips = skips
        
        # hyperparams
        self.lr = 1e-3
    
    """
    Get the forward prediction of shape (batch_size, state_dim)
    """
    def __call__(self, X):     
        
        A_log = [X]
        for w, b in self.weights:    
            
            # linear + activation
            Z = np.dot(X, w) + b  
            A = np.tanh(Z) 
            X = A
            
            # log hidden states
            A_log.append(A)
            
        # set intermediate states
        self.A = A_log[:-1] + [Z]
        
        return Z   
    
    """
    Update the model weights.
    """
    def step(self, Y):
        
        for idx, (w, b) in reversed(list(enumerate(self.weights))):
            
            #############################
            # TODO: check db calculation
            ############################
            
            # compute the cost function
            if idx == (len(self.weights) - 1):                
                dz = (2/Y.shape[1]) * np.sum(self.A[idx] - Y, axis=1, keepdims=True) 
                dw = -np.dot(self.A[idx-1].T, dz)
                db = np.sum(dz, axis=1, keepdims=True)
                self.weights[idx] = (w + dw * self.lr, b + db * self.lr)
                continue
            
            # update the hidden layers               
            dz = np.dot(dz, self.weights[idx + 1][0].T) * (1 - np.square(np.tanh(self.A[idx+1])))
            dw = np.dot(self.A[idx].T, dz)
            db = np.sum(dz, axis=1, keepdims=True)
            self.weights[idx] = (w + dw * self.lr, b + db * self.lr)            
        
    
    
"""
Simple Mean-Squared Error Loss
"""
def mse_loss(true, pred):    
    return np.mean(np.sum(np.square(true - pred), axis=1))

"""
Run the training loop for the residual model.
"""
def train_model(model, dataset, loss_func, epochs=10, batch_size=32):
    
    y, x = dataset[0], dataset[1]    
    for ep in range(epochs):
        iters = int(len(dataset) // batch_size) + 1
        for it in iters:
            
            # get a batch of data
            x_tr = x[it * batch_size: min((it + 1) * batch_size, len(dataset)) , :]
            y_tr = x[it * batch_size: min((it + 1) * batch_size, len(dataset)) , :]
            
            # get the prediction
            y_pred = model(x_tr)
            loss = loss_func(y_tr, y_pred)
            
            # update the weights 
            model.step(y_tr)
                        
            # display loss
            print('Ep: {} - Loss: {}'.format(ep, loss))
            
        
        
        
        
    
    


model = residual_NN(
    layers=[2, 20, 20, 1],
    skips=[(1, 3)],
)


model(X=np.ones((10, 2)))
model.step(Y=np.ones((10, 1)))