In [14]:
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# 定义特殊标记
SPECIAL_TOKENS = ['<pad>', '<bos>', '<eos>', '<unk>']

# 读取数据
def read_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f]

def build_vocab(data, min_freq=1):
    vocab = {token: idx for idx, token in enumerate(SPECIAL_TOKENS)}
    word_freq = {}
    for line in data:
        for char in line:
            word_freq[char] = word_freq.get(char, 0) + 1
    
    idx = len(vocab)
    for word, freq in word_freq.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

##### 数据集：https://github.com/chinese-poetry/chinese-poetry

In [4]:
# 读取数据
data = read_data('poetry.txt')

# 构建词典
vocab = build_vocab(data, min_freq=2)
for token, idx in vocab.items():
    print(f'Token: {token} Index: {idx}')

print('Vocab size:', len(vocab))

Token: <pad> Index: 0
Token: <bos> Index: 1
Token: <eos> Index: 2
Token: <unk> Index: 3
Token: 圣 Index: 4
Token: 朝 Index: 5
Token: 兴 Index: 6
Token: 运 Index: 7
Token: 自 Index: 8
Token: 天 Index: 9
Token: 开 Index: 10
Token: ， Index: 11
Token: 又 Index: 12
Token: 直 Index: 13
Token: 临 Index: 14
Token: 轩 Index: 15
Token: 策 Index: 16
Token: 草 Index: 17
Token: 莱 Index: 18
Token: 。 Index: 19
Token: 廷 Index: 20
Token: 对 Index: 21
Token: 惭 Index: 22
Token: 无 Index: 23
Token: 宿 Index: 24
Token: 构 Index: 25
Token: 胪 Index: 26
Token: 传 Index: 27
Token: 何 Index: 28
Token: 意 Index: 29
Token: 冠 Index: 30
Token: 群 Index: 31
Token: 魁 Index: 32
Token: 幸 Index: 33
Token: 瞻 Index: 34
Token: 北 Index: 35
Token: 阙 Index: 36
Token: 承 Index: 37
Token: 殊 Index: 38
Token: 宠 Index: 39
Token: 忍 Index: 40
Token: 负 Index: 41
Token: 南 Index: 42
Token: 山 Index: 43
Token: 咏 Index: 44
Token: 有 Index: 45
Token: 台 Index: 46
Token: 稽 Index: 47
Token: 首 Index: 48
Token: 君 Index: 49
Token: 恩 Index: 50
Token: 难 Index: 51
Token:

In [5]:
class PoetryDataset(Dataset):
    def __init__(self, data, vocab):
        self.data = data
        self.vocab = vocab
        self.unk_idx = vocab['<unk>']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        line = self.data[idx]
        input_seq, output_seq = line.split('，', 1)
        
        input_seq = [self.vocab['<bos>']] + [self.vocab.get(char, self.unk_idx) for char in input_seq] + [self.vocab['<eos>']]
        output_seq = [self.vocab['<bos>']] + [self.vocab.get(char, self.unk_idx) for char in output_seq] + [self.vocab['<eos>']]
        
        return torch.tensor(input_seq), torch.tensor(output_seq)

# 动态padding
def collate_fn(batch):
    input_seqs, output_seqs = zip(*batch)
    input_seqs = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    output_seqs = pad_sequence(output_seqs, batch_first=True, padding_value=0)
    return input_seqs, output_seqs

In [6]:
# 创建数据集和数据加载器
dataset = PoetryDataset(data, vocab)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)

for id, (src, tgt) in enumerate(dataloader):
    print('Source shape:', src.shape)
    print('Target shape:', tgt.shape)
    print('Source sequence:', src[0])
    print('Target sequence:', tgt[0])

    if id == 1:
        break

Source shape: torch.Size([32, 9])
Target shape: torch.Size([32, 260])
Source sequence: tensor([   1,  244,  291,  466, 1899, 1501,  437, 1381,    2])
Target sequence: tensor([   1,  812, 1190, 1149, 1213,  135, 1180,  154,   19, 2300,  224,    3,
        2786,  285,  224,  383,   11, 2257,  978,  130,  814,  536,  712,  258,
          19,    2,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,  

In [None]:
# 定义Transformer模型
class PoetryTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
        super(PoetryTransformer, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(d_model, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)

        src_mask = None
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

         # 将 key_padding_mask 转换为浮点类型
        if src_key_padding_mask is not None:
            src_key_padding_mask = src_key_padding_mask.float().masked_fill(
                src_key_padding_mask, float('-inf'))
        if tgt_key_padding_mask is not None:
            tgt_key_padding_mask = tgt_key_padding_mask.float().masked_fill(
                tgt_key_padding_mask, float('-inf'))

        output = self.transformer(src, tgt, src_mask, tgt_mask, 
                                  src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask)
        return self.fc(output)

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [8]:
model = PoetryTransformer(
    vocab_size=len(vocab),
    d_model=256,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=512,
    dropout=0.1
).to(device)

model

PoetryTransformer(
  (embedding): Embedding(6238, 256)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Transform

In [9]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        
        src_key_padding_mask = (src == 0).to(device)
        tgt_key_padding_mask = (tgt == 0).to(device)
        
        optimizer.zero_grad()
        
        output = model(src, tgt[:, :-1], 
                       src_key_padding_mask=src_key_padding_mask,
                       tgt_key_padding_mask=tgt_key_padding_mask[:, :-1])
        
        loss = criterion(output.reshape(-1, output.shape[-1]), tgt[:, 1:].reshape(-1))
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [10]:
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])

# 训练模型
num_epochs = 120
best_loss = float('inf')
for epoch in range(num_epochs):
    loss = train(model, dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}')
    
    # 保存最佳模型
    if loss < best_loss:
        best_loss = loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, 'best_model.pth')
        print(f"Best model saved with loss: {loss:.4f}")

Epoch 1/120, Loss: 6.6828
Best model saved with loss: 6.6828
Epoch 2/120, Loss: 6.3197
Best model saved with loss: 6.3197
Epoch 3/120, Loss: 6.2526
Best model saved with loss: 6.2526
Epoch 4/120, Loss: 6.2097
Best model saved with loss: 6.2097
Epoch 5/120, Loss: 6.1722
Best model saved with loss: 6.1722
Epoch 6/120, Loss: 6.1380
Best model saved with loss: 6.1380
Epoch 7/120, Loss: 6.1094
Best model saved with loss: 6.1094
Epoch 8/120, Loss: 6.0890
Best model saved with loss: 6.0890
Epoch 9/120, Loss: 6.0642
Best model saved with loss: 6.0642
Epoch 10/120, Loss: 6.0459
Best model saved with loss: 6.0459
Epoch 11/120, Loss: 6.0228
Best model saved with loss: 6.0228
Epoch 12/120, Loss: 6.0033
Best model saved with loss: 6.0033
Epoch 13/120, Loss: 5.9745
Best model saved with loss: 5.9745
Epoch 14/120, Loss: 5.9479
Best model saved with loss: 5.9479
Epoch 15/120, Loss: 5.9218
Best model saved with loss: 5.9218
Epoch 16/120, Loss: 5.9023
Best model saved with loss: 5.9023
Epoch 17/120, Los

In [11]:
import torch.nn.functional as F

def generate_poetry(model, vocab, start_sequence, max_length=50, device='cuda'):
    model.eval()
    idx2word = {idx: word for word, idx in vocab.items()}
    input_seq = torch.tensor([[vocab.get(char, vocab['<unk>']) for char in start_sequence]], device=device)
    
    for _ in range(max_length):
        with torch.no_grad():
            output = model(input_seq, input_seq)
            next_token_logits = output[0, -1, :]
            next_token_probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(next_token_probs, 1)
        
        input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
        
        if next_token.item() == vocab['<eos>']:
            break
    
    generated_seq = input_seq.squeeze().tolist()
    poem = ''.join([idx2word[idx] for idx in generated_seq if idx not in [vocab['<pad>'], vocab['<bos>'], vocab['<eos>']]])
    return poem

In [29]:
# 初始化模型（使用与训练时相同的参数）
model = PoetryTransformer(
    vocab_size=len(vocab),
    d_model=256,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=512,
    dropout=0.1
).to(device)

# 加载最佳模型
checkpoint = torch.load('best_model.pth', weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 生成诗词
start_sequence = "古寺重来兴转添"
generated_poem = generate_poetry(model, vocab, start_sequence + "，", max_length=50, device=device)

print("生成的诗词：")
print(generated_poem)

生成的诗词：
古寺重来兴转添，到石泉幽浄简，舟切受峰，相误俗，琢与鹦醴，甘然供。赐缁纷挥三曾释昆轮，岩有邾大岷掀渠
