In [None]:
PLACEHOLDER_FILE_PATH = "../../datasets/poetry.txt"
PLACEHOLDER_MODEL_PATH = "../models/seq2seq_poetry_model.h5"
PLACEHOLDER_ENCODER_MODEL_PATH = "../models/seq2seq_poetry_encoder.h5"
PLACEHOLDER_DECODER_MODEL_PATH = "../models/seq2seq_poetry_decoder.h5"
PLACEHOLDER_DICT_PATH = "../models/seq2seq_poetry_dicts.pkl"

In [None]:
## Init

from keras.models import load_model
import numpy as np
import random
import pickle

with open(PLACEHOLDER_DICT_PATH, 'rb') as f:
    tmp = pickle.load(f)
    input_vocab = tmp["input_vocab"]
    target_vocab = tmp["target_vocab"]
    reverse_input_char_index = tmp["reverse_input_char_index"]
    reverse_target_char_index = tmp["reverse_target_char_index"]
    encoder_len = tmp["encoder_len"]
    decoder_len = tmp["decoder_len"]
    
# 输入侧词汇表大小
encoder_vocab_size = len(input_vocab)
# 输出侧词汇表大小
decoder_vocab_size = len(target_vocab)

#model = load_model(PLACEHOLDER_MODEL_PATH)
encoder_model = load_model(PLACEHOLDER_ENCODER_MODEL_PATH)
decoder_model = load_model(PLACEHOLDER_DECODER_MODEL_PATH)

def decode_sequence(input_seq):
    # 先把上句输入编码器得到编码的中间向量，这个中间向量将是解码器的初始状态向量
    states_value = encoder_model.predict(input_seq)
    # 初始的解码器输入是开始符'\t'
    target_seq = np.zeros((1, 1))
    target_seq[0, 0] = target_vocab['\t']

    stop_condition = False
    decoded_sentence = ''
    # 迭代解码
    while not stop_condition:
        # 把当前的解码器输入和当前的解码器状态向量送进解码器
        # 得到对下一个时刻的预测和新的解码器状态向量
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
        # 采样出概率最大的那个字作为下一个时刻的输入
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char
        # 如果采样到了结束符或者生成的句子长度超过了decoder_len，就停止生成
        if (sampled_char == '\n' or len(decoded_sentence) > decoder_len):
            stop_condition = True
        # 否则我们更新下一个时刻的解码器输入和解码器状态向量
        target_seq = np.zeros((1, 1))
        target_seq[0, 0] = sampled_token_index
        states_value = [h, c]

    return decoded_sentence

In [None]:
PLACEHOLDER_START_TEXT = "千山鸟飞绝"

In [None]:
## Run

st = PLACEHOLDER_START_TEXT
input_texts = [st]
encoder_input_data = np.zeros((len(input_texts), encoder_len), dtype='int')

for i, input_text in enumerate(input_texts):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t] = (input_vocab[char] if char in input_vocab else random.randint(0, encoder_vocab_size-1))

decoded_sentence = decode_sequence(encoder_input_data[0 : 1])
decoded_sentence