In [1]:
sentence = "我 喜欢 学习 深度 学习"
words = sentence.split()
word2idx = {word: idx for idx, word in enumerate(set(words))}
idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(word2idx)
input_seq = [] # 输入词索引
target_seq = [] # 目标词索引

for i in range(len(words) - 1):
    input_seq.append(word2idx[words[i]])
    target_seq.append(word2idx[words[i + 1]])

print(idx2word, input_seq, target_seq)

{0: '深度', 1: '我', 2: '喜欢', 3: '学习'} [1, 2, 3, 0] [2, 3, 0, 3]


In [3]:
import torch

input_tensor = torch.tensor(input_seq)     # shape: (seq_len,)
target_tensor = torch.tensor(target_seq)   # shape: (seq_len,)

In [4]:
import torch
import torch.nn as nn

class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(SimpleLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)                  # shape: (batch, seq_len, embedding_dim)
        output, (hn, cn) = self.lstm(embedded)        # output: (batch, seq_len, hidden_dim)
        logits = self.fc(output)                      # shape: (batch, seq_len, vocab_size)
        return logits

In [6]:
input_tensor = input_tensor.unsqueeze(0)     # shape: (1, seq_len)
target_tensor = target_tensor.unsqueeze(0)   # shape: (1, seq_len)
model = SimpleLSTM(vocab_size=vocab_size, embedding_dim=16, hidden_dim=32)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()
    output = model(input_tensor)

    output = output.view(-1, vocab_size)
    target = target_tensor.view(-1)

    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

Epoch: 0001 cost = 1.472496
Epoch: 0011 cost = 0.349428
Epoch: 0021 cost = 0.033909
Epoch: 0031 cost = 0.007014
Epoch: 0041 cost = 0.003042
Epoch: 0051 cost = 0.001956
Epoch: 0061 cost = 0.001527
Epoch: 0071 cost = 0.001302
Epoch: 0081 cost = 0.001155
Epoch: 0091 cost = 0.001044
