In [1]:
# 1. 环境准备
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_from_disk
import json
import os
import numpy as np


In [2]:
# 2. 加载数据
# 配置路径
data_path = "./data/data"
vocab_path = "./data/vocab.json"
model_save_path = "./model/Poetry_Transformer_model.pt"

# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载数据集
data = load_from_disk(data_path)
poetry = data['poetry']

# 加载词典
def load_vocab(vocab_path):
    with open(vocab_path, 'r', encoding='utf-8') as f:
        vocab = json.load(f)
    return vocab

vocab = load_vocab(vocab_path)
char_to_id = vocab['char_to_id']
id_to_char = vocab['id_to_char']
vocab_size = len(char_to_id)

# 计算最长诗词长度
max_len = max(len(s) for s in poetry)
print(f"词汇表大小: {vocab_size}, 最长诗词长度: {max_len}")


词汇表大小: 9050, 最长诗词长度: 48


In [3]:
# 3. 处理数据
class PoetryDataset(Dataset):
    def __init__(self, dataset, vocab, max_length=48):
        self.max_length = max_length
        self.char_to_id = vocab['char_to_id']

        # 处理输入数据
        self.inputs = [[self.char_to_id['<SOS>']] + self.text_to_id(text) for text in dataset]
        self.labels = [self.text_to_id(text) + [self.char_to_id['<EOS>']] for text in dataset]

        # 填充到相同长度
        self.inputs = self.pad_sequences(self.inputs)
        self.labels = self.pad_sequences(self.labels)

    def text_to_id(self, text):
        """将文本转换为 ID 序列"""
        return [self.char_to_id.get(char, self.char_to_id['<UNK>']) for char in text]

    def pad_sequences(self, sequences):
        """填充序列"""
        return [seq + [self.char_to_id['<PAD>']] * (self.max_length - len(seq)) if len(seq) < self.max_length else seq[:self.max_length] for seq in sequences]

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.inputs[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
        }

train_dataset = PoetryDataset(poetry, vocab, max_length=max_len+2)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)


In [4]:
# 4. 定义 Transformer 诗歌生成模型
class TransformerPoetry(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, num_heads=8, ff_dim=512, num_layers=6):
        super(TransformerPoetry, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=char_to_id['<PAD>'])
        self.positional_encoding = self.create_positional_encoding(embed_dim, max_len + 2)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc = nn.Linear(embed_dim, vocab_size)

    def create_positional_encoding(self, embed_dim, max_len):
        """生成固定的 Positional Encoding"""
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0).to(device)

    def forward(self, x):
        x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
        x = self.encoder(x)
        x = self.fc(x)
        return x


In [5]:
# 5. 训练模型
def train_transformer():
    model = TransformerPoetry(vocab_size).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    loss_fn = nn.CrossEntropyLoss(ignore_index=char_to_id['<PAD>'])
    num_epochs = 50

    if os.path.isfile(model_save_path):
        print(f"加载已训练模型 {model_save_path}")
        model.load_state_dict(torch.load(model_save_path))

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            output = model(input_ids)
            
            loss = loss_fn(output.view(-1, vocab_size), labels.view(-1))
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        scheduler.step(total_loss / len(train_loader))
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

    torch.save(model.state_dict(), model_save_path)
    print("模型训练完成并保存!")

train_transformer()


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


KeyboardInterrupt: 

In [None]:
# 6. 诗歌生成
def generate_poetry(model, vocab, start_text="", max_length=50, temperature=0.8, top_k=10):
    model.eval()
    char_to_id = vocab['char_to_id']
    id_to_char = vocab['id_to_char']

    input_ids = [char_to_id['<SOS>']] + [char_to_id.get(c, char_to_id['<UNK>']) for c in start_text]
    input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)

    generated_text = start_text

    with torch.no_grad():
        for _ in range(max_length - len(start_text)):
            output = model(input_tensor)[:, -1, :] / temperature
            probs = torch.softmax(output, dim=-1)

            topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
            predicted_id = torch.multinomial(topk_probs.squeeze(0), 1).item()
            predicted_id = topk_indices.squeeze(0)[predicted_id].item()

            if predicted_id == char_to_id['<EOS>']:
                break
            generated_text += id_to_char[str(predicted_id)]
            input_tensor = torch.tensor([input_ids + [predicted_id]], dtype=torch.long).to(device)

    return generated_text

# 生成示例
model = TransformerPoetry(vocab_size).to(device)
model.load_state_dict(torch.load(model_save_path, map_location=device))
model.eval()

output_poetry = generate_poetry(model, vocab, start_text="春风", max_length=max_len+2)
print("生成的诗歌:", output_poetry)