In [2]:
import numpy as np

def rnn_step_forward(x, prev_h, Wx, Wh, b):
    """
    Simple RNN hidden state update (forward pass for a single timestep)
    
    Parameters:
    x: input for current timestep (input_dim, 1)
    prev_h: hidden state from previous timestep (hidden_dim, 1)
    Wx: input-to-hidden weights (hidden_dim, input_dim)
    Wh: hidden-to-hidden weights (hidden_dim, hidden_dim)
    b: bias term (hidden_dim, 1)
    
    Returns:
    next_h: next hidden state (hidden_dim, 1)
    """
    # Compute the affine transform
    h_affine = np.dot(Wx, x) + np.dot(Wh, prev_h) + b
    
    # Apply activation function (tanh is commonly used in RNNs)
    next_h = np.tanh(h_affine)
    
    return next_h

def rnn_forward(x_seq, h0, Wx, Wh, b):
    """
    Complete RNN forward pass for a sequence of inputs
    
    Parameters:
    x_seq: sequence of inputs (seq_length, input_dim, 1)
    h0: initial hidden state (hidden_dim, 1)
    Wx, Wh, b: RNN parameters
    
    Returns:
    h_seq: sequence of hidden states (seq_length, hidden_dim, 1)
    """
    seq_length = len(x_seq)
    hidden_dim = h0.shape[0]
    h_seq = np.zeros((seq_length, hidden_dim, 1))
    
    prev_h = h0
    for t in range(seq_length):
        prev_h = rnn_step_forward(x_seq[t], prev_h, Wx, Wh, b)
        h_seq[t] = prev_h
    
    return h_seq

# Define dimensions
input_dim = 3
hidden_dim = 2
seq_length = 4

# Initialize parameters (normally these would be learned)
Wx = np.random.randn(hidden_dim, input_dim) * 0.01  # Input weights
Wh = np.random.randn(hidden_dim, hidden_dim) * 0.01  # Hidden weights
b = np.zeros((hidden_dim, 1))  # Bias

# Initial hidden state (typically zeros)
h0 = np.zeros((hidden_dim, 1))

# Create a random input sequence (4 timesteps, each with 3 features)
x_seq = [np.random.randn(input_dim, 1) for _ in range(seq_length)]

# Forward pass through the RNN
hidden_states = rnn_forward(x_seq, h0, Wx, Wh, b)

# Print the hidden states at each timestep
for t, h in enumerate(hidden_states):
    print(f"Timestep {t} hidden state:\n{h.squeeze()}")

Timestep 0 hidden state:
[ 0.00090802 -0.01208704]
Timestep 1 hidden state:
[-0.00285381  0.04442604]
Timestep 2 hidden state:
[-0.00188001  0.01883054]
Timestep 3 hidden state:
[0.00302879 0.02105912]
