In [None]:
import torch
import torch.nn as nn

class TextRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_classes, bidirectional=False):
        super(TextRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size, num_layers, bidirectional=bidirectional, batch_first=True)
        # 如果是双向 RNN，隐藏层输出维度需要乘以 2
        self.fc = nn.Linear(hidden_size * (2 if bidirectional else 1), num_classes)

    def forward(self, x):
        x = self.embedding(x)  # (batch_size, seq_len) -> (batch_size, seq_len, embed_size)
        out, _ = self.rnn(x)  # (batch_size, seq_len, hidden_size * num_directions)

        # 对于双向 RNN，拼接正向和反向隐藏状态
        if self.rnn.bidirectional:
            out = out[:, -1, :]  # 取最后一个时间步的输出 (batch_size, hidden_size * 2)
        else:
            out = out[:, -1, :]  # 取最后一个时间步的输出 (batch_size, hidden_size)

        out = self.fc(out)  # (batch_size, num_classes)
        return out

