In [None]:
from tensorflow.python.layers import core as core_layers
import tensorflow as tf
import numpy as np
import time
import os
import Util.myResidualCell
from Util.bleu import BLEU
from tensorflow.python.ops import array_ops
import random
import pickle as cPickle
import matplotlib.pyplot as plt


class VRAE:
    def __init__(self, dp, rnn_size, n_layers, latent_dim, var, encoder_embedding_dim, decoder_embedding_dim, max_infer_length,
                 sess=tf.Session(), lr=0.001, grad_clip=5.0, is_jieba=False, beam_width=5, force_teaching_ratio=1.0, beam_penalty=0.0,
                residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, 
                latent_weight=0.1, beta_decay_period=10, beta_decay_offset=5, decay_scheme='luong234', is_save=True):
        
        self.rnn_size = rnn_size
        self.latent_dim = latent_dim
        self.n_layers = n_layers
        self.grad_clip = grad_clip
        self.is_jieba = is_jieba
        self.var = var
        self.dp = dp
        self.step = 0
        self.encoder_embedding_dim = encoder_embedding_dim
        self.decoder_embedding_dim = decoder_embedding_dim
        self.beam_width = beam_width
        self.latent_weight = latent_weight
        self.beam_penalty = beam_penalty
        self.max_infer_length = max_infer_length
        self.residual = residual
        self.is_save = is_save
        self.decay_scheme = decay_scheme
        if self.residual:
            assert encoder_embedding_dim == rnn_size
            assert decoder_embedding_dim == rnn_size
        self.reverse = reverse
        self.cell_type = cell_type
        self.force_teaching_ratio = force_teaching_ratio
        self._output_keep_prob = output_keep_prob
        self._input_keep_prob = input_keep_prob
        self.beta_decay_period = beta_decay_period
        self.beta_decay_offset = beta_decay_offset
        self.sess = sess
        self.lr=lr
        self.build_graph()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 30)
        self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary()
        
    # end constructor

    def build_graph(self):
        self.register_symbols()
        self.add_input_layer()
        self.add_encoder_layer()
        self.add_stochastic_layer()
        self.add_decoder_hidden()
        with tf.variable_scope('decode'):
            self.add_decoder_for_training()
        with tf.variable_scope('decode', reuse=True):
            self.add_decoder_for_inference()
        with tf.variable_scope('decode', reuse=True):
            self.add_decoder_for_prefix_inference()
        self.add_backward_path()
    # end method

    def add_input_layer(self):
        self.X = tf.placeholder(tf.int32, [None, None], name="X")
        self.Y = tf.placeholder(tf.int32, [None, None], name="Y")
        self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len")
        self.Y_seq_len = tf.placeholder(tf.int32, [None], name="Y_seq_len")
        self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob")
        self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob")
        self.batch_size = tf.shape(self.X)[0]
        self.B = tf.placeholder(tf.float32, name='Beta_deterministic_warmup')
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
    # end method

    def single_cell(self, reuse=False):
        if self.cell_type == 'lstm':
             cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse)
        else:
            cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size)    
        cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob)
        if self.residual:
            cell = myResidualCell.ResidualWrapper(cell)
        return cell
    
    def add_encoder_layer(self):
        encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        
        self.encoder_inputs = tf.nn.embedding_lookup(encoder_embedding, self.X)
        bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
            cell_fw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]), 
            cell_bw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]),
            inputs = self.encoder_inputs,
            sequence_length = self.X_seq_len,
            dtype = tf.float32,
            scope = 'bidirectional_rnn')
        #print bi_encoder_state
        if self.cell_type == 'lstm':
            self.encoder_out = tf.concat([bi_encoder_state[0][-1][1],bi_encoder_state[1][-1][1]], -1)
        else:
            self.encoder_out = tf.concat([bi_encoder_state[0][-1],bi_encoder_state[1][-1]], -1)
        
    def add_stochastic_layer(self):
        # reparametrization trick
        self.z_mu = tf.layers.dense(self.encoder_out, self.latent_dim)
        z_mean_mu = tf.reduce_mean(self.z_mu, 0)
        
        z_mean_mu = tf.tile(tf.expand_dims(z_mean_mu,0), [tf.shape(self.z_mu)[0], 1])
        
        self.aux_loss = tf.nn.relu(float(self.var) - tf.losses.mean_squared_error(z_mean_mu, self.z_mu))
        
        self.z_lgs2 = tf.layers.dense(self.encoder_out, self.latent_dim)
        noise = tf.random_normal(tf.shape(self.z_lgs2))
        self.z = self.z_mu + tf.exp(0.5 * self.z_lgs2) * noise
        
    
    def add_decoder_hidden(self):
        hidden_state_list = []
        for i in range(self.n_layers * 2):
            if self.cell_type == 'gru':
                hidden_state_list.append(tf.layers.dense(self.z, self.rnn_size))
            else:
                hidden_state_list.append(tf.contrib.rnn.LSTMStateTuple(tf.layers.dense(self.z, self.rnn_size), tf.layers.dense(self.z, self.rnn_size))) 
        self.decoder_init_state = tuple(hidden_state_list)
        
    def processed_decoder_input(self):
        main = tf.strided_slice(self.Y, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char
        decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._y_go), main], 1)
        return decoder_input

    def add_decoder_for_training(self):
        self.decoder_cell = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(2 * self.n_layers)])
        decoder_embedding = tf.get_variable('decoder_embedding', [len(self.dp.Y_w2id), self.decoder_embedding_dim],
                                             tf.float32, tf.random_uniform_initializer(-1.0, 1.0))
        emb = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input())
        inputs = tf.expand_dims(self.z, 1)
        inputs = tf.tile(inputs, [1, tf.shape(emb)[1], 1])
        inputs = tf.concat([emb, inputs],2) 
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs = inputs,
            sequence_length = self.Y_seq_len,
            time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = self.decoder_cell,
            helper = training_helper,
            initial_state = self.decoder_init_state, #self.decoder_cell.zero_state(self.batch_size, tf.float32),
            output_layer = core_layers.Dense(len(self.dp.Y_w2id)))
        training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = training_decoder,
            impute_finished = True,
            maximum_iterations = tf.reduce_max(self.Y_seq_len))
        self.training_logits = training_decoder_output.rnn_output
        self.init_prefix_state = training_final_state

    def add_decoder_for_inference(self):   
        decoder_embedding = tf.get_variable('decoder_embedding')
        self.beam_f = (lambda ids: tf.concat([tf.nn.embedding_lookup(decoder_embedding, ids), 
                                    tf.tile(tf.expand_dims(self.z, 1), 
                                            [1,int(tf.nn.embedding_lookup(decoder_embedding, ids).get_shape()[1]), 1]) if len(ids.get_shape()) !=1 
                                             else self.z], -1))

        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = self.beam_f, 
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._y_eos,
            initial_state = tf.contrib.seq2seq.tile_batch(self.decoder_init_state, self.beam_width),#self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32),
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.predicting_ids = predicting_decoder_output.predicted_ids
        self.score = predicting_decoder_output.beam_search_decoder_output.scores
        
    def add_decoder_for_prefix_inference(self):   
        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell = self.decoder_cell,
            embedding = self.beam_f,
            start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]),
            end_token = self._y_eos,
            initial_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width),
            beam_width = self.beam_width,
            output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True),
            length_penalty_weight = self.beam_penalty)
        
        self.prefix_go = tf.placeholder(tf.int32, [None])
        prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width])
        prefix_emb = self.beam_f(prefix_go_beam)
        predicting_decoder._start_inputs = prefix_emb
        predicting_prefix_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = predicting_decoder,
            impute_finished = False,
            maximum_iterations = self.max_infer_length)
        self.predicting_prefix_ids = predicting_prefix_decoder_output.predicted_ids
        self.prefix_score = predicting_prefix_decoder_output.beam_search_decoder_output.scores

    def add_backward_path(self):
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.reconstruct_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks)
        self.all_reconstruct_loss = tf.reduce_sum(tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks,
                                                     average_across_timesteps=False))
        self.kl_loss = tf.reduce_mean(-0.5 * tf.reduce_sum(1 + self.z_lgs2 - tf.square(self.z_mu) - tf.exp(self.z_lgs2), 1))
        self.loss = self.reconstruct_loss + self.B * self.latent_weight * self.kl_loss + self.aux_loss
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip)
        self.learning_rate = tf.constant(self.lr)
        self.learning_rate = self.get_learning_rate_decay(self.decay_scheme)  # decay
        self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step)

    def register_symbols(self):
        self._x_go = self.dp.X_w2id['<GO>']
        self._x_eos = self.dp.X_w2id['<EOS>']
        self._x_pad = self.dp.X_w2id['<PAD>']
        self._x_unk = self.dp.X_w2id['<UNK>']
        
        self._y_go = self.dp.Y_w2id['<GO>']
        self._y_eos = self.dp.Y_w2id['<EOS>']
        self._y_pad = self.dp.Y_w2id['<PAD>']
        self._y_unk = self.dp.Y_w2id['<UNK>']
    
    def infer(self, input_word):
        if self.is_jieba:
            input_word = list(jieba.cut(input_word))
        if self.reverse:
            input_word = input_word[::-1]
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        out_indices = self.sess.run(self.predicting_ids, {
            self.X: [input_indices], self.X_seq_len: [len(input_indices)], self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ' '.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def prefix_infer(self, input_word, prefix):
        if self.is_jieba:
            input_word = list(jieba.cut(input_word))
            prefix = list(jieba.cut(prefix))
        input_indices_X = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        input_indices_Y = [self.dp.Y_w2id.get(char, self._y_unk) for char in prefix]
        
        prefix_go = []
        prefix_go.append(input_indices_Y[-1]) 
        out_indices, scores = self.sess.run([self.predicting_prefix_ids, self.prefix_score], {
            self.X: [input_indices_X], self.X_seq_len: [len(input_indices_X)], self.Y:[input_indices_Y], self.Y_seq_len:[len(input_indices_Y)],
            self.prefix_go: prefix_go, self.input_keep_prob:1, self.output_keep_prob:1})
        
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
                if self.reverse:
                    ot = ot[::-1]
            if self.reverse:
                output_str = ' '.join([self.dp.Y_id2w.get(i, u'&') for i in ot]) + prefix
            else:
                output_str = prefix + ' '.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def xToz(self, input_word):
        if self.is_jieba:
            input_word = list(jieba.cut(input_word))
        if self.reverse:
            input_word = input_word[::-1]
        input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word]
        z = self.sess.run(self.z, {self.X: [input_indices], self.X_seq_len: [len(input_indices)], self.output_keep_prob:1, self.input_keep_prob:1})
        return z
    # end method
    
    def zTox(self, z):
        out_indices = self.sess.run(self.predicting_ids, {self.batch_size:z.shape[0],
            self.z:z, self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[-1]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[0,:,idx]
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ' '.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def generate(self, batch_size = 6):
        out_indices = self.sess.run(self.predicting_ids, { self.batch_size:batch_size,
            self.z:np.random.randn(batch_size, self.latent_dim), self.output_keep_prob:1, self.input_keep_prob:1})
        outputs = []
        for idx in range(out_indices.shape[0]):
            eos_id = self.dp.Y_w2id['<EOS>']
            ot = out_indices[idx,:,0]   # The 0th beam of each batch 
            if eos_id in ot:
                ot = ot.tolist()
                ot = ot[:ot.index(eos_id)]
            if self.reverse:
                ot = ot[::-1]
            output_str = ' '.join([self.dp.Y_id2w.get(i, u'&') for i in ot])
            outputs.append(output_str)
        return outputs
    
    def restore(self, path):
        self.saver.restore(self.sess, path)
        print('restore %s success' % path)
        
    def get_learning_rate_decay(self, decay_scheme='luong234'):
        num_train_steps = self.dp.num_steps
        if decay_scheme == "luong10":
            start_decay_step = int(num_train_steps / 2)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 10)  # decay 10 times
            decay_factor = 0.5
        else:
            start_decay_step = int(num_train_steps * 2 / 3)
            remain_steps = num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 4)  # decay 4 times
            decay_factor = 0.5
        return tf.cond(
            self.global_step < start_decay_step,
            lambda: self.learning_rate,
            lambda: tf.train.exponential_decay(
                self.learning_rate,
                (self.global_step - start_decay_step),
                decay_steps, decay_factor, staircase=True),
            name="learning_rate_decay_cond")
    
    def setup_summary(self):
        train_loss = tf.Variable(0.)
        tf.summary.scalar('Train_loss', train_loss) 
        train_aux_loss = tf.Variable(0.)
        tf.summary.scalar('Train_aux_loss', train_aux_loss) 
        train_KL_loss = tf.Variable(0.)
        tf.summary.scalar('Train_KL_loss', train_KL_loss)
        train_r_loss = tf.Variable(0.)
        tf.summary.scalar('Train_R_loss', train_r_loss)
        test_loss = tf.Variable(0.)
        tf.summary.scalar('Test_loss', test_loss) 
        test_aux_loss = tf.Variable(0.)
        tf.summary.scalar('Test_aux_loss', test_aux_loss)
        test_KL_loss = tf.Variable(0.)
        tf.summary.scalar('Test_KL_loss', test_KL_loss)
        test_r_loss = tf.Variable(0.)
        tf.summary.scalar('Test_R_loss', test_r_loss)
        beta = tf.Variable(0.)
        tf.summary.scalar('Beta', beta)
        tf.summary.scalar('lr_rate', self.learning_rate)
        tf.summary.histogram("z_mu", self.z_mu)
        tf.summary.histogram("z_ls2", self.z_lgs2)
        tf.summary.histogram("z", self.z)
        
        summary_vars = [train_loss, train_aux_loss, train_KL_loss, train_r_loss, test_loss, test_aux_loss, test_KL_loss, test_r_loss, beta]
        summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))]
        update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))]
        summary_op = tf.summary.merge_all()
        return summary_placeholders, update_ops, summary_op

In [None]:
class VRAE_DP:
    def __init__(self, X_indices, Y_indices, X_w2id, Y_w2id, BATCH_SIZE=256, n_epoch=15):
        assert len(X_indices) == len(Y_indices)
        num_test = int(len(X_indices) * 0.1)
        self.n_epoch = n_epoch
        self.X_train = np.array(X_indices[num_test:])
        self.Y_train = np.array(Y_indices[num_test:])
        self.X_test = np.array(X_indices[:num_test])
        self.Y_test = np.array(Y_indices[:num_test])
        self.num_batch = int(len(self.X_train) / BATCH_SIZE)
        self.num_steps = self.num_batch * self.n_epoch
        self.batch_size = BATCH_SIZE
        self.X_w2id = X_w2id
        self.X_id2w = dict(zip(X_w2id.values(), X_w2id.keys()))
        self.Y_w2id = Y_w2id
        self.Y_id2w = dict(zip(Y_w2id.values(), Y_w2id.keys()))
        self._x_pad = self.X_w2id['<PAD>']
        self._y_pad = self.Y_w2id['<PAD>']
        print('Train_data: %d | Test_data: %d | Batch_size: %d | Num_batch: %d | X_vocab_size: %d | Y_vocab_size: %d' % (len(self.X_train), len(self.X_test), BATCH_SIZE, self.num_batch, len(self.X_w2id), len(self.Y_w2id)))
        
    def next_batch(self, X, Y):
        r = np.random.permutation(len(X))
        X = X[r]
        Y = Y[r]
        for i in range(0, len(X) - len(X) % self.batch_size, self.batch_size):
            X_batch = X[i : i + self.batch_size]
            Y_batch = Y[i : i + self.batch_size]
            padded_X_batch, X_batch_lens = self.pad_sentence_batch(X_batch, self._x_pad)
            padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(Y_batch, self._y_pad)
            yield (np.array(padded_X_batch),
                   np.array(padded_Y_batch),
                   X_batch_lens,
                   Y_batch_lens)
    
    def sample_test_batch(self):
        padded_X_batch, X_batch_lens = self.pad_sentence_batch(self.X_test[: self.batch_size], self._x_pad)
        padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(self.Y_test[: self.batch_size], self._y_pad)
        return np.array(padded_X_batch), np.array(padded_Y_batch), X_batch_lens, Y_batch_lens
        
    def pad_sentence_batch(self, sentence_batch, pad_int):
        padded_seqs = []
        seq_lens = []
        max_sentence_len = max([len(sentence) for sentence in sentence_batch])
        for sentence in sentence_batch:
            padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
            seq_lens.append(len(sentence))
        return padded_seqs, seq_lens
    
   


In [None]:
import scipy.interpolate as si
from scipy import interpolate


def BetaGenerator(epoches, beta_decay_period, beta_decay_offset):
    points = [[0,0], [0, beta_decay_offset],[0, beta_decay_offset + 0.33 * beta_decay_period], [1, beta_decay_offset + 0.66*beta_decay_period],[1, beta_decay_offset + beta_decay_period], [1, epoches] ];
    points = np.array(points)
    x = points[:,0]
    y = points[:,1]
    t = range(len(points))
    ipl_t = np.linspace(0.0, len(points) - 1, 100)
    x_tup = si.splrep(t, x, k=3)
    y_tup = si.splrep(t, y, k=3)
    x_list = list(x_tup)
    xl = x.tolist()
    x_list[1] = xl + [0.0, 0.0, 0.0, 0.0]
    y_list = list(y_tup)
    yl = y.tolist()
    y_list[1] = yl + [0.0, 0.0, 0.0, 0.0]
    x_i = si.splev(ipl_t, x_list)
    y_i = si.splev(ipl_t, y_list)
    return interpolate.interp1d(y_i, x_i)

class VRAE_util:
    def __init__(self, dp, model, display_freq=3):
        self.display_freq = display_freq
        self.dp = dp
        self.model = model
        self.summary_cnt = 0
        self.betaG = BetaGenerator(self.dp.n_epoch*self.dp.num_batch, self.model.beta_decay_period*self.dp.num_batch, self.model.beta_decay_offset*self.dp.num_batch)
        
    def train(self, epoch):
        avg_loss = 0.0
        avg_all_loss = 0.0
        avg_r_loss = 0.0
        avg_kl_loss = 0.0
        avg_aux_loss = 0.0
        tic = time.time()
        X_test_batch, Y_test_batch, X_test_batch_lens, Y_test_batch_lens = self.dp.sample_test_batch()
        for local_step, (X_train_batch, Y_train_batch, X_train_batch_lens, Y_train_batch_lens) in enumerate(
            self.dp.next_batch(self.dp.X_train, self.dp.Y_train)):
            beta = 0.001 + self.betaG(self.model.step) # add small value to avoid points to scatter
            self.model.step, _, loss, all_loss, r_loss, kl_loss, aux_loss = self.model.sess.run([self.model.global_step, self.model.train_op, 
                                                            self.model.loss, self.model.all_reconstruct_loss, self.model.reconstruct_loss, self.model.kl_loss, self.model.aux_loss], 
                                          {self.model.X: X_train_batch,
                                           self.model.Y: Y_train_batch,
                                           self.model.X_seq_len: X_train_batch_lens,
                                           self.model.Y_seq_len: Y_train_batch_lens,
                                           self.model.output_keep_prob:self.model._output_keep_prob,
                                           self.model.input_keep_prob:self.model._input_keep_prob,
                                          self.model.B:beta})
            avg_loss += loss
            avg_all_loss += all_loss
            avg_r_loss += r_loss
            avg_kl_loss += kl_loss
            avg_aux_loss += aux_loss
            # summary
            if local_step % 10 == 0:
                self.summary_cnt += 1
                val_loss, val_r_loss, val_kl_loss, val_aux_loss = self.model.sess.run([self.model.loss, self.model.reconstruct_loss, self.model.kl_loss, self.model.aux_loss], 
                                               {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1,
                                                     self.model.B:beta})
                stats = [avg_loss/(local_step+1), avg_aux_loss/(local_step+1), avg_kl_loss/(local_step+1), avg_r_loss/(local_step+1),
                         val_loss, val_aux_loss, val_kl_loss, val_r_loss, beta]
                for i in range(len(stats)):
                    self.model.sess.run(self.model.update_ops[i], feed_dict={
                        self.model.summary_placeholders[i]: float(stats[i])
                    })
                summary_str = self.model.sess.run(self.model.summary_op, {
                    self.model.X: X_test_batch, 
                    self.model.X_seq_len: X_test_batch_lens,
                    self.model.output_keep_prob:1,
                    self.model.input_keep_prob:1})
                self.summary_writer.add_summary(summary_str, self.summary_cnt)
                
            if local_step % (self.dp.num_batch / self.display_freq) == 0:
                val_loss, val_r_loss, val_kl_loss, val_aux_loss = self.model.sess.run([self.model.loss, self.model.reconstruct_loss, self.model.kl_loss, self.model.aux_loss], 
                                               {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1,
                                                     self.model.B:beta})
                print("Epoch %d/%d | Batch %d/%d | Train_loss: %.3f = %.3f(%.3f) + %.3f +%.3f | Test_loss: %.3f = %.3f + %.3f + %.3f | Time_cost:%.3f" % (epoch, self.n_epoch, local_step, self.dp.num_batch, 
                                                                                                                                       avg_loss / (local_step + 1),
                                                                                                                                       avg_r_loss / (local_step + 1),
                                                                                                                                       avg_all_loss / (local_step + 1),
                                                                                                                                       avg_kl_loss / (local_step + 1),
                                                                                                                                       avg_aux_loss / (local_step + 1),
                                                                                                                                       val_loss, val_r_loss, val_kl_loss, val_aux_loss, time.time()-tic))
                self.cal()
                tic = time.time()
        return avg_loss / self.dp.num_batch, avg_all_loss / self.dp.num_batch, avg_r_loss / self.dp.num_batch, avg_kl_loss / self.dp.num_batch, avg_aux_loss / self.dp.num_batch
    
    def test(self):
        avg_loss = 0.0
        avg_r_loss = 0.0
        avg_all_loss = 0.0
        avg_kl_loss = 0.0
        avg_aux_loss = 0.0
        beta = 0.001 + self.betaG(self.model.step) # add small value to avoid points to scatter
        for local_step, (X_test_batch, Y_test_batch, X_test_batch_lens, Y_test_batch_lens) in enumerate(
            self.dp.next_batch(self.dp.X_test, self.dp.Y_test)):
            val_loss, val_all_loss, val_r_loss, val_kl_loss, val_aux_loss = self.model.sess.run([self.model.loss, self.model.all_reconstruct_loss, self.model.reconstruct_loss, self.model.kl_loss, self.model.aux_loss], 
                                               {self.model.X: X_test_batch,
                                                     self.model.Y: Y_test_batch,
                                                     self.model.X_seq_len: X_test_batch_lens,
                                                     self.model.Y_seq_len: Y_test_batch_lens,
                                                     self.model.output_keep_prob:1,
                                                     self.model.input_keep_prob:1,
                                                     self.model.B:beta})
            avg_loss += val_loss
            avg_all_loss += val_all_loss
            avg_r_loss += val_r_loss
            avg_kl_loss += val_kl_loss
            avg_aux_loss += val_aux_loss
        return avg_loss / (local_step + 1), avg_all_loss / (local_step + 1), avg_r_loss / (local_step + 1), avg_kl_loss / (local_step + 1) , avg_aux_loss / (local_step + 1)
    
    def fit(self, train_dir, is_bleu=False):
        self.n_epoch = self.dp.n_epoch
        test_loss_list = []
        train_loss_list = []
        test_r_loss_list = []
        test_aux_loss_list = []
        train_r_loss_list = []
        test_kl_loss_list = []
        train_kl_loss_list = []
        train_aux_loss_list = []
        time_cost_list = []
        bleu_list = []
        timestamp = str(int(time.time()))
        out_dir = train_dir#os.path.abspath(os.path.join(train_dir, "runs", 'var_2_1'))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        print("Writing to %s" % out_dir)
        checkpoint_prefix = os.path.join(out_dir, "model")
        self.summary_writer = tf.summary.FileWriter(os.path.join(out_dir, 'Summary'), self.model.sess.graph)
        for epoch in range(1, self.n_epoch+1):
            tic = time.time()
            train_loss, train_all_loss, train_r_loss, train_kl_loss, train_aux_loss = self.train(epoch)
            train_loss_list.append(train_loss)
            train_r_loss_list.append(train_r_loss)
            train_kl_loss_list.append(train_kl_loss)
            train_aux_loss_list.append(train_aux_loss)
            
            test_loss, test_all_loss, test_r_loss, test_kl_loss, test_aux_loss = self.test()
            test_loss_list.append(test_loss)
            test_r_loss_list.append(test_r_loss)
            test_kl_loss_list.append(test_kl_loss)
            test_aux_loss_list.append(test_aux_loss)
            toc = time.time()
            time_cost_list.append((toc - tic))
            if is_bleu:
                bleu = self.test_bleu()
                bleu_list.append(bleu)
                print("Epoch %d/%d | Train_loss: %.3f = %.3f(%.3f) + %.3f + %.3f | Test_loss: %.3f = %.3f(%.3f) + %.3f + %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, train_r_loss, train_all_loss, train_kl_loss, train_aux_loss, test_loss, test_r_loss, test_all_loss, test_kl_loss, test_aux_loss, bleu))
            else:
                bleu = 0.0
                print("Epoch %d/%d | Train_loss: %.3f = %.3f(%.3f) + %.3f + %.3f | Test_loss: %.3f = %.3f(%.3f) + %.3f + %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, train_r_loss, train_all_loss, train_kl_loss, train_aux_loss, test_loss, test_r_loss, test_all_loss, test_kl_loss, test_aux_loss, bleu))
            if self.model.is_save:
                cPickle.dump((train_loss_list, train_r_loss_list, train_kl_loss_list, train_aux_loss_list, test_loss_list, test_r_loss_list, test_kl_loss_list, test_aux_loss,time_cost_list, bleu_list), open(os.path.join(out_dir,"res.pkl"),'wb'))
                path = self.model.saver.save(self.model.sess, checkpoint_prefix, global_step=epoch)
                print("Saved model checkpoint to %s" % path)
    
    def show(self, sent, id2w):
        return " ".join([id2w.get(idx, u'&') for idx in sent])
    
    def cal(self, n_example=5):
        train_n_example = int(n_example / 2)
        test_n_example = n_example - train_n_example
        for _ in range(test_n_example):
            example = self.show(self.dp.X_test[_], self.dp.X_id2w)
            y = self.show(self.dp.Y_test[_], self.dp.Y_id2w)
            o = self.model.infer(example)[0]
            print('Test_Input: %s | Output: %s | GroundTruth: %s' % (example, o, y))
        for _ in range(train_n_example):
            example = self.show(self.dp.X_train[_], self.dp.X_id2w)
            y = self.show(self.dp.Y_train[_], self.dp.Y_id2w)
            o = self.model.infer(example)[0]
            print('Train_Input: %s | Output: %s | GroundTruth: %s' % (example, o, y)) 
        o = self.model.generate()
        print('generate:')
        for oo in o:
            print('【',oo,'】')
        print("")
        
    def test_bleu(self, N=300, gram=4):
        all_score = []
        for i in range(N):
            input_indices = self.show(self.dp.X_test[i], self.dp.X_id2w)
            o = self.model.infer(input_indices)[0]
            refer4bleu = [[' '.join([self.dp.Y_id2w.get(w, u'&') for w in self.dp.Y_test[i]])]]
            candi = [' '.join(w for w in o)]
            score = BLEU(candi, refer4bleu, gram=gram)
            all_score.append(score)
        return np.mean(all_score)
    
    def show_res(self, path):
        res = cPickle.load(open(path))
        plt.figure(1)
        plt.title('The train results') 
        l1, = plt.plot(res[0], 'g')
        l2, = plt.plot(res[1], 'r')
        l3, = plt.plot(res[2], 'b')
        plt.legend(handles = [l1, l2, l3], labels = ["Train_loss","Train_r_loss","Train_kl_loss"], loc = 'best')
        plt.show()
        
        plt.figure(1)
        plt.title('The test results') 
        l4, = plt.plot(res[3], 'g')
        l5, = plt.plot(res[4], 'r')
        l6, = plt.plot(res[5], 'r')
        l7, = plt.plot(res[-1], 'b')
        plt.legend(handles = [l4, l5, l6, l7], labels = ["Test_loss","Test_r_loss","Test_kl_loss","BLEU"], loc = 'best')
        plt.show()
        
    def test_all(self, path, epoch_range, is_bleu=True):
        val_loss_list = []
        bleu_list = []
        for i in range(epoch_range[0], epoch_range[-1]):
            self.model.restore(path + str(i))
            val_loss = self.test()
            val_loss_list.append(val_loss)
            if is_bleu:
                bleu_score = self.test_bleu()
                bleu_list.append(bleu_score)
        plt.figure(1)
        plt.title('The results') 
        l1, = plt.plot(val_loss_list,'r')
        l2, = plt.plot(bleu_list,'b')
        plt.legend(handles = [l1, l2], labels = ["Test_loss","BLEU"], loc = 'best')
        plt.show()
        
    

In [None]:
w2id, id2w = cPickle.load(open("Data/SQUAD/w2id_id2w.pkl",'rb'))
data = cPickle.load(open("Data/SQUAD/X_indices.pkl",'rb'))
data_X = [o[:-1] for o in data]

for t in random.sample(data, 3):
    print(' '.join(id2w[idx] for idx in t))

In [None]:

dp = VRAE_DP(data_X, data, w2id, w2id)
g = tf.Graph() 
sess = tf.Session(graph=g) 
with sess.as_default():
    with sess.graph.as_default():
        model = VRAE(
            dp = dp,
            rnn_size = 1024,
            latent_dim = 16,
            n_layers = 1,
            var = 2.5,
            encoder_embedding_dim = 512,
            decoder_embedding_dim = 512,
            cell_type='gru',
            max_infer_length=35,
            residual=False,
            is_jieba=False,
            is_save=True,
            beam_width=1,
            sess= sess
        )
        util = VRAE_util(dp=dp, model=model)
        util.fit(train_dir='Models/')  #Training
        

In [None]:
model.generate()

In [None]:
model.restore('Models/model-10')

In [None]:
#model.load()
for i, x in enumerate(data_X):
    x_str = ' '.join([id2w[ids] for ids in x])
    print(x_str)
    print(model.infer(x_str.split()))
    print('')
    if i == 50:
        break

In [None]:
#model.load()
model.restore('Models/model-10')
data = cPickle.load(open("Data/SQUAD/X_indices.pkl",'rb'))
data_X = [o[:-1] for o in data]
with open('vae_10.txt', 'w', encoding='utf-8') as fout:
    for i, x in enumerate(data_X):
        x_str = ' '.join([id2w[ids] for ids in x])
        #print(x_str)
        o_str = model.infer(x_str.split())[0]
        fout.write('{}\n'.format(o_str))
        #print('')
        if i % 100 == 0:
            print(i)
        #if i == 50:
        #    break