In [1]:
from data import AttnDataSet
from network import attention_network
import jieba, jieba.posseg
import tensorflow as tf
from config import *
import numpy as np

  from ._conv import register_converters as _register_converters


In [3]:
# 给jieba分词加载自定义词典
jieba.set_dictionary("./data/new_dict.txt")

# 定义数据集
dataset = AttnDataSet(batch_size, max_length, vocab_size)

### 网络结构

In [4]:
X_encoder = tf.placeholder(tf.float32, [batch_size, max_length, vector_size])
X_decoder = tf.placeholder(tf.float32, [batch_size, max_length, vector_size])
encoder_length = tf.placeholder(tf.int32, [batch_size])
decoder_length = tf.placeholder(tf.int32, [batch_size])
y = tf.placeholder(tf.int32, [batch_size, max_length])
z_encoder = tf.placeholder(tf.int32, [batch_size, max_length])
z_decoder = tf.placeholder(tf.int32, [batch_size, max_length])
dropout_prob = tf.placeholder(tf.float32)

output, decoder_state = attention_network(X_encoder, X_decoder, encoder_length, z_encoder,
                                                      z_decoder, dropout_prob)
alignments = decoder_state[3]
alignment_history = decoder_state[4]

Instructions for updating:
seq_dim is deprecated, use seq_axis instead
Instructions for updating:
batch_dim is deprecated, use batch_axis instead


### 加载模型

In [6]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
model_dir = "./model"
checkPoint = tf.train.get_checkpoint_state(model_dir)
saver.restore(sess, checkPoint.model_checkpoint_path)

INFO:tensorflow:Restoring parameters from ./model/attention_model-500000


In [7]:
def Z_generator(length, decode=False):
    z = np.zeros(max_length, dtype=np.int32)
    if length == 1:
        z[0 + int(decode)] = 4
    elif length == 2:
        z[0 + int(decode)], z[1 + int(decode)] = 3, 4
    elif length == 3:
        z[0 + int(decode)], z[1 + int(decode)], z[2 + int(decode)] = 2, 3, 4
    elif length >= 4:
        z[0 + int(decode)], z[length-3 + int(decode)], z[length-2 + int(decode)], z[length-1 + int(decode)] = 1, 2, 3, 4
    return z

In [8]:
def generator(sentence1, sentence2, z_flags2):
    _encoder_length = np.zeros((batch_size), dtype=np.int32)
    _encoder_length[0] = len(sentence1)
    _X_encoder = dataset.sentence2seq(sentence1)
    _X_decoder = dataset.sentence2seq(sentence2)
    _z_encoder = np.zeros((batch_size, max_length), dtype=np.int32)
    _z_decoder = np.zeros((batch_size, max_length), dtype=np.int32)
    _z_encoder[0] = np.zeros((max_length), dtype=np.int32)
    _z_decoder[0] = z_flags2
    
    logists, align = sess.run([output, alignment_history.read(len(sentence2) - 1)], feed_dict={X_encoder: _X_encoder, X_decoder: _X_decoder, 
                                          encoder_length: _encoder_length, 
                                          z_encoder: _z_encoder, z_decoder: _z_decoder, dropout_prob: 1.})
    result = logists.argsort(axis=2)[0, len(sentence2) - 1, :].tolist()  # 按概率大小对词袋索引进行排序
    result.reverse()
    for key, value in enumerate(result):
        result[key] = dataset.word_bag[value]  # 将索引换乘对应的词语
    return result, align

### 上下文相关、情感
下句相同，改变上句，会影响当前生成结果

In [15]:
s1 = ["迎接", "早晨", "灿烂", "的", "阳光"]
s2 = ["<GO>", "我", "的"]
s2_length = 4
z2 = Z_generator(s2_length, decode=True)
result, align = generator(s1, s2, z2)
print("注意力分布:", align[:, :len(s1)])
print("生成结果:", result[:10])

注意力分布: [[0.20360762 0.3191696  0.14018247 0.09243581 0.24460447]]
生成结果: ['祖国', '阳光', '太阳', '新疆', '家庭', '惊叹', '大道', '蓝天', '成绩', '春天']


In [16]:
s1 = ["眼角", "留", "着", "你", "给", "的", "泪水"]
s2 = ["<GO>", "我", "的"]
s2_length = 4
z2 = Z_generator(s2_length, decode=True)
result, align = generator(s1, s2, z2)
print("注意力分布:", align[:, :len(s1)])
print("生成结果:", result[:10])

注意力分布: [[0.29617822 0.0401211  0.09877114 0.15365875 0.07016596 0.12643187
  0.21467307]]
生成结果: ['眼角', '眼泪', '心里', '心碎', '泪水', '爱', '泪', '伤心', '泪光', '心理']


In [17]:
s1 = ["我", "在", "黑夜", "之中", "寻找", "出口"]
s2 = ["<GO>", "我", "的"]
s2_length = 4
z2 = Z_generator(s2_length, decode=True)
result, align = generator(s1, s2, z2)
print("注意力分布:", align[:, :len(s1)])
print("生成结果:", result[:10])

注意力分布: [[0.129262   0.08518599 0.46715733 0.09963258 0.09444849 0.12431359]]
生成结果: ['迷惘', '孤独', '寂寞', '黑夜', '星光', '孤单', '彷徨', '放逐', '无垠', '心']


In [19]:
s1 = ["给", "我", "甜蜜", "笑容"]
s2 = ["<GO>", "我", "的"]
s2_length = 4
z2 = Z_generator(s2_length, decode=True)
result, align = generator(s1, s2, z2)
print("注意力分布:", align[:, :len(s1)])
print("生成结果:", result[:10])

注意力分布: [[0.15871677 0.16507187 0.32744813 0.34876317]]
生成结果: ['爱人', '人生', '快乐', '美梦', '小', '生命', '爱情', '热情', '健壮', '心']
