<a href="https://colab.research.google.com/github/hissain/mlworks/blob/main/codes/GRU_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import
import torch
import torch.nn as nn
import math

In [None]:
# Seed
torch.manual_seed(0)

<torch._C.Generator at 0x7da1403a5ad0>

In [None]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Reset gate weights
        self.W_xr = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hr = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_r = nn.Parameter(torch.Tensor(hidden_size))

        # Update gate weights
        self.W_xz = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hz = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_z = nn.Parameter(torch.Tensor(hidden_size))

        # Candidate hidden state weights
        self.W_xh = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_h = nn.Parameter(torch.Tensor(hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, input, hidden_state=None):
        """
        Performs a forward pass through the GRU cell.

        Args:
            input: A tensor of shape (batch_size, input_size) representing the input sequence.
            hidden_state: A tensor of shape (batch_size, hidden_size) representing the initial hidden state (optional).

        Returns:
            hidden_state: A tensor of shape (batch_size, hidden_size) representing the hidden state.
        """
        batch_size, _ = input.size()

        if hidden_state is None:
            hidden_state = torch.zeros(batch_size, self.hidden_size)

        # Reset gate
        reset_gate = torch.sigmoid(input @ self.W_xr + hidden_state @ self.W_hr + self.b_r)
        # Update gate
        update_gate = torch.sigmoid(input @ self.W_xz + hidden_state @ self.W_hz + self.b_z)
        # Candidate activation
        candidate_activation = torch.tanh(input @ self.W_xh + (reset_gate * hidden_state) @ self.W_hh + self.b_h)

        # Compute new hidden state
        hidden_state = update_gate * hidden_state + (1 - update_gate) * candidate_activation

        return hidden_state


In [None]:
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.gru_cell = GRUCell(input_size, hidden_size)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, inputs):
        """
        Performs a forward pass through the RNN model.

        Args:
            inputs: A tensor of shape (batch_size, seq_len, input_size) representing the input sequence.

        Returns:
            prediction: A tensor of shape (batch_size, output_size) representing the model output.
        """
        hidden_state = None
        cell_state = None
        _, seq_len, _ = inputs.size()

        for t in range(seq_len):
            hidden_state = self.gru_cell(
                inputs[:, t, :],
                hidden_state
            )

        prediction = self.fc(hidden_state)

        return prediction, hidden_state

In [None]:
# Example usage:
input_size = 10
output_size = 10
hidden_size = 20
seq_length = 5
batch_size = 2

# Create GRU model
gru = GRU(input_size, hidden_size, output_size)

# Generate some random input data
input_data = torch.randn(batch_size, seq_length, input_size)

# Forward pass
output, hidden_state_last = gru(input_data)
print("Output shape:", output.shape)
print("Last hidden state shape:", hidden_state_last.shape)


Output shape: torch.Size([2, 10])
Last hidden state shape: torch.Size([2, 20])
