## 导包

In [3]:
# 导入python库
import numpy as np
import pandas as pd
import math
import re
import tensorflow as tf
from collections import Counter

In [4]:
# 装载云端硬盘，读取训练数据
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 数据预处理

In [5]:
# 数据路径
DATA_PATH = '/content/drive/MyDrive/poetry.txt'
# 单行诗最大长度
MAX_LEN = 64
# 禁用的字符，拥有以下符号的诗将被忽略
DISALLOWED_WORDS = ['（', '）', '(', ')', '__', '《', '》', '【', '】', '[', ']']

# 一首诗（一行）对应一个列表的元素
poetry = []

# 按行读取数据 poetry.txt
with open(DATA_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
# 遍历处理每一条数据
for line in lines:
    # 利用正则表达式拆分标题和内容
    fields = re.split(r"[:：]", line)
    # 跳过异常数据
    if len(fields) != 2:
        continue
    # 得到诗词内容（后面不需要标题）
    content = fields[1]
    # 跳过内容过长的诗词
    if len(content) > MAX_LEN - 2:
        continue
    # 跳过存在禁用符的诗词
    if any(word in content for word in DISALLOWED_WORDS):
        continue

    poetry.append(content.replace('\n', '')) # 最后要记得删除换行符

In [6]:
# 查看处理后的数据
for i in range(0, 30):
    print(poetry[i])

寒随穷律变，春逐鸟声开。初风飘带柳，晚雪间花梅。碧林青旧竹，绿沼翠新苔。芝田初雁去，绮树巧莺来。
晚霞聊自怡，初晴弥可喜。日晃百花色，风动千林翠。池鱼跃不同，园鸟声还异。寄言博通者，知予物外志。
夏律昨留灰，秋箭今移晷。峨嵋岫初出，洞庭波渐起。桂白发幽岩，菊黄开灞涘。运流方可叹，含毫属微理。
寒惊蓟门叶，秋发小山枝。松阴背日转，竹影避风移。提壶菊花岸，高兴芙蓉池。欲知凉气早，巢空燕不窥。
山亭秋色满，岩牖凉风度。疏兰尚染烟，残菊犹承露。古石衣新苔，新巢封古树。历览情无极，咫尺轮光暮。
慨然抚长剑，济世岂邀名。星旗纷电举，日羽肃天行。遍野屯万骑，临原驻五营。登山麾武节，背水纵神兵。在昔戎戈动，今来宇宙平。
翠野驻戎轩，卢龙转征旆。遥山丽如绮，长流萦似带。海气百重楼，岩松千丈盖。兹焉可游赏，何必襄城外。
玄兔月初明，澄辉照辽碣。映云光暂隐，隔树花如缀。魄满桂枝圆，轮亏镜彩缺。临城却影散，带晕重围结。驻跸俯九都，停观妖氛灭。
碧原开雾隰，绮岭峻霞城。烟峰高下翠，日浪浅深明。斑红妆蕊树，圆青压溜荆。迹岩劳傅想，窥野访莘情。巨川何以济，舟楫伫时英。
春蒐驰骏骨，总辔俯长河。霞处流萦锦，风前漾卷罗。水花翻照树，堤兰倒插波。岂必汾阴曲，秋云发棹歌。
重峦俯渭水，碧嶂插遥天。出红扶岭日，入翠贮岩烟。叠松朝若夜，复岫阙疑全。对此恬千虑，无劳访九仙。
朝光浮烧野，霜华净碧空。结浪冰初镜，在径菊方丛。约岭烟深翠，分旗霞散红。抽思滋泉侧，飞想傅岩中。已获千箱庆，何以继熏风。
岭衔宵月桂，珠穿晓露丛。蝉啼觉树冷，萤火不温风。花生圆菊蕊，荷尽戏鱼通。晨浦鸣飞雁，夕渚集栖鸿。飒飒高天吹，氛澄下炽空。
萧条起关塞，摇飏下蓬瀛。拂林花乱彩，响谷鸟分声。披云罗影散，泛水织文生。劳歌大风曲，威加四海清。
罩云飘远岫，喷雨泛长河。低飞昏岭腹，斜足洒岩阿。泫丛珠缔叶，起溜镜图波。濛柳添丝密，含吹织空罗。
洁野凝晨曜，装墀带夕晖。集条分树玉，拂浪影泉玑。色洒妆台粉，花飘绮席衣。入扇萦离匣，点素皎残机。
北阙三春晚，南荣九夏初。黄莺弄渐变，翠林花落余。瀑流还响谷，猿啼自应虚。早荷向心卷，长杨就影舒。此时欢不极，调轸坐相于。
红轮不暂驻，乌飞岂复停。岑霞渐渐落，溪阴寸寸生。藿叶随光转，葵心逐照倾。晚烟含树色，栖鸟杂流声。
高轩临碧渚，飞檐迥架空。余花攒镂槛，残柳散雕栊。岸菊初含蕊，园梨始带红。莫虑昆山暗，还共尽杯

In [7]:
# 最小词频
MIN_WORD_FREQUENCY = 8

# 统计词频，利用Counter可以直接按单个字符进行统计词频
counter = Counter()
for line in poetry:
    counter.update(line)
# 过滤掉低词频的词
tokens = [token for token, count in counter.items() if count >= MIN_WORD_FREQUENCY]

In [8]:
# 查看出现频率最高的五个字符
i = 0
for token, count in counter.items():
    if i >= 5:
        break;
    print(token, "->",count)
    i += 1

寒 -> 2627
随 -> 1039
穷 -> 487
律 -> 119
变 -> 286


In [9]:
# 补上特殊词标记：填充字符标记、未知词标记、开始标记、结束标记
tokens = ["[PAD]", "[NONE]", "[START]", "[END]"] + tokens

In [10]:
# 映射: 词 -> 编号
word_idx = {}
# 映射: 编号 -> 词
idx_word = {}
for idx, word in enumerate(tokens):
    word_idx[word] = idx
    idx_word[idx] = word

In [11]:
class Tokenizer:
    """
    分词器
    """

    def __init__(self, tokens):
        # 词汇表大小
        self.dict_size = len(tokens)
        # 生成映射关系
        self.token_id = {} # 映射: 词 -> 编号
        self.id_token = {} # 映射: 编号 -> 词
        for idx, word in enumerate(tokens):
            self.token_id[word] = idx
            self.id_token[idx] = word

        # 各个特殊标记的编号id，方便其他地方使用
        self.start_id = self.token_id["[START]"]
        self.end_id = self.token_id["[END]"]
        self.none_id = self.token_id["[NONE]"]
        self.pad_id = self.token_id["[PAD]"]

    def id_to_token(self, token_id):
        """
        编号 -> 词
        """
        return self.id_token.get(token_id)

    def token_to_id(self, token):
        """
        词 -> 编号
        """
        return self.token_id.get(token, self.none_id)

    def encode(self, tokens):
        """
        词列表 -> [START]编号 + 编号列表 + [END]编号
        """
        token_ids = [self.start_id, ] # 起始标记
        # 遍历，词转编号
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        token_ids.append(self.end_id) # 结束标记
        return token_ids

    def decode(self, token_ids):
        """
        编号列表 -> 词列表(去掉起始、结束标记)
        """
        # 起始、结束标记
        flag_tokens = {"[START]", "[END]"}

        tokens = []
        for idx in token_ids:
            token = self.id_to_token(idx)
            # 跳过起始、结束标记
            if token not in flag_tokens:
                tokens.append(token)
        return tokens

In [12]:
tokenizer = Tokenizer(tokens)

In [13]:
class PoetryDataSet:
    """
    古诗数据集生成器
    """

    def __init__(self, data, tokenizer, batch_size):
        # 数据集
        self.data = data
        self.total_size = len(self.data)
        # 分词器，用于词转编号
        self.tokenizer = tokenizer
        # 每批数据量
        self.batch_size = BATCH_SIZE
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))

    def pad_line(self, line, length, padding=None):
        """
        对齐单行数据
        """
        if padding is None:
            padding = self.tokenizer.pad_id

        padding_length = length - len(line)
        if padding_length > 0:
            return line + [padding] * padding_length
        else:
            return line[:length]

    def __len__(self):
        return self.steps

    def __iter__(self):
        # 打乱数据
        np.random.shuffle(self.data)
        # 迭代一个epoch，每次yield一个batch
        for start in range(0, self.total_size, self.batch_size):
            end = min(start + self.batch_size, self.total_size)
            data = self.data[start:end]

            max_length = max(map(len, data))

            batch_data = []
            for str_line in data:
                # 对每一行诗词进行编码、并补齐padding
                encode_line = self.tokenizer.encode(str_line)
                pad_encode_line = self.pad_line(encode_line, max_length + 2) # 加2是因为tokenizer.encode会添加START和END
                batch_data.append(pad_encode_line)

            batch_data = np.array(batch_data)
            # yield 特征、标签
            yield batch_data[:, :-1], batch_data[:, 1:]

    def generator(self):
        while True:
            yield from self.__iter__()

In [14]:
BATCH_SIZE = 32
dataset = PoetryDataSet(poetry, tokenizer, BATCH_SIZE)

In [15]:
model = tf.keras.Sequential([
    # 词嵌入层
    tf.keras.layers.Embedding(input_dim=tokenizer.dict_size, output_dim=150),
    # 第一个LSTM层
    tf.keras.layers.LSTM(150, dropout=0.5, return_sequences=True),
    # 第二个LSTM层
    tf.keras.layers.LSTM(150, dropout=0.5, return_sequences=True),
    # 利用TimeDistributed对每个时间步的输出都做Dense操作(softmax激活)
    tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.dict_size, activation='softmax')),
])

In [23]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.sparse_categorical_crossentropy
)

In [24]:
model.fit(
    dataset.generator(),
    steps_per_epoch=dataset.steps,
    epochs=10
)

Epoch 1/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m498s[0m 641ms/step - loss: 5.3537
Epoch 2/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m470s[0m 614ms/step - loss: 4.3445
Epoch 3/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m474s[0m 618ms/step - loss: 4.1726
Epoch 4/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m472s[0m 615ms/step - loss: 4.0492
Epoch 5/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m463s[0m 604ms/step - loss: 3.9063
Epoch 6/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m461s[0m 602ms/step - loss: 3.8113
Epoch 7/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m466s[0m 607ms/step - loss: 3.7643
Epoch 8/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m464s[0m 605ms/step - loss: 3.7096
Epoch 9/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m470s[0m 613ms/step - loss: 3.6561
Epoch 10/10
[1m767/767[0m [32m━━━━━━━━━━━━━━━━━━━━[

<keras.src.callbacks.history.History at 0x7db556613400>

## 预测

In [25]:
# 需要先将词转为编号
token_ids = [tokenizer.token_to_id(word) for word in ["月", "光", "静", "谧"]]

# 将列表转换为二维NumPy数组
token_ids_array = np.array([token_ids])

# 进行预测
result = model.predict(token_ids_array)
print(result)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 726ms/step
[[[4.7380320e-04 5.1363725e-03 1.9682902e-06 ... 2.3627650e-05
   1.0689867e-05 9.8074124e-06]
  [8.7223598e-05 4.6725199e-03 3.7058129e-08 ... 2.1165872e-06
   2.6165992e-06 5.2159038e-07]
  [1.1590040e-04 4.7531524e-03 5.7887835e-08 ... 4.9152854e-06
   3.5719447e-06 5.9783656e-07]
  [1.8908936e-04 4.4616945e-03 2.1600696e-08 ... 1.6199649e-06
   7.8337126e-07 5.7985267e-07]]]


In [26]:
def predict(model, token_ids):
    """
    在概率值为前100的词中选取一个词(按概率分布的方式)
    :return: 一个词的编号(不包含[PAD][NONE][START])
    """
    # 预测各个词的概率分布
    # -1 表示只要对最新的词的预测
    # 3: 表示不要前面几个标记符
    _probas = model.predict([token_ids, ])[0, -1, 3:]
    # 按概率降序，取前100
    p_args = _probas.argsort()[-100:][::-1] # 此时拿到的是索引
    p = _probas[p_args] # 根据索引找到具体的概率值
    p = p / sum(p) # 归一
    # 按概率抽取一个
    target_index = np.random.choice(len(p), p=p)
    # 前面预测时删除了前几个标记符，因此编号要补上3位，才是实际在tokenizer词典中的编号
    return p_args[target_index] + 3

In [27]:
token_ids = tokenizer.encode("清风明月")[:-1]
while len(token_ids) < 13:
    # 将token_ids转换为二维NumPy数组
    token_ids_array = np.array([token_ids])
    # 使用模型进行预测
    target = predict(model, token_ids_array)
    # 保存结果
    token_ids.append(target)
    # 到达END
    if target == tokenizer.end_id:
        break

print("".join(tokenizer.decode(token_ids)))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 445ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 761ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step
清风明月夜，幽夕色吟吟。


In [28]:
def generate_random_poem(tokenizer, model, text=""):
    """
    随机生成一首诗
    :param tokenizer: 分词器
    :param model: 古诗模型
    :param text: 古诗的起始字符串，默认为空
    :return: 一首古诗的字符串
    """
    # 将初始字符串转成token_ids，并去掉结束标记[END]
    token_ids = tokenizer.encode(text)[:-1]
    while len(token_ids) < MAX_LEN:
        # 将 token_ids转换为二维NumPy数组
        token_ids_array = np.array([token_ids])
        # 预测词的编号
        target = predict(model, token_ids_array)
        # 保存结果
        token_ids.append(target)
        # 到达END
        if target == tokenizer.end_id:
            break

    return "".join(tokenizer.decode(token_ids))

In [29]:
for i in range(5):
    print(generate_random_poem(tokenizer, model))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30

In [30]:
print(generate_random_poem(tokenizer, model, "春眠不觉晓，"))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33

In [31]:
def generate_acrostic_poem(tokenizer, model, heads):
    """
    生成一首藏头诗
    :param tokenizer: 分词器
    :param model: 古诗模型
    :param heads: 藏头诗的头
    :return: 一首古诗的字符串
    """
    # token_ids，只包含[START]编号
    token_ids = [tokenizer.start_id, ]
    # 逗号和句号标记编号
    punctuation_ids = {tokenizer.token_to_id("，"), tokenizer.token_to_id("。")}
    content = []
    # 为每一个head生成一句诗
    for head in heads:
        content.append(head)
        # head转为编号id，放入列表，用于预测
        token_ids.append(tokenizer.token_to_id(head))
        # 开始生成一句诗
        target = -1

        while target not in punctuation_ids: # 遇到逗号、句号，说明本句结束，开始下一句
            input_ids = np.array([token_ids])
            # 预测词的编号
            target = predict(model, input_ids)
            # 因为可能预测到END，所以加个判断
            if target > 3:
                # 保存结果到token_ids中，下一次预测还要用
                token_ids.append(target)
                content.append(tokenizer.id_to_token(target))

    return "".join(content)

In [32]:
print(generate_acrostic_poem(tokenizer, model, heads="上善若水"))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36