<a href="https://colab.research.google.com/github/hissain/mlworks/blob/main/codes/Stacked_RNN_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import
import torch
import torch.nn as nn
import math

In [None]:
# Seed
torch.manual_seed(0)

<torch._C.Generator at 0x795bec181af0>

In [None]:
class RNNLayer(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNNLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Weight matrices for input and hidden layer connections
        self.W_xh = torch.nn.Parameter(torch.randn(input_size, hidden_size))
        self.W_hh = torch.nn.Parameter(torch.randn(hidden_size, hidden_size))
        # Bias term for hidden layer
        self.b_h = torch.nn.Parameter(torch.zeros(hidden_size))

    def forward(self, input_data, hidden_state=None):
        """
        Performs a forward pass through the RNN layer.

        Args:
            input_data: A tensor of shape (batch_size, input_size) representing the input sequence.
            hidden_state: A tensor of shape (batch_size, hidden_size) representing the initial hidden state (optional).

        Returns:
            output: A tensor of shape (batch_size, hidden_size) representing the hidden state at each time step.
            hidden_state: A tensor of shape (batch_size, hidden_size) representing the hidden state.
        """
        batch_size, _ = input_data.size()

        # Initialize hidden state if not provided
        if hidden_state is None:
            hidden_state = torch.zeros(batch_size, self.hidden_size)

        # Calculate current hidden state
        hidden_state = torch.tanh(
            # (batch_size, input_size) x (input_size, hidden_size)
            # = (batch_size, hidden_size)
            torch.mm(input_data, self.W_xh) + \
            # (batch_size, hidden_size) x (hidden_size, hidden_size)
            # = (batch_size, hidden_size)
            torch.mm(hidden_state, self.W_hh) + \
            # hidden_size
            self.b_h
        )

        return hidden_state

In [None]:
class StackedRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(StackedRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size

        self.layers = nn.ModuleList([RNNLayer(input_size, hidden_size) for _ in range(num_layers)])
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, inputs, hidden_states=None):
        """
        Performs a forward pass through the stacked RNN model.

        Args:
            inputs: A tensor of shape (batch_size, seq_len, input_size) representing the input sequence.
            hidden_states: A list of tensors of shape (batch_size, hidden_size) representing the initial hidden states for each layer (optional).

        Returns:
            prediction: A tensor of shape (batch_size, output_size) representing the model output.
        """
        if hidden_states is None:
            hidden_states = [None] * self.num_layers

        _, seq_len, _ = inputs.size()

        for t in range(seq_len):
            input_data = inputs[:, t, :]
            for layer_idx in range(self.num_layers):
                hidden_states[layer_idx] = self.layers[layer_idx](input_data, hidden_states[layer_idx])
                input_data = hidden_states[layer_idx]

        prediction = self.fc(hidden_states[-1])

        return prediction, torch.stack(hidden_states)

In [None]:
# Example usage:

input_size = 10
output_size = 10

# it is matched with input size.
hidden_size = 10
# HW: implement StackedRNN that supports different input size and hidden size.

seq_length = 5
batch_size = 2
num_layers = 15

# Create LSTM model
rnn = StackedRNN(input_size, hidden_size, output_size, num_layers)

# Generate some random input data
inputs = torch.randn(batch_size, seq_length, input_size)

# Forward pass
output, hidden_states_last = rnn(inputs)
print("Output shape:", output.shape)
print("Last hidden state shape:", hidden_states_last.shape)


Output shape: torch.Size([2, 10])
Last hidden state shape: torch.Size([15, 2, 20])
