In [None]:
import numpy as np
%matplotlib inline

## RNN

### Backward step computation for one timestep

I) $$ \frac{\partial (tanh(x))}{\partial (x)} = 1 - tanh^2(x)$$

II) $$ \frac{\partial (Ux_t + Ws_{t-1} + b)}{\partial (Ux_t)} = 1$$
$$ \frac{\partial (Ux_t + Ws_{t-1} + b)}{\partial (Ws_{t-1})} = 1$$
$$ \frac{\partial (Ux_t + Ws_{t-1} + b)}{\partial (b)} = 1$$


III.a) $$ \frac{\partial (Ws_{t-1})}{\partial (W)} = s_{t-1}$$
$$ \frac{\partial (Ws_{t-1})}{\partial (s_{t-1})} = W$$

III.b) $$ \frac{\partial (Ux_t)}{\partial (U)} = x_t$$
$$ \frac{\partial (Ux_t)}{\partial (x_t)} = U$$


## Multiple timesteps

In [None]:
class RNN():
    def __init__(self, time_steps, feature_length, batch_size, hidden_length, s_0):
        self.time_steps = time_steps
        self.feature_length = feature_length
        self.batch_size = batch_size
        
        self.s = np.zeros((batch_size, time_steps, hidden_length))
        self.s[:,-1,:] = s_0
        self.cache = []

    def _rnn_step_forward(self, x, prev_s, U, W, b):
        linear_transform = np.dot(x, U) + np.dot(prev_s, W) + b.T
        next_s = np.tanh(linear_transform)
        _cache = (x,
                  prev_s.copy(), 
                  U,
                  W,
                  next_s,
                  linear_transform)
        return {'next_s': next_s, 'cache': _cache}

    def _rnn_step_backward(self, d_next_s, _cache):
        (x, prev_s, U, W, next_s, linear_transform) = _cache

        # I) how much will `linear_transform` vary with an output variation
        d_linear_transform = (1 - np.square(np.tanh(linear_transform))) * (d_next_s)

        # II) how much will `[Ux, Ws, bias]` vary with d_linear_transform
        d_Ux = d_linear_transform
        d_Ws = d_linear_transform
        d_b = np.sum(d_linear_transform, axis=0)

        # III.a) how much will `Ws_{t-1}` vary with [d_Ux, d_Ws, d_b]
        d_W = prev_s.T.dot(d_Ws)
        d_prev_s = W.dot(d_Ws.T).T

        # III.b) how much will `Ux_t` vary with [d_U, d_W, d_b]
        d_x = d_Ux.dot(U.T)
        d_U = x.T.dot(d_Ux)

        return {'d_x': d_x,
                'd_prev_s': d_prev_s,
                'd_U': d_U,
                'd_W': d_W,
                'd_b': d_b}
    
    def forward(self, x, U, W, b):
        for t in range(self.time_steps):
            dict_forward = self._rnn_step_forward(x[:, t, :], 
                                                  self.s[:, t-1, :], 
                                                  U, W, b)
            self.s[:,t,:] = dict_forward['next_s']            
            self.cache.append(dict_forward['cache'])
        return self.s, self.cache
    
    def backward(self, d_s, cache):
        d_x = np.zeros((self.batch_size, self.time_steps, self.feature_length)),
        d_prev_s = np.zeros((self.batch_size, self.hidden_length))
        d_U = np.zeros((self.feature_length, self.hidden_length))
        d_W = np.zeros((self.hidden_length, self.hidden_length))
        d_b = np.zeros((self.hidden_length,))
        d_s = dh.copy()

        for t in reversed(range(self.time_steps)):
            d_s[:, t, :]  += d_prev_s
            dict_backward = self._rnn_step_backward(d_s[:, t, :], self.cache[t])

            d_x[:, t, :] += dict_backward['d_x']
            d_U += dict_backward['d_U']
            d_W += dict_backward['d_W']
            d_b += dict_backward['d_b']

        d_s_0 = dict_backward['d_prev_s']

        return d_x, d_s_0, d_U, d_W, d_b    