In [37]:
import numpy as np

# [Issue 1] Forward propagation implementation of SimpleRNN

In [38]:
class Tanh:
    def forward(self, A):
        """return Z"""
        return np.tanh(A)
    
    def backward(self, dZ, Z):
        """return dA"""
        return dZ * (1 - Z**2)

In [39]:
class ReLU:
    def forward(self, A):
        return np.maximum(A, 0)
    
    def backward(self, dZ, Z):
        return dZ * (Z > 0)

In [40]:
class SimpleRNN:
    def __init__(self, activation, w_x, w_h, b):
        self.activation = activation
        self.w_x = w_x
        self.w_h = w_h
        self.b = b

    def forward(self, X):
        new_h = np.zeros((batch_size, n_nodes))
        for seq in range(0, X.shape[1]):
            new_h = self.activation.forward((X[:, seq, :] @ self.w_x) + (new_h @ self.w_h) + self.b).astype(np.float64)
        return new_h

    def backward(self):
        return

# [Problem 2] Experiment of forward propagation with small sequence

In [41]:
x = np.array([[[1, 2], [2, 3], [3, 4]]])/100 # (batch_size, n_sequences, n_features)
w_x = np.array([[1, 3, 5, 7], [3, 5, 7, 8]])/100 # (n_features, n_nodes)
w_h = np.array([[1, 3, 5, 7], [2, 4, 6, 8], [3, 5, 7, 8], [4, 6, 8, 10]])/100 # (n_nodes, n_nodes)
batch_size = x.shape[0] # 1
n_sequences = x.shape[1] # 3
n_features = x.shape[2] # 2
n_nodes = w_x.shape[1] # 4
h = np.zeros((batch_size, n_nodes)) # (batch_size, n_nodes)
b = np.array([1, 1, 1, 1]) # (n_nodes,)

print("batch_size: {}".format(batch_size))
print("n_sequences: {}".format(n_sequences))
print("n_features: {}".format(n_features))
print("n_nodes: {}".format(n_nodes))

batch_size: 1
n_sequences: 3
n_features: 2
n_nodes: 4


In [42]:
rnn = SimpleRNN(Tanh(), w_x, w_h, b)

new_h = rnn.forward(x)
print(new_h)

[[0.79494228 0.81839002 0.83939649 0.85584174]]


<h4>output</h4>

In [43]:
h = np.array([[0.79494228, 0.81839002, 0.83939649, 0.85584174]]) # (batch_size, n_nodes)

<h3>try with ReLU</h3>

In [44]:
rnn = SimpleRNN(ReLU(), w_x, w_h, b)

new_h = rnn.forward(x)
print(new_h)

[[1.12744024 1.2264713  1.32550236 1.41149812]]
