RNN的基本原理是通过循环来传递隐藏状态信息，从而实现对序列数据的建模。

In [2]:
sentence = "我 喜欢 学习 深度 学习"
words = sentence.split()
word2idx = {word: idx for idx, word in enumerate(set(words))}
idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(word2idx)
input_seq = [] # 输入词索引
target_seq = [] # 目标词索引

for i in range(len(words) - 1):
    input_seq.append(word2idx[words[i]])
    target_seq.append(word2idx[words[i + 1]])

print(idx2word, input_seq, target_seq)

{0: '喜欢', 1: '我', 2: '学习', 3: '深度'} [1, 0, 2, 3] [0, 2, 3, 2]


In [3]:
import torch

input_tensor = torch.tensor(input_seq)     # shape: (seq_len,)
target_tensor = torch.tensor(target_seq)   # shape: (seq_len,)
print(input_tensor, target_tensor)

tensor([1, 0, 2, 3]) tensor([0, 2, 3, 2])


In [5]:
import torch.nn as nn
class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(input_size = embedding_dim, hidden_size=hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size) # 输出词表中每个词的概率分布

    def forward(self, x):
        x = self.embedding(x)                # (batch, seq_len) → (batch, seq_len, embed_dim)
        output, hidden = self.rnn(x)         # output: (batch, seq_len, hidden_dim)
        logits = self.fc(output)             # (batch, seq_len, vocab_size)
        return logits

In [6]:
model = RNN(vocab_size=vocab_size, embedding_dim=16, hidden_dim=32)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):  # 假设训练 100 轮
    optimizer.zero_grad()             # 清空梯度
    output = model(input_tensor)      # 前向传播，输出 shape: (1, seq_len, vocab_size)

    # 注意：loss 函数要的输入是 (N, C)，目标是 (N,)
    output = output.view(-1, vocab_size)        # reshape 成 (seq_len, vocab_size)
    target = target_tensor.view(-1)             # reshape 成 (seq_len,)

    loss = loss_fn(output, target)    # 计算交叉熵损失
    loss.backward()                   # 反向传播
    optimizer.step()                  # 更新参数

    if epoch % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))


Epoch: 0001 cost = 1.371842
Epoch: 0011 cost = 0.064620
Epoch: 0021 cost = 0.006817
Epoch: 0031 cost = 0.002382
Epoch: 0041 cost = 0.001422
Epoch: 0051 cost = 0.001089
Epoch: 0061 cost = 0.000930
Epoch: 0071 cost = 0.000834
Epoch: 0081 cost = 0.000763
Epoch: 0091 cost = 0.000705
