In [1]:
# 加载数据
import torch
from datasets import load_from_disk
data_path = "./data/data"
batch_size = 8
final_model_path = "./model/Poetry_LSTM_model.pt"
data = load_from_disk(data_path)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
poetry = data['poetry']
poetry[0]

'欲出未出光辣达，千山万山如火发。须臾走向天上来，逐却残星赶却月。'

In [2]:
import json
# 加载词典
def load_vocab(vocab_path):
    with open(vocab_path, 'r', encoding='utf-8') as f:
        vocab = json.load(f)
    return vocab

# 加载词汇表
vocab_path = "./data/vocab.json"
vocab = load_vocab(vocab_path)
print(f"词汇表长度:{len(vocab['char_to_id'])}")

词汇表长度:9050


In [3]:
# 最长的诗词
max_len = max(len(s) for s in poetry)
print(max_len)

48


In [4]:
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, dataset, vocab, max_length=48):
        self.max_length = max_length
        self.char_to_id = vocab['char_to_id']
        self.id_to_char = vocab['id_to_char']

        # 将输入文本和标签转换为ID形式
        datas = [self.text_to_id(text) for text in dataset]
        self.inputs = [[self.char_to_id['<SOS>']] + ids for ids in datas]
        self.labels = [ids + [self.char_to_id['<EOS>']] for ids in datas]

        # 对输入进行填充
        self.inputs = self.pad_sequence(self.inputs, pad_value=self.char_to_id['<PAD>'])
        self.labels = self.pad_sequence(self.labels, pad_value=self.char_to_id['<PAD>'])  # 使用 0 标签作为填充值

    def text_to_id(self, text):
        """将字符序列转换为词汇表中的ID"""
        return [self.char_to_id.get(char, self.char_to_id['<UNK>']) for char in text]

    def pad_sequence(self, sequences, pad_value=0):
        """填充序列到最大长度"""
        return [
            seq + [pad_value] * (self.max_length - len(seq)) if len(seq) < self.max_length else seq[:self.max_length]
            for seq in sequences]

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.inputs[idx])
        label_ids = torch.tensor(self.labels[idx])
        return {
            'input_ids': input_ids,
            'labels': label_ids
        }


train_dataset = CustomDataset(poetry, vocab, max_len + 2)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [5]:
import torch.nn as nn
import torch.optim as optim
import numpy as np

# LSTM 语言模型
class LSTMTextGenerator(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256):
        super(LSTMTextGenerator, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = 2
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab['char_to_id']['<PAD>'])
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=self.num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        batch_size = x.shape[0]
        if hidden is None:
            h0 = torch.tensor(np.zeros((self.num_layers, batch_size, self.hidden_size), dtype=np.float32)).to(device)
            c0 = torch.tensor(np.zeros((self.num_layers, batch_size, self.hidden_size), dtype=np.float32)).to(device)
        else:
            h0, c0 = hidden
        x = self.embedding(x)
        output, hidden = self.lstm(x, (h0, c0))
        output = self.fc(output)
        return output, hidden


In [6]:
# 训练和生成函数（简化版）
import os
def train_lstm_model():
    model = LSTMTextGenerator(len(vocab['char_to_id'])).to(device)
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    epoches = 20
    loss_fn = nn.CrossEntropyLoss(ignore_index=vocab['char_to_id']['<PAD>'])  # 忽略填充值
    # 检查模型是否已经存在
    if os.path.isfile(final_model_path):
        print(f"发现模型文件 {final_model_path}，正在加载...")
        model.load_state_dict(torch.load(final_model_path))
        print("模型加载完成")

    # 训练循环
    for epoch in range(epoches):
        model.train()
        total_loss = 0
        for num, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            output, hidden = model(input_ids)
            # 调整输出和标签的形状
            output = output.view(-1, len(vocab['char_to_id']))  # (batch_size * seq_len, vocab_size)
            labels = labels.view(-1)  # (batch_size * seq_len)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if num % 500 == 0:
                print(f"\tBatch {num}/{len(train_loader)}, Loss: {loss.item():.4f}")
        scheduler.step()
        print(f"Epoch {epoch + 1}/{epoches}, Loss: {total_loss / len(train_loader):.4f}")
    # 保存模型
    torch.save(model.state_dict(), final_model_path)

train_lstm_model()

发现模型文件 ./model/Poetry_LSTM_model.pt，正在加载...
模型加载完成
	Batch 0/22849, Loss: 4.8245
	Batch 500/22849, Loss: 4.7323
	Batch 1000/22849, Loss: 5.0995
	Batch 1500/22849, Loss: 4.8943
	Batch 2000/22849, Loss: 4.7676
	Batch 2500/22849, Loss: 5.1930
	Batch 3000/22849, Loss: 4.7066
	Batch 3500/22849, Loss: 5.0375
	Batch 4000/22849, Loss: 4.9793
	Batch 4500/22849, Loss: 5.0903
	Batch 5000/22849, Loss: 4.6616
	Batch 5500/22849, Loss: 4.5518
	Batch 6000/22849, Loss: 4.6911
	Batch 6500/22849, Loss: 4.9674
	Batch 7000/22849, Loss: 4.5144
	Batch 7500/22849, Loss: 4.8321
	Batch 8000/22849, Loss: 4.4702
	Batch 8500/22849, Loss: 4.6591
	Batch 9000/22849, Loss: 4.9835
	Batch 9500/22849, Loss: 4.6525
	Batch 10000/22849, Loss: 4.6228
	Batch 10500/22849, Loss: 4.8135
	Batch 11000/22849, Loss: 4.9328
	Batch 11500/22849, Loss: 4.9268
	Batch 12000/22849, Loss: 4.9881
	Batch 12500/22849, Loss: 4.9302
	Batch 13000/22849, Loss: 4.5632
	Batch 13500/22849, Loss: 4.7830
	Batch 14000/22849, Loss: 4.8030
	Batch 14500/228