In [14]:
from utils.load_data import load_sentences, trim_and_pad_audio_data
import utils.layers as layers
from models.emotion_recognition import ClassifyEmotion
from collections import Counter
from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np 
import matplotlib.pyplot as plt 
%matplotlib inline

In [5]:
%load_ext autoreload
%autoreload 2

# Emotion Recognition

In [None]:
from sklearn.preprocessing import LabelEncoder, OneHotEncoder 

data, label, metadata = load_sentences()

In [10]:
int_labels = LabelEncoder().fit_transform(label)
int_labels = int_labels.reshape(len(int_labels), 1)
labels = OneHotEncoder().fit_transform(int_labels).toarray() 

# Split into train and test
train_x, val_x, train_y, val_y = train_test_split(data, labels,
                                                  test_size=0.2)

In [11]:
batch_size = 128 
batched_data = [train_x[i:i+batch_size] for i in range(0, len(train_x), batch_size)]
batched_labels = [train_y[i:i+batch_size] for i in range(0, len(train_y), batch_size)]

print(len(batched_data),len(batched_labels),batched_data[0].shape, batched_labels[0].shape)

47 47 (128, 2000, 26) (128, 6)


In [17]:
num_features = 26 
num_classes = 6
learning_rate = 1e-4
epochs = 5

tf.reset_default_graph()
with tf.name_scope('inputs'):
    x = tf.placeholder(shape=(None, None, num_features), dtype=tf.float32)
    y = tf.placeholder(shape=(None, num_classes), dtype=tf.float32)
    
concat_lstm1 = layers.blstm(0, 64, x, return_all=True)
# with tf.variable_scope('lstm1'):
#     cell_fw = tf.nn.rnn_cell.LSTMCell(64, state_is_tuple=True)
#     cell_bw = tf.nn.rnn_cell.LSTMCell(64, state_is_tuple=True)
#     outputs_1, states_1 = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs=x, dtype=tf.float32)
#     concat_lstm1 = tf.concat(outputs_1, 2)
#     print(concat_lstm1.shape)

concat_lstm2 = layers.blstm(1, 64, concat_lstm1, return_all=False)
# with tf.variable_scope('lstm2'):
#     cell_fw2 = tf.nn.rnn_cell.LSTMCell(64, state_is_tuple=True)
#     cell_bw2 = tf.nn.rnn_cell.LSTMCell(64, state_is_tuple=True)
#     outputs_2, states_2 = tf.nn.bidirectional_dynamic_rnn(cell_fw2, cell_bw2, inputs=concat_lstm1, dtype=tf.float32)
#     concat_lstm2 = tf.concat(outputs_2, 2)
#     concat_lstm2 = tf.transpose(concat_lstm2, [1,0,2])[-1]
#     print(concat_lstm2.shape)

with tf.name_scope('dense'):
    dense_0 = tf.layers.dense(concat_lstm2, 512, activation=tf.nn.tanh)
    print(dense_0.shape)
    logits= tf.layers.dense(dense_0, num_classes)
    print(logits.shape)

with tf.name_scope('loss'):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits))
    
step = tf.train.AdamOptimizer(learning_rate).minimize(loss)

tf.set_random_seed(0)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(epochs):
        batch_losses = [] 
        for i, batch in enumerate(batched_data):
            labels_batch = batched_labels[i]
            feed_dict = {
                x: batch,
                y: labels_batch
            }
            _, err = sess.run([step, loss], feed_dict=feed_dict)
            batch_losses.append(err)
        if epoch % 1 == 0: 
            print(np.mean(batch_losses))

(?, ?, 128)
(?, 128)
(?, 512)
(?, 6)


KeyboardInterrupt: 

# Conversation Generation

In [1]:
# path to training data
training_data_path = 'data/conversations_lenmax22_formersents2_with_former'

# path to all_words
all_words_path = 'data/all_words.txt'

# training parameters 
CHECKPOINT = True
train_model_path = 'model'
train_model_name = 'model-55'
start_epoch = 56
start_batch = 0
batch_size = 25

# for RL training
training_type = 'normal' # 'normal' for seq2seq training, 'pg' for policy gradient
reversed_model_path = 'Adam_encode22_decode22_reversed-maxlen22_lr0.0001_batch25_wordthres6'
reversed_model_name = 'model-63'

# data reader shuffle index list
load_list = False
index_list_file = 'data/shuffle_index_list'
cur_train_index = start_batch * batch_size

# word count threshold
WC_threshold = 20
reversed_WC_threshold = 6

# dialog simulation turns
MAX_TURNS = 10

Cornell Movie-Dialogs Corpus

In [None]:
pip install -r requirements.txt

In [None]:
./script/parse.sh

# coding=utf-8

from __future__ import print_function
import pickle
import codecs
import re
import os
import time
import numpy as np
import config

def preProBuildWordVocab(word_count_threshold=5, all_words_path=config.all_words_path):
    # borrowed this function from NeuralTalk

    if not os.path.exists(all_words_path):
        parse_all_words(all_words_path)

    corpus = open(all_words_path, 'r').read().split('\n')[:-1]
    captions = np.asarray(corpus, dtype=np.object)

    captions = map(lambda x: x.replace('.', ''), captions)
    captions = map(lambda x: x.replace(',', ''), captions)
    captions = map(lambda x: x.replace('"', ''), captions)
    captions = map(lambda x: x.replace('\n', ''), captions)
    captions = map(lambda x: x.replace('?', ''), captions)
    captions = map(lambda x: x.replace('!', ''), captions)
    captions = map(lambda x: x.replace('\\', ''), captions)
    captions = map(lambda x: x.replace('/', ''), captions)

    print('preprocessing word counts and creating vocab based on word count threshold %d' % (word_count_threshold))
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.lower().split(' '):
           word_counts[w] = word_counts.get(w, 0) + 1
    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('filtered words from %d to %d' % (len(word_counts), len(vocab)))

    ixtoword = {}
    ixtoword[0] = '<pad>'
    ixtoword[1] = '<bos>'
    ixtoword[2] = '<eos>'
    ixtoword[3] = '<unk>'

    wordtoix = {}
    wordtoix['<pad>'] = 0
    wordtoix['<bos>'] = 1
    wordtoix['<eos>'] = 2
    wordtoix['<unk>'] = 3

    for idx, w in enumerate(vocab):
        wordtoix[w] = idx+4
        ixtoword[idx+4] = w

    word_counts['<pad>'] = nsents
    word_counts['<bos>'] = nsents
    word_counts['<eos>'] = nsents
    word_counts['<unk>'] = nsents

    bias_init_vector = np.array([1.0 * word_counts[ixtoword[i]] for i in ixtoword])
    bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies
    bias_init_vector = np.log(bias_init_vector)
    bias_init_vector -= np.max(bias_init_vector) # shift to nice numeric range

    return wordtoix, ixtoword, bias_init_vector

def parse_all_words(all_words_path):
    raw_movie_lines = open('data/movie_lines.txt', 'r', encoding='utf-8', errors='ignore').read().split('\n')[:-1]

    with codecs.open(all_words_path, "w", encoding='utf-8', errors='ignore') as f:
        for line in raw_movie_lines:
            line = line.split(' +++$+++ ')
            utterance = line[-1]
            f.write(utterance + '\n')

""" Extract only the vocabulary part of the data """
def refine(data):
    words = re.findall("[a-zA-Z'-]+", data)
    words = ["".join(word.split("'")) for word in words]
    # words = ["".join(word.split("-")) for word in words]
    data = ' '.join(words)
    return data

if __name__ == '__main__':
    parse_all_words(config.all_words_path)

    raw_movie_lines = open('data/movie_lines.txt', 'r', encoding='utf-8', errors='ignore').read().split('\n')[:-1]
    
    utterance_dict = {}
    with codecs.open('data/tokenized_all_words.txt', "w", encoding='utf-8', errors='ignore') as f:
        for line in raw_movie_lines:
            line = line.split(' +++$+++ ')
            line_ID = line[0]
            utterance = line[-1]
            utterance_dict[line_ID] = utterance
            utterance = " ".join([refine(w) for w in utterance.lower().split()])
            f.write(utterance + '\n')
    pickle.dump(utterance_dict, open('data/utterance_dict', 'wb'), True)

In [None]:
./script/train.sh
#-*- coding: utf-8 -*-

from __future__ import print_function

from gensim.models import KeyedVectors
from data_reader import Data_Reader
import data_parser
import config

from model import Seq2Seq_chatbot
import tensorflow as tf
import numpy as np

import os
import time


### Global Parameters ###
checkpoint = config.CHECKPOINT
model_path = config.train_model_path
model_name = config.train_model_name
start_epoch = config.start_epoch

word_count_threshold = config.WC_threshold

### Train Parameters ###
dim_wordvec = 300
dim_hidden = 1000

n_encode_lstm_step = 22 + 22
n_decode_lstm_step = 22

epochs = 500
batch_size = 100
learning_rate = 0.0001


def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        lengths.append(len(x))

    num_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x

def train():
    wordtoix, ixtoword, bias_init_vector = data_parser.preProBuildWordVocab(word_count_threshold=word_count_threshold)
    word_vector = KeyedVectors.load_word2vec_format('model/word_vector.bin', binary=True)

    model = Seq2Seq_chatbot(
            dim_wordvec=dim_wordvec,
            n_words=len(wordtoix),
            dim_hidden=dim_hidden,
            batch_size=batch_size,
            n_encode_lstm_step=n_encode_lstm_step,
            n_decode_lstm_step=n_decode_lstm_step,
            bias_init_vector=bias_init_vector,
            lr=learning_rate)

    train_op, tf_loss, word_vectors, tf_caption, tf_caption_mask, inter_value = model.build_model()

    saver = tf.train.Saver(max_to_keep=100)

    sess = tf.InteractiveSession()
    
    if checkpoint:
        print("Use Model {}.".format(model_name))
        saver.restore(sess, os.path.join(model_path, model_name))
        print("Model {} restored.".format(model_name))
    else:
        print("Restart training...")
        tf.global_variables_initializer().run()

    dr = Data_Reader()

    for epoch in range(start_epoch, epochs):
        n_batch = dr.get_batch_num(batch_size)
        for batch in range(n_batch):
            start_time = time.time()

            batch_X, batch_Y = dr.generate_training_batch(batch_size)

            for i in range(len(batch_X)):
                batch_X[i] = [word_vector[w] if w in word_vector else np.zeros(dim_wordvec) for w in batch_X[i]]
                # batch_X[i].insert(0, np.random.normal(size=(dim_wordvec,))) # insert random normal at the first step
                if len(batch_X[i]) > n_encode_lstm_step:
                    batch_X[i] = batch_X[i][:n_encode_lstm_step]
                else:
                    for _ in range(len(batch_X[i]), n_encode_lstm_step):
                        batch_X[i].append(np.zeros(dim_wordvec))

            current_feats = np.array(batch_X)

            current_captions = batch_Y
            current_captions = map(lambda x: '<bos> ' + x, current_captions)
            current_captions = map(lambda x: x.replace('.', ''), current_captions)
            current_captions = map(lambda x: x.replace(',', ''), current_captions)
            current_captions = map(lambda x: x.replace('"', ''), current_captions)
            current_captions = map(lambda x: x.replace('\n', ''), current_captions)
            current_captions = map(lambda x: x.replace('?', ''), current_captions)
            current_captions = map(lambda x: x.replace('!', ''), current_captions)
            current_captions = map(lambda x: x.replace('\\', ''), current_captions)
            current_captions = map(lambda x: x.replace('/', ''), current_captions)

            for idx, each_cap in enumerate(current_captions):
                word = each_cap.lower().split(' ')
                if len(word) < n_decode_lstm_step:
                    current_captions[idx] = current_captions[idx] + ' <eos>'
                else:
                    new_word = ''
                    for i in range(n_decode_lstm_step-1):
                        new_word = new_word + word[i] + ' '
                    current_captions[idx] = new_word + '<eos>'

            current_caption_ind = []
            for cap in current_captions:
                current_word_ind = []
                for word in cap.lower().split(' '):
                    if word in wordtoix:
                        current_word_ind.append(wordtoix[word])
                    else:
                        current_word_ind.append(wordtoix['<unk>'])
                current_caption_ind.append(current_word_ind)

            current_caption_matrix = pad_sequences(current_caption_ind, padding='post', maxlen=n_decode_lstm_step)
            current_caption_matrix = np.hstack([current_caption_matrix, np.zeros([len(current_caption_matrix), 1])]).astype(int)
            current_caption_masks = np.zeros((current_caption_matrix.shape[0], current_caption_matrix.shape[1]))
            nonzeros = np.array(map(lambda x: (x != 0).sum() + 1, current_caption_matrix))

            for ind, row in enumerate(current_caption_masks):
                row[:nonzeros[ind]] = 1

            if batch % 100 == 0:
                _, loss_val = sess.run(
                        [train_op, tf_loss],
                        feed_dict={
                            word_vectors: current_feats,
                            tf_caption: current_caption_matrix,
                            tf_caption_mask: current_caption_masks
                        })
                print("Epoch: {}, batch: {}, loss: {}, Elapsed time: {}".format(epoch, batch, loss_val, time.time() - start_time))
            else:
                _ = sess.run(train_op,
                             feed_dict={
                                word_vectors: current_feats,
                                tf_caption: current_caption_matrix,
                                tf_caption_mask: current_caption_masks
                            })


        print("Epoch ", epoch, " is done. Saving the model ...")
        saver.save(sess, os.path.join(model_path, 'model'), global_step=epoch)

if __name__ == "__main__":
    train()

In [None]:
./script/test.sh <PATH TO MODEL> <INPUT FILE> <OUTPUT FILE>

#-*- coding: utf-8 -*-

from __future__ import print_function

import re
import os
import time
import sys

sys.path.append("python")
import data_parser
import config

from gensim.models import KeyedVectors
from rl_model import PolicyGradient_chatbot
import tensorflow as tf
import numpy as np

#=====================================================
# Global Parameters
#=====================================================
default_model_path = './model/RL/model-56-3000'
testing_data_path = 'sample_input.txt' if len(sys.argv) <= 2 else sys.argv[2]
output_path = 'sample_output_RL.txt' if len(sys.argv) <= 3 else sys.argv[3]

word_count_threshold = config.WC_threshold

#=====================================================
# Train Parameters
#=====================================================
dim_wordvec = 300
dim_hidden = 1000

n_encode_lstm_step = 22 + 1 # one random normal as the first timestep
n_decode_lstm_step = 22

batch_size = 1

""" Extract only the vocabulary part of the data """
def refine(data):
    words = re.findall("[a-zA-Z'-]+", data)
    words = ["".join(word.split("'")) for word in words]
    # words = ["".join(word.split("-")) for word in words]
    data = ' '.join(words)
    return data

def test(model_path=default_model_path):
    testing_data = open(testing_data_path, 'r').read().split('\n')

    word_vector = KeyedVectors.load_word2vec_format('model/word_vector.bin', binary=True)

    _, ixtoword, bias_init_vector = data_parser.preProBuildWordVocab(word_count_threshold=word_count_threshold)

    model = PolicyGradient_chatbot(
            dim_wordvec=dim_wordvec,
            n_words=len(ixtoword),
            dim_hidden=dim_hidden,
            batch_size=batch_size,
            n_encode_lstm_step=n_encode_lstm_step,
            n_decode_lstm_step=n_decode_lstm_step,
            bias_init_vector=bias_init_vector)

    word_vectors, caption_tf, feats = model.build_generator()

    sess = tf.InteractiveSession()

    saver = tf.train.Saver()
    try:
        print('\n=== Use model', model_path, '===\n')
        saver.restore(sess, model_path)
    except:
        print('\nUse default model\n')
        saver.restore(sess, default_model_path)

    with open(output_path, 'w') as out:
        generated_sentences = []
        bleu_score_avg = [0., 0.]
        for idx, question in enumerate(testing_data):
            print('question =>', question)

            question = [refine(w) for w in question.lower().split()]
            question = [word_vector[w] if w in word_vector else np.zeros(dim_wordvec) for w in question]
            question.insert(0, np.random.normal(size=(dim_wordvec,))) # insert random normal at the first step

            if len(question) > n_encode_lstm_step:
                question = question[:n_encode_lstm_step]
            else:
                for _ in range(len(question), n_encode_lstm_step):
                    question.append(np.zeros(dim_wordvec))

            question = np.array([question]) # 1x22x300
    
            generated_word_index, prob_logit = sess.run([caption_tf, feats['probs']], feed_dict={word_vectors: question})
            generated_word_index = np.array(generated_word_index).reshape(batch_size, n_decode_lstm_step)[0]
            prob_logit = np.array(prob_logit).reshape(batch_size, n_decode_lstm_step, -1)[0]
            # print('generated_word_index.shape', generated_word_index.shape)
            # print('prob_logit.shape', prob_logit.shape)

            # remove <unk> to second high prob. word
            # print('generated_word_index', generated_word_index)
            for i in range(len(generated_word_index)):
                if generated_word_index[i] == 3:
                    sort_prob_logit = sorted(prob_logit[i])
                    # print('max val', sort_prob_logit[-1])
                    # print('second max val', sort_prob_logit[-2])
                    maxindex = np.where(prob_logit[i] == sort_prob_logit[-1])[0][0]
                    secmaxindex = np.where(prob_logit[i] == sort_prob_logit[-2])[0][0]
                    # print('max ind', maxindex, ixtoword[maxindex])
                    # print('second max ind', secmaxindex, ixtoword[secmaxindex])
                    generated_word_index[i] = secmaxindex
            # print('generated_word_index', generated_word_index)

            generated_words = []
            for ind in generated_word_index:
                generated_words.append(ixtoword[ind])

            # generate sentence
            punctuation = np.argmax(np.array(generated_words) == '<eos>') + 1
            generated_words = generated_words[:punctuation]
            generated_sentence = ' '.join(generated_words)

            # modify the output sentence 
            generated_sentence = generated_sentence.replace('<bos> ', '')
            generated_sentence = generated_sentence.replace(' <eos>', '')
            generated_sentence = generated_sentence.replace('--', '')
            generated_sentence = generated_sentence.split('  ')
            for i in range(len(generated_sentence)):
                generated_sentence[i] = generated_sentence[i].strip()
                if len(generated_sentence[i]) > 1:
                    generated_sentence[i] = generated_sentence[i][0].upper() + generated_sentence[i][1:] + '.'
                else:
                    generated_sentence[i] = generated_sentence[i].upper()
            generated_sentence = ' '.join(generated_sentence)
            generated_sentence = generated_sentence.replace(' i ', ' I ')
            generated_sentence = generated_sentence.replace("i'm", "I'm")
            generated_sentence = generated_sentence.replace("i'd", "I'd")
            generated_sentence = generated_sentence.replace("i'll", "I'll")
            generated_sentence = generated_sentence.replace("i'v", "I'v")
            generated_sentence = generated_sentence.replace(" - ", "")

            print('generated_sentence =>', generated_sentence)
            out.write(generated_sentence + '\n')


if __name__ == "__main__":
    if len(sys.argv) > 1:
        test(model_path=sys.argv[1])
    else:
        test()

In [None]:
./script/simulate.sh <PATH TO MODEL> <SIMULATE TYPE> <INPUT FILE> <OUTPUT FILE>

you need to change the training_type parameter in python/config.py

'normal' for seq2seq training, 'pg' for policy gradient

you need to first train with 'normal' for some epochs till stable (at least 30 epoches is highly recommended)

then change the method to 'pg' to optimize the reward function



In [None]:
./script/train_RL.sh

#-*- coding: utf-8 -*-

from __future__ import print_function

import os
import time
import sys
import copy

sys.path.append("python")
from model import Seq2Seq_chatbot
from data_reader import Data_Reader
import data_parser
import config
import re

from gensim.models import KeyedVectors
from rl_model import PolicyGradient_chatbot
from scipy import spatial
import tensorflow as tf
import numpy as np
import math


### Global Parameters ###
checkpoint = config.CHECKPOINT
model_path = config.train_model_path
model_name = config.train_model_name
start_epoch = config.start_epoch
start_batch = config.start_batch

# reversed model
reversed_model_path = config.reversed_model_path
reversed_model_name = config.reversed_model_name

word_count_threshold = config.WC_threshold
r_word_count_threshold = config.reversed_WC_threshold

# dialog simulation turns
max_turns = config.MAX_TURNS

dull_set = ["I don't know what you're talking about.", "I don't know.", "You don't know.", "You know what I mean.", "I know what you mean.", "You know what I'm saying.", "You don't know anything."]

### Train Parameters ###
training_type = config.training_type    # 'normal' for seq2seq training, 'pg' for policy gradient

dim_wordvec = 300
dim_hidden = 1000

n_encode_lstm_step = 22 + 22
n_decode_lstm_step = 22

r_n_encode_lstm_step = 22
r_n_decode_lstm_step = 22

learning_rate = 0.0001
epochs = 500
batch_size = config.batch_size
reversed_batch_size = config.batch_size

def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        lengths.append(len(x))

    num_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x

""" Extract only the vocabulary part of the data """
def refine(data):
    words = re.findall("[a-zA-Z'-]+", data)
    words = ["".join(word.split("'")) for word in words]
    # words = ["".join(word.split("-")) for word in words]
    data = ' '.join(words)
    return data

def make_batch_X(batch_X, n_encode_lstm_step, dim_wordvec, word_vector, noise=False):
    for i in range(len(batch_X)):
        batch_X[i] = [word_vector[w] if w in word_vector else np.zeros(dim_wordvec) for w in batch_X[i]]
        if noise:
            batch_X[i].insert(0, np.random.normal(size=(dim_wordvec,))) # insert random normal at the first step

        if len(batch_X[i]) > n_encode_lstm_step:
            batch_X[i] = batch_X[i][:n_encode_lstm_step]
        else:
            for _ in range(len(batch_X[i]), n_encode_lstm_step):
                batch_X[i].append(np.zeros(dim_wordvec))

    current_feats = np.array(batch_X)
    return current_feats

def make_batch_Y(batch_Y, wordtoix, n_decode_lstm_step):
    current_captions = batch_Y
    current_captions = map(lambda x: '<bos> ' + x, current_captions)
    current_captions = map(lambda x: x.replace('.', ''), current_captions)
    current_captions = map(lambda x: x.replace(',', ''), current_captions)
    current_captions = map(lambda x: x.replace('"', ''), current_captions)
    current_captions = map(lambda x: x.replace('\n', ''), current_captions)
    current_captions = map(lambda x: x.replace('?', ''), current_captions)
    current_captions = map(lambda x: x.replace('!', ''), current_captions)
    current_captions = map(lambda x: x.replace('\\', ''), current_captions)
    current_captions = map(lambda x: x.replace('/', ''), current_captions)

    for idx, each_cap in enumerate(current_captions):
        word = each_cap.lower().split(' ')
        if len(word) < n_decode_lstm_step:
            current_captions[idx] = current_captions[idx] + ' <eos>'
        else:
            new_word = ''
            for i in range(n_decode_lstm_step-1):
                new_word = new_word + word[i] + ' '
            current_captions[idx] = new_word + '<eos>'

    current_caption_ind = []
    for cap in current_captions:
        current_word_ind = []
        for word in cap.lower().split(' '):
            if word in wordtoix:
                current_word_ind.append(wordtoix[word])
            else:
                current_word_ind.append(wordtoix['<unk>'])
        current_caption_ind.append(current_word_ind)

    current_caption_matrix = pad_sequences(current_caption_ind, padding='post', maxlen=n_decode_lstm_step)
    current_caption_matrix = np.hstack([current_caption_matrix, np.zeros([len(current_caption_matrix), 1])]).astype(int)
    current_caption_masks = np.zeros((current_caption_matrix.shape[0], current_caption_matrix.shape[1]))
    nonzeros = np.array(map(lambda x: (x != 0).sum() + 1, current_caption_matrix))

    for ind, row in enumerate(current_caption_masks):
        row[:nonzeros[ind]] = 1

    return current_caption_matrix, current_caption_masks

def index2sentence(generated_word_index, prob_logit, ixtoword):
    # remove <unk> to second high prob. word
    for i in range(len(generated_word_index)):
        if generated_word_index[i] == 3 or generated_word_index[i] <= 1:
            sort_prob_logit = sorted(prob_logit[i])
            curindex = np.where(prob_logit[i] == sort_prob_logit[-2])[0][0]
            count = 1
            while curindex <= 3:
                curindex = np.where(prob_logit[i] == sort_prob_logit[(-2)-count])[0][0]
                count += 1

            generated_word_index[i] = curindex

    generated_words = []
    for ind in generated_word_index:
        generated_words.append(ixtoword[ind])

    # generate sentence
    punctuation = np.argmax(np.array(generated_words) == '<eos>') + 1
    generated_words = generated_words[:punctuation]
    generated_sentence = ' '.join(generated_words)

    # modify the output sentence 
    generated_sentence = generated_sentence.replace('<bos> ', '')
    generated_sentence = generated_sentence.replace('<eos>', '')
    generated_sentence = generated_sentence.replace(' <eos>', '')
    generated_sentence = generated_sentence.replace('--', '')
    generated_sentence = generated_sentence.split('  ')
    for i in range(len(generated_sentence)):
        generated_sentence[i] = generated_sentence[i].strip()
        if len(generated_sentence[i]) > 1:
            generated_sentence[i] = generated_sentence[i][0].upper() + generated_sentence[i][1:] + '.'
        else:
            generated_sentence[i] = generated_sentence[i].upper()
    generated_sentence = ' '.join(generated_sentence)
    generated_sentence = generated_sentence.replace(' i ', ' I ')
    generated_sentence = generated_sentence.replace("i'm", "I'm")
    generated_sentence = generated_sentence.replace("i'd", "I'd")

    return generated_sentence

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def count_rewards(dull_loss, forward_entropy, backward_entropy, forward_target, backward_target, reward_type='pg'):
    ''' args:
            generated_word_indexs:  <type 'numpy.ndarray'>  
                                    word indexs generated by pre-trained model
                                    shape: (batch_size, n_decode_lstm_step)
            inference_feats:        <type 'dict'>  
                                    some features generated during inference
                                    keys:
                                        'probs': 
                                            shape: (n_decode_lstm_step, batch_size, n_words)
                                        'embeds': 
                                            shape: (n_decode_lstm_step, batch_size, dim_hidden)
                                            current word embeddings at each decode stage
                                        'states': 
                                            shape: (n_encode_lstm_step, batch_size, dim_hidden)
                                            LSTM_1's hidden state at each encode stage
    '''

    # normal training, rewards all equal to 1
    if reward_type == 'normal':
        return np.ones([batch_size, n_decode_lstm_step])

    if reward_type == 'pg':
        forward_entropy = np.array(forward_entropy).reshape(batch_size, n_decode_lstm_step)
        backward_entropy = np.array(backward_entropy).reshape(batch_size, n_decode_lstm_step)
        total_loss = np.zeros([batch_size, n_decode_lstm_step])

        for i in range(batch_size):
            # ease of answering
            total_loss[i, :] += dull_loss[i]
    
            # information flow
            # cosine_sim = 1 - spatial.distance.cosine(embeds[0][-1], embeds[1][-1])
            # IF = cosine_sim * (-1)
    
            # semantic coherence
            forward_len = len(forward_target[i].split())
            backward_len = len(backward_target[i].split())
            if forward_len > 0:
                total_loss[i, :] += (np.sum(forward_entropy[i]) / forward_len)
            if backward_len > 0:
                total_loss[i, :] += (np.sum(backward_entropy[i]) / backward_len)

        total_loss = sigmoid(total_loss) * 1.1

        return total_loss

def train():
    global dull_set

    wordtoix, ixtoword, bias_init_vector = data_parser.preProBuildWordVocab(word_count_threshold=word_count_threshold)
    word_vector = KeyedVectors.load_word2vec_format('model/word_vector.bin', binary=True)

    if len(dull_set) > batch_size:
        dull_set = dull_set[:batch_size]
    else:
        for _ in range(len(dull_set), batch_size):
            dull_set.append('')
    dull_matrix, dull_mask = make_batch_Y(
                                batch_Y=dull_set, 
                                wordtoix=wordtoix, 
                                n_decode_lstm_step=n_decode_lstm_step)

    ones_reward = np.ones([batch_size, n_decode_lstm_step])

    g1 = tf.Graph()
    g2 = tf.Graph()

    default_graph = tf.get_default_graph() 

    with g1.as_default():
        model = PolicyGradient_chatbot(
                dim_wordvec=dim_wordvec,
                n_words=len(wordtoix),
                dim_hidden=dim_hidden,
                batch_size=batch_size,
                n_encode_lstm_step=n_encode_lstm_step,
                n_decode_lstm_step=n_decode_lstm_step,
                bias_init_vector=bias_init_vector,
                lr=learning_rate)
        train_op, loss, input_tensors, inter_value = model.build_model()
        tf_states, tf_actions, tf_feats = model.build_generator()
        sess = tf.InteractiveSession()
        saver = tf.train.Saver(max_to_keep=100)
        if checkpoint:
            print("Use Model {}.".format(model_name))
            saver.restore(sess, os.path.join(model_path, model_name))
            print("Model {} restored.".format(model_name))
        else:
            print("Restart training...")
            tf.global_variables_initializer().run()

    r_wordtoix, r_ixtoword, r_bias_init_vector = data_parser.preProBuildWordVocab(word_count_threshold=r_word_count_threshold)
    with g2.as_default():
        reversed_model = Seq2Seq_chatbot(
            dim_wordvec=dim_wordvec,
            n_words=len(r_wordtoix),
            dim_hidden=dim_hidden,
            batch_size=reversed_batch_size,
            n_encode_lstm_step=r_n_encode_lstm_step,
            n_decode_lstm_step=r_n_decode_lstm_step,
            bias_init_vector=r_bias_init_vector,
            lr=learning_rate)
        _, _, word_vectors, caption, caption_mask, reverse_inter = reversed_model.build_model()
        sess2 = tf.InteractiveSession()
        saver2 = tf.train.Saver()
        saver2.restore(sess2, os.path.join(reversed_model_path, reversed_model_name))
        print("Reversed model {} restored.".format(reversed_model_name))


    dr = Data_Reader(cur_train_index=config.cur_train_index, load_list=config.load_list)

    for epoch in range(start_epoch, epochs):
        n_batch = dr.get_batch_num(batch_size)
        sb = start_batch if epoch == start_epoch else 0
        for batch in range(sb, n_batch):
            start_time = time.time()

            batch_X, batch_Y, former = dr.generate_training_batch_with_former(batch_size)

            current_feats = make_batch_X(
                            batch_X=copy.deepcopy(batch_X), 
                            n_encode_lstm_step=n_encode_lstm_step, 
                            dim_wordvec=dim_wordvec,
                            word_vector=word_vector)

            current_caption_matrix, current_caption_masks = make_batch_Y(
                                                                batch_Y=copy.deepcopy(batch_Y), 
                                                                wordtoix=wordtoix, 
                                                                n_decode_lstm_step=n_decode_lstm_step)

            if training_type == 'pg':
                # action: generate batch_size sents
                action_word_indexs, inference_feats = sess.run([tf_actions, tf_feats],
                                                                feed_dict={
                                                                   tf_states: current_feats
                                                                })
                action_word_indexs = np.array(action_word_indexs).reshape(batch_size, n_decode_lstm_step)
                action_probs = np.array(inference_feats['probs']).reshape(batch_size, n_decode_lstm_step, -1)

                actions = []
                actions_list = []
                for i in range(len(action_word_indexs)):
                    action = index2sentence(
                                generated_word_index=action_word_indexs[i], 
                                prob_logit=action_probs[i],
                                ixtoword=ixtoword)
                    actions.append(action)
                    actions_list.append(action.split())

                action_feats = make_batch_X(
                                batch_X=copy.deepcopy(actions_list), 
                                n_encode_lstm_step=n_encode_lstm_step, 
                                dim_wordvec=dim_wordvec,
                                word_vector=word_vector)

                action_caption_matrix, action_caption_masks = make_batch_Y(
                                                                batch_Y=copy.deepcopy(actions), 
                                                                wordtoix=wordtoix, 
                                                                n_decode_lstm_step=n_decode_lstm_step)

                # ease of answering
                dull_loss = []
                for vector in action_feats:
                    action_batch_X = np.array([vector for _ in range(batch_size)])
                    d_loss = sess.run(loss,
                                 feed_dict={
                                    input_tensors['word_vectors']: action_batch_X,
                                    input_tensors['caption']: dull_matrix,
                                    input_tensors['caption_mask']: dull_mask,
                                    input_tensors['reward']: ones_reward
                                })
                    d_loss = d_loss * -1. / len(dull_set)
                    dull_loss.append(d_loss)

                # Information Flow
                pass

                # semantic coherence
                forward_inter = sess.run(inter_value,
                                 feed_dict={
                                    input_tensors['word_vectors']: current_feats,
                                    input_tensors['caption']: action_caption_matrix,
                                    input_tensors['caption_mask']: action_caption_masks,
                                    input_tensors['reward']: ones_reward
                                })
                forward_entropies = forward_inter['entropies']
                former_caption_matrix, former_caption_masks = make_batch_Y(
                                                                batch_Y=copy.deepcopy(former), 
                                                                wordtoix=wordtoix, 
                                                                n_decode_lstm_step=n_decode_lstm_step)
                action_feats = make_batch_X(
                                batch_X=copy.deepcopy(actions_list), 
                                n_encode_lstm_step=r_n_encode_lstm_step, 
                                dim_wordvec=dim_wordvec,
                                word_vector=word_vector)
                backward_inter = sess2.run(reverse_inter,
                                 feed_dict={
                                    word_vectors: action_feats,
                                    caption: former_caption_matrix,
                                    caption_mask: former_caption_masks
                                })
                backward_entropies = backward_inter['entropies']

                # reward: count goodness of actions
                rewards = count_rewards(dull_loss, forward_entropies, backward_entropies, actions, former, reward_type='pg')
    
                # policy gradient: train batch with rewards
                if batch % 10 == 0:
                    _, loss_val = sess.run(
                            [train_op, loss],
                            feed_dict={
                                input_tensors['word_vectors']: current_feats,
                                input_tensors['caption']: current_caption_matrix,
                                input_tensors['caption_mask']: current_caption_masks,
                                input_tensors['reward']: rewards
                            })
                    print("Epoch: {}, batch: {}, loss: {}, Elapsed time: {}".format(epoch, batch, loss_val, time.time() - start_time))
                else:
                    _ = sess.run(train_op,
                                 feed_dict={
                                    input_tensors['word_vectors']: current_feats,
                                    input_tensors['caption']: current_caption_matrix,
                                    input_tensors['caption_mask']: current_caption_masks,
                                    input_tensors['reward']: rewards
                                })
                if batch % 1000 == 0 and batch != 0:
                    print("Epoch {} batch {} is done. Saving the model ...".format(epoch, batch))
                    saver.save(sess, os.path.join(model_path, 'model-{}-{}'.format(epoch, batch)))
            if training_type == 'normal':
                if batch % 10 == 0:
                    _, loss_val = sess.run(
                            [train_op, loss],
                            feed_dict={
                                input_tensors['word_vectors']: current_feats,
                                input_tensors['caption']: current_caption_matrix,
                                input_tensors['caption_mask']: current_caption_masks,
                                input_tensors['reward']: ones_reward
                            })
                    print("Epoch: {}, batch: {}, loss: {}, Elapsed time: {}".format(epoch, batch, loss_val, time.time() - start_time))
                else:
                    _ = sess.run(train_op,
                                 feed_dict={
                                    input_tensors['word_vectors']: current_feats,
                                    input_tensors['caption']: current_caption_matrix,
                                    input_tensors['caption_mask']: current_caption_masks,
                                    input_tensors['reward']: ones_reward
                                })

        print("Epoch ", epoch, " is done. Saving the model ...")
        saver.save(sess, os.path.join(model_path, 'model'), global_step=epoch)

if __name__ == "__main__":
    train()

In [None]:
./script/download_reversed.sh



In [None]:
./script/test_RL.sh <PATH TO MODEL> <INPUT FILE> <OUTPUT FILE>

#-*- coding: utf-8 -*-

from __future__ import print_function

import re
import os
import time
import sys

sys.path.append("python")
import data_parser
import config

from gensim.models import KeyedVectors
from rl_model import PolicyGradient_chatbot
import tensorflow as tf
import numpy as np

#=====================================================
# Global Parameters
#=====================================================
default_model_path = './model/RL/model-56-3000'
testing_data_path = 'sample_input.txt' if len(sys.argv) <= 2 else sys.argv[2]
output_path = 'sample_output_RL.txt' if len(sys.argv) <= 3 else sys.argv[3]

word_count_threshold = config.WC_threshold

#=====================================================
# Train Parameters
#=====================================================
dim_wordvec = 300
dim_hidden = 1000

n_encode_lstm_step = 22 + 1 # one random normal as the first timestep
n_decode_lstm_step = 22

batch_size = 1

""" Extract only the vocabulary part of the data """
def refine(data):
    words = re.findall("[a-zA-Z'-]+", data)
    words = ["".join(word.split("'")) for word in words]
    # words = ["".join(word.split("-")) for word in words]
    data = ' '.join(words)
    return data

def test(model_path=default_model_path):
    testing_data = open(testing_data_path, 'r').read().split('\n')

    word_vector = KeyedVectors.load_word2vec_format('model/word_vector.bin', binary=True)

    _, ixtoword, bias_init_vector = data_parser.preProBuildWordVocab(word_count_threshold=word_count_threshold)

    model = PolicyGradient_chatbot(
            dim_wordvec=dim_wordvec,
            n_words=len(ixtoword),
            dim_hidden=dim_hidden,
            batch_size=batch_size,
            n_encode_lstm_step=n_encode_lstm_step,
            n_decode_lstm_step=n_decode_lstm_step,
            bias_init_vector=bias_init_vector)

    word_vectors, caption_tf, feats = model.build_generator()

    sess = tf.InteractiveSession()

    saver = tf.train.Saver()
    try:
        print('\n=== Use model', model_path, '===\n')
        saver.restore(sess, model_path)
    except:
        print('\nUse default model\n')
        saver.restore(sess, default_model_path)

    with open(output_path, 'w') as out:
        generated_sentences = []
        bleu_score_avg = [0., 0.]
        for idx, question in enumerate(testing_data):
            print('question =>', question)

            question = [refine(w) for w in question.lower().split()]
            question = [word_vector[w] if w in word_vector else np.zeros(dim_wordvec) for w in question]
            question.insert(0, np.random.normal(size=(dim_wordvec,))) # insert random normal at the first step

            if len(question) > n_encode_lstm_step:
                question = question[:n_encode_lstm_step]
            else:
                for _ in range(len(question), n_encode_lstm_step):
                    question.append(np.zeros(dim_wordvec))

            question = np.array([question]) # 1x22x300
    
            generated_word_index, prob_logit = sess.run([caption_tf, feats['probs']], feed_dict={word_vectors: question})
            generated_word_index = np.array(generated_word_index).reshape(batch_size, n_decode_lstm_step)[0]
            prob_logit = np.array(prob_logit).reshape(batch_size, n_decode_lstm_step, -1)[0]
            # print('generated_word_index.shape', generated_word_index.shape)
            # print('prob_logit.shape', prob_logit.shape)

            # remove <unk> to second high prob. word
            # print('generated_word_index', generated_word_index)
            for i in range(len(generated_word_index)):
                if generated_word_index[i] == 3:
                    sort_prob_logit = sorted(prob_logit[i])
                    # print('max val', sort_prob_logit[-1])
                    # print('second max val', sort_prob_logit[-2])
                    maxindex = np.where(prob_logit[i] == sort_prob_logit[-1])[0][0]
                    secmaxindex = np.where(prob_logit[i] == sort_prob_logit[-2])[0][0]
                    # print('max ind', maxindex, ixtoword[maxindex])
                    # print('second max ind', secmaxindex, ixtoword[secmaxindex])
                    generated_word_index[i] = secmaxindex
            # print('generated_word_index', generated_word_index)

            generated_words = []
            for ind in generated_word_index:
                generated_words.append(ixtoword[ind])

            # generate sentence
            punctuation = np.argmax(np.array(generated_words) == '<eos>') + 1
            generated_words = generated_words[:punctuation]
            generated_sentence = ' '.join(generated_words)

            # modify the output sentence 
            generated_sentence = generated_sentence.replace('<bos> ', '')
            generated_sentence = generated_sentence.replace(' <eos>', '')
            generated_sentence = generated_sentence.replace('--', '')
            generated_sentence = generated_sentence.split('  ')
            for i in range(len(generated_sentence)):
                generated_sentence[i] = generated_sentence[i].strip()
                if len(generated_sentence[i]) > 1:
                    generated_sentence[i] = generated_sentence[i][0].upper() + generated_sentence[i][1:] + '.'
                else:
                    generated_sentence[i] = generated_sentence[i].upper()
            generated_sentence = ' '.join(generated_sentence)
            generated_sentence = generated_sentence.replace(' i ', ' I ')
            generated_sentence = generated_sentence.replace("i'm", "I'm")
            generated_sentence = generated_sentence.replace("i'd", "I'd")
            generated_sentence = generated_sentence.replace("i'll", "I'll")
            generated_sentence = generated_sentence.replace("i'v", "I'v")
            generated_sentence = generated_sentence.replace(" - ", "")

            print('generated_sentence =>', generated_sentence)
            out.write(generated_sentence + '\n')


if __name__ == "__main__":
    if len(sys.argv) > 1:
        test(model_path=sys.argv[1])
    else:
        test()

In [None]:
./script/simulate.sh <PATH TO MODEL> <SIMULATE TYPE> <INPUT FILE> <OUTPUT FILE>

#-*- coding: utf-8 -*-

from __future__ import print_function

from gensim.models import KeyedVectors
import data_parser
import config

from model import Seq2Seq_chatbot
import tensorflow as tf
import numpy as np

import re
import os
import sys
import time


#=====================================================
# Global Parameters
#=====================================================
default_model_path = './model/model-20'
default_simulate_type = 1  # type 1 use one former sent, type 2 use two former sents

testing_data_path = 'sample_input.txt' if len(sys.argv) <= 3 else sys.argv[3]
output_path = 'sample_dialog_output.txt' if len(sys.argv) <= 4 else sys.argv[4]

max_turns = config.MAX_TURNS
word_count_threshold = config.WC_threshold

#=====================================================
# Train Parameters
#=====================================================
dim_wordvec = 300
dim_hidden = 1000

n_encode_lstm_step = 22  # need to plus 1 later, because one random normal as the first timestep
n_decode_lstm_step = 22

batch_size = 1

""" Extract only the vocabulary part of the data """
def refine(data):
    words = re.findall("[a-zA-Z'-]+", data)
    words = ["".join(word.split("'")) for word in words]
    # words = ["".join(word.split("-")) for word in words]
    data = ' '.join(words)
    return data

def generate_question_vector(state, word_vector, dim_wordvec, n_encode_lstm_step):
    state = [refine(w) for w in state.lower().split()]
    state = [word_vector[w] if w in word_vector else np.zeros(dim_wordvec) for w in state]
    state.insert(0, np.random.normal(size=(dim_wordvec,))) # insert random normal at the first step

    if len(state) > n_encode_lstm_step:
        state = state[:n_encode_lstm_step]
    else:
        for _ in range(len(state), n_encode_lstm_step):
            state.append(np.zeros(dim_wordvec))

    return np.array([state]) # 1 x n_encode_lstm_step x dim_wordvec

def generate_answer_sentence(generated_word_index, prob_logit, ixtoword):
    # remove <unk> to second high prob. word
    for i in range(len(generated_word_index)):
        if generated_word_index[i] == 3:
            sort_prob_logit = sorted(prob_logit[i][0])
            # print('max val', sort_prob_logit[-1])
            # print('second max val', sort_prob_logit[-2])
            maxindex = np.where(prob_logit[i][0] == sort_prob_logit[-1])[0][0]
            secmaxindex = np.where(prob_logit[i][0] == sort_prob_logit[-2])[0][0]
            # print('max ind', maxindex, ixtoword[maxindex])
            # print('second max ind', secmaxindex, ixtoword[secmaxindex])
            generated_word_index[i] = secmaxindex

    generated_words = []
    for ind in generated_word_index:
        generated_words.append(ixtoword[ind])

    # generate sentence
    punctuation = np.argmax(np.array(generated_words) == '<eos>') + 1
    generated_words = generated_words[:punctuation]
    generated_sentence = ' '.join(generated_words)

    # modify the output sentence 
    generated_sentence = generated_sentence.replace('<bos> ', '')
    generated_sentence = generated_sentence.replace(' <eos>', '')
    generated_sentence = generated_sentence.replace('--', '')
    generated_sentence = generated_sentence.split('  ')
    for i in range(len(generated_sentence)):
        generated_sentence[i] = generated_sentence[i].strip()
        if len(generated_sentence[i]) > 1:
            generated_sentence[i] = generated_sentence[i][0].upper() + generated_sentence[i][1:] + '.'
        else:
            generated_sentence[i] = generated_sentence[i].upper()
    generated_sentence = ' '.join(generated_sentence)
    generated_sentence = generated_sentence.replace(' i ', ' I ')
    generated_sentence = generated_sentence.replace("i'm", "I'm")
    generated_sentence = generated_sentence.replace("i'd", "I'd")

    return generated_sentence

def init_history(simulate_type, start_sentence):
    history = []
    history += ['' for _ in range(simulate_type-1)]
    history.append(start_sentence)
    return history

def get_cur_state(simulate_type, dialog_history):
    return ' '.join(dialog_history[-1*simulate_type:]).strip()

def simulate(model_path=default_model_path, simulate_type=default_simulate_type):
    ''' args:
            model_path:     <type 'str'> the pre-trained model using for inference
            simulate_type:  <type 'int'> how many former sents should use as state
    '''

    testing_data = open(testing_data_path, 'r').read().split('\n')

    word_vector = KeyedVectors.load_word2vec_format('model/word_vector.bin', binary=True)

    _, ixtoword, bias_init_vector = data_parser.preProBuildWordVocab(word_count_threshold=word_count_threshold)

    model = Seq2Seq_chatbot(
            dim_wordvec=dim_wordvec,
            n_words=len(ixtoword),
            dim_hidden=dim_hidden,
            batch_size=batch_size,
            n_encode_lstm_step=n_encode_lstm_step,
            n_decode_lstm_step=n_decode_lstm_step,
            bias_init_vector=bias_init_vector)

    word_vectors, caption_tf, probs, _ = model.build_generator()

    sess = tf.InteractiveSession()

    saver = tf.train.Saver()
    try:
        print('\n=== Use model {} ===\n'.format(model_path))
        saver.restore(sess, model_path)
    except:
        print('\nUse default model\n')
        saver.restore(sess, default_model_path)

    with open(output_path, 'w') as out:
        for idx, start_sentence in enumerate(testing_data):
            print('dialog {}'.format(idx))
            print('A => {}'.format(start_sentence))
            out.write('dialog {}\nA: {}\n'.format(idx, start_sentence))

            dialog_history = init_history(simulate_type, start_sentence)

            for turn in range(max_turns):
                question = generate_question_vector(state=get_cur_state(simulate_type, dialog_history), 
                                                    word_vector=word_vector, 
                                                    dim_wordvec=dim_wordvec, 
                                                    n_encode_lstm_step=n_encode_lstm_step)

                generated_word_index, prob_logit = sess.run([caption_tf, probs], feed_dict={word_vectors: question})

                generated_sentence = generate_answer_sentence(generated_word_index=generated_word_index, 
                                                              prob_logit=prob_logit, 
                                                              ixtoword=ixtoword)

                dialog_history.append(generated_sentence)
                print('B => {}'.format(generated_sentence))

                question_2 = generate_question_vector(state=get_cur_state(simulate_type, dialog_history), 
                                                    word_vector=word_vector, 
                                                    dim_wordvec=dim_wordvec, 
                                                    n_encode_lstm_step=n_encode_lstm_step)

                generated_word_index, prob_logit = sess.run([caption_tf, probs], feed_dict={word_vectors: question_2})

                generated_sentence_2 = generate_answer_sentence(generated_word_index=generated_word_index, 
                                                                  prob_logit=prob_logit, 
                                                                  ixtoword=ixtoword)

                dialog_history.append(generated_sentence_2)
                print('A => {}'.format(generated_sentence_2))
                out.write('B: {}\nA: {}\n'.format(generated_sentence, generated_sentence_2))


if __name__ == "__main__":
    model_path = default_model_path if len(sys.argv) <= 1 else sys.argv[1]
    simulate_type = default_simulate_type if len(sys.argv) <= 2 else int(sys.argv[2])
    n_encode_lstm_step = n_encode_lstm_step * simulate_type + 1  # sent len * sent num + one random normal
    print('simulate_type', simulate_type)
    print('n_encode_lstm_step', n_encode_lstm_step)
    simulate(model_path=model_path, simulate_type=simulate_type)