In [1]:
import torch
import torch.nn as nn

In [None]:
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_forward = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.rnn_backward = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(hidden_size * 2, output_size)
    
    def forward(self, x):
        # Set initial hidden and cell states
        h_forward = torch.zeros(self.num_layers * 1, x.size(0), self.hidden_size).to(x.device)
        h_backward = torch.zeros(self.num_layers * 1, x.size(0), self.hidden_size).to(x.device)
        
        # Forward propagate LSTM
        out_forward, h_forward = self.rnn_forward(x, h_forward)
        
        # Backward propagate LSTM
        x_backward = torch.flip(x, [1])
        out_backward, h_backward = self.rnn_backward(x_backward, h_backward)
        out_backward = torch.flip(out_backward, [1])
        
        # Concatenate forward and backward outputs
        out = torch.cat((out_forward, out_backward), dim=2)
        
        # Pass output through fully connected layer
        out = self.fc(out)
        
        return out