In [None]:
PLACEHOLDER_MODEL_PATH = "../models/lstm_poetry.h5"
PLACEHOLDER_CORPUS_FILE_PATH = "../../datasets/poetry.txt"

In [None]:
## Init

import itertools
import jieba
import numpy as np
from collections import Counter
from keras.models import load_model

# 建立词汇表，为每种字赋予唯一的索引
def build_vocab(text, vocab_lim):
    word_cnt = Counter(itertools.chain(*text))
    vocab_inv = [x[0] for x in word_cnt.most_common(vocab_lim)]
    vocab_inv = list(sorted(vocab_inv))
    vocab = {x: index for index, x in enumerate(vocab_inv)}
    return vocab, vocab_inv

# 处理输入文本文件
def process_file(file_name, use_char_based_model):
    raw_text = []
    with open(file_name, "r") as f:
        for line in f:
            if (use_char_based_model):
                raw_text.extend([str(ch) for ch in line])
            else:
                raw_text.extend([word for word in jieba.cut(line)])
    return raw_text

# 格式化文本，建立词矩阵
def build_matrix(text, vocab, length, step):
    M = []
    for word in text:
        index = vocab.get(word)
        if index is None:
            M.append(len(vocab))
        else:
            M.append(index)
    num_sentences = len(M) // length
    M = M[: num_sentences * length]
    M = np.array(M)

    X = []
    Y = []
    for i in range(0, len(M) - length, step):
        X.append(M[i : i + length])
        Y.append(M[i + length])
    return np.array(X), np.array(Y)


model = load_model(PLACEHOLDER_MODEL_PATH)
raw_text = process_file(PLACEHOLDER_CORPUS_FILE_PATH, True)
vocab, vocab_inv = build_vocab(raw_text, 4000)

In [None]:
PLACEHOLDER_START_TEXT = "明月松间照" # 限制为5个字
PLACEHOLDER_GEN_LEN = 20
PLACEHOLDER_TOPN = 4

In [None]:
## Run

st = PLACEHOLDER_START_TEXT
seq_length = 5

print(st, end='')
vocab_inv.append(' ')
for i in range(PLACEHOLDER_GEN_LEN):
    X_sample = np.array([[vocab.get(x, len(vocab)) for x in st]])
    pdt = (-model.predict(X_sample))[0].argsort()[:PLACEHOLDER_TOPN]
    if vocab_inv[pdt[0]] == '，' or vocab_inv[pdt[0]] == '。' or vocab_inv[pdt[0]] == '\n':
        ch = vocab_inv[pdt[0]]
    else:
        ch = vocab_inv[np.random.choice(pdt)]
    print(ch, end='')
    if len(st) == seq_length:
        st = st[1 :] + ch
    else:
        st = st + ch