In [None]:
import torch
import torch.nn as nn

class BiGRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(BiGRUModel, self).__init__()
        
        # BiGRU Layer
        self.bigru = nn.GRU(
            input_size=input_size,   # size of input vector (e.g., embedding dim)
            hidden_size=hidden_size, # GRU hidden state size
            num_layers=num_layers,   # number of stacked GRU layers
            batch_first=True,        # batch dim comes first: [batch, seq_len, feat]
            bidirectional=True       # makes it BiGRU
        )
        
    def forward(self, x):
        # x shape: [batch, seq_len, input_size]
        
        output, hidden = self.bigru(x)
        # output shape: [batch, seq_len, hidden_size*2]   (because bidirectional)
        # hidden shape: [num_layers*2, batch, hidden_size]
        
        return output, hidden

# -------------------------
# Example usage
batch_size = 2
seq_len = 5
input_size = 10   # each token embedding size = 10
hidden_size = 8   # GRU hidden size

# Random input (batch of 2 sentences, each 5 tokens, each token = 10 dim vector)
x = torch.randn(batch_size, seq_len, input_size)

model = BiGRUModel(input_size=input_size, hidden_size=hidden_size)
output, hidden = model(x)

print("Input shape:", x.shape)        # [2, 5, 10]
print("Output shape:", output.shape)  # [2, 5, 16] (8*2 because bidirectional)
print("Hidden shape:", hidden.shape)  # [2, 2, 8]  (2 directions, batch=2, hidden=8)
