In [1]:
sentence = "我 喜欢 学习 深度 学习"
words = sentence.split()
word2idx = {word: idx for idx, word in enumerate(set(words))}
idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(word2idx)
input_seq = [] # 输入词索引
target_seq = [] # 目标词索引

for i in range(len(words) - 1):
    input_seq.append(word2idx[words[i]])
    target_seq.append(word2idx[words[i + 1]])

print(idx2word, input_seq, target_seq)

{0: '我', 1: '学习', 2: '喜欢', 3: '深度'} [0, 2, 1, 3] [2, 1, 3, 1]


In [2]:
import torch

input_tensor = torch.tensor(input_seq)     # shape: (seq_len,)
target_tensor = torch.tensor(target_seq)   # shape: (seq_len,)

In [3]:
import torch
import torch.nn as nn

class SimpleGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(SimpleGRU, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)               # shape: (batch, seq_len, embed_dim)
        output, h_n = self.gru(x)           # output: (batch, seq_len, hidden_dim)
        logits = self.fc(output)            # shape: (batch, seq_len, vocab_size)
        return logits

input_tensor = input_tensor.unsqueeze(0)     # (1, seq_len)
target_tensor = target_tensor.unsqueeze(0)   # (1, seq_len)
model = SimpleGRU(vocab_size=vocab_size, embedding_dim=16, hidden_dim=32)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()
    output = model(input_tensor)

    output = output.view(-1, vocab_size)
    target = target_tensor.view(-1)

    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))


Epoch: 0001 cost = 1.467389
Epoch: 0011 cost = 0.194275
Epoch: 0021 cost = 0.009332
Epoch: 0031 cost = 0.001860
Epoch: 0041 cost = 0.000920
Epoch: 0051 cost = 0.000670
Epoch: 0061 cost = 0.000570
Epoch: 0071 cost = 0.000516
Epoch: 0081 cost = 0.000480
Epoch: 0091 cost = 0.000451
