In [8]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming vocab_size is the size of the vocabulary
# embedding_dim, hidden_dim, and output_dim are hyperparameters to be defined

class Seq2SeqBiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, output_dim):
        super(Seq2SeqBiLSTM, self).__init__()

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # Encoder LSTM
        self.encoder = nn.LSTM(embedding_dim, encoder_hidden_dim, batch_first=True)

        # Decoder LSTM
        self.decoder = nn.LSTM(embedding_dim, decoder_hidden_dim, batch_first=True)

        # BiLSTM layer
        self.bilstm = nn.LSTM(decoder_hidden_dim, decoder_hidden_dim, batch_first=True, bidirectional=True)

        # Dense layer and Softmax for output
        self.fc = nn.Linear(decoder_hidden_dim * 2, output_dim)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_seq):
        # Input to embedding layer
        embedded = self.embedding(input_seq)

        # Encoder LSTM
        _, (hidden, cell) = self.encoder(embedded)

        # Decoder LSTM with encoder's final hidden state as initial hidden state
        decoder_output, _ = self.decoder(embedded, (hidden, cell))

        # BiLSTM layer
        bilstm_output, _ = self.bilstm(decoder_output)

        # Dense layer and softmax
        output = self.softmax(self.fc(bilstm_output))

        return output

# Example usage
model = Seq2SeqBiLSTM(vocab_size=10000, embedding_dim=256, encoder_hidden_dim=128, decoder_hidden_dim=128, output_dim=10000)
input_seq = torch.randint(0, 10000, (1, 10)) # Example input sequence of length 10
output = model(input_seq)


In [9]:
print(output)

tensor([[[-2.2818, -2.3091, -2.3095,  ..., -2.3024, -2.3050, -2.3151],
         [-2.2804, -2.2994, -2.3000,  ..., -2.2999, -2.3327, -2.2951],
         [-2.2882, -2.3082, -2.3193,  ..., -2.2965, -2.3286, -2.2981],
         ...,
         [-2.3025, -2.3005, -2.2908,  ..., -2.3014, -2.2856, -2.2946],
         [-2.3041, -2.3021, -2.2943,  ..., -2.3050, -2.3008, -2.3110],
         [-2.2867, -2.3122, -2.2962,  ..., -2.3263, -2.2943, -2.3123]]],
       grad_fn=<LogSoftmaxBackward0>)
