Skip to content

Commit

Permalink
Fixes the repeating token bug
Browse files Browse the repository at this point in the history
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>).
See tensorflow/nmt#3
  • Loading branch information
gundamMC committed Sep 8, 2018
1 parent ac27898 commit 6011c2f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 77 deletions.
207 changes: 134 additions & 73 deletions ProjectWaifu/Chatbot/ChatbotNetwork.py
Expand Up @@ -4,11 +4,15 @@
import ProjectWaifu.Chatbot.ParseData as ParseData
import ProjectWaifu.WordEmbedding as WordEmbedding
from ProjectWaifu.Utils import get_mini_batches
from hyperdash import Experiment


exp = Experiment("Chatbot")


class ChatbotNetwork(Network):

def __init__(self, learning_rate=0.001, batch_size=16, restore=False):
def __init__(self, learning_rate=0.001, batch_size=8, restore=False):
# hyperparameters
self.learning_rate = learning_rate
self.batch_size = batch_size
Expand All @@ -26,33 +30,74 @@ def __init__(self, learning_rate=0.001, batch_size=16, restore=False):
self.y_length = tf.placeholder(tf.int32, [None])
self.word_embedding = tf.Variable(tf.constant(0.0, shape=(self.word_count, self.n_vector)), trainable=False)

self.y_target = tf.placeholder(tf.int32, [None, self.max_sequence])
# this is w/o <GO>

# Network parameters
def get_gru_cell():
return tf.contrib.rnn.GRUCell(self.n_hidden)

self.cell_encode = tf.contrib.rnn.MultiRNNCell([get_gru_cell() for _ in range(3)])
self.cell_decode = tf.contrib.rnn.MultiRNNCell([get_gru_cell() for _ in range(3)])
self.cell_encode = tf.contrib.rnn.MultiRNNCell([get_gru_cell() for _ in range(2)])
self.cell_decode = tf.contrib.rnn.MultiRNNCell([get_gru_cell() for _ in range(2)])
self.projection_layer = tf.layers.Dense(self.word_count)

# Optimization
dynamic_max_sequence = tf.reduce_max(self.y_length)
mask = tf.sequence_mask(self.y_length, maxlen=dynamic_max_sequence, dtype=tf.float32)
# crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
# labels=self.y[:, :dynamic_max_sequence], logits=self.network())
# self.cost = tf.reduce_sum(crossent * mask)
self.cost = tf.contrib.seq2seq.sequence_loss(self.network(), self.y[:, :dynamic_max_sequence], weights=mask)

# pred_network = self.network()
#
# pred_train = tf.cond(tf.less(tf.shape(pred_network)[-1], self.max_sequence),
# lambda: tf.concat([tf.squeeze(pred_network[:, 0]),
# tf.zeros(
# [tf.shape(pred_network)[0],
# self.max_sequence - tf.shape(pred_network)[-1]],
# tf.float32)], 1),
# lambda: pred_network[:, 20]
# )

crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=self.y_target[:, :dynamic_max_sequence], logits=self.network())
self.cost = tf.reduce_sum(crossent * mask) / tf.cast(tf.shape(self.y)[0], tf.float32)
# self.cost = tf.contrib.seq2seq.sequence_loss(self.network(), self.y, weights=mask)
self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.cost)
self.infer = self.network(mode="infer")

pred_infer = tf.cond(tf.less(tf.shape(self.infer)[2], self.max_sequence),
lambda: tf.concat([tf.squeeze(self.infer[:, 0]),
tf.zeros(
[tf.shape(self.infer)[0],
self.max_sequence - tf.shape(self.infer)[-1]],
tf.int32)], 1),
lambda: tf.squeeze(self.infer[:, 0, :20])
)

correct_pred = tf.equal(
pred_infer,
self.y_target)
self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Tensorboard
tf.summary.scalar('cost', self.cost)
tf.summary.scalar('accuracy', self.accuracy)
self.merged = tf.summary.merge_all()

# Tensorflow initialization
self.saver = tf.train.Saver()
self.sess = tf.Session()

if restore:
self.tensorboard_writer = tf.summary.FileWriter('/tmp')
else:
self.tensorboard_writer = tf.summary.FileWriter('/tmp', self.sess.graph)

self.sess.run(tf.global_variables_initializer())

if restore is False:
embedding_placeholder = tf.placeholder(tf.float32, shape=WordEmbedding.embeddings.shape)
self.sess.run(self.word_embedding.assign(embedding_placeholder),
feed_dict={embedding_placeholder: WordEmbedding.embeddings})

else:
self.saver.restore(self.sess, tf.train.latest_checkpoint('./model'))

Expand Down Expand Up @@ -93,61 +138,61 @@ def network(self, mode="train"):
output_layer=self.projection_layer
)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sequence, impute_finished=True)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sequence)

return outputs.rnn_output
else:

with tf.variable_scope('decode', reuse=True):
with tf.variable_scope('decode', reuse=tf.AUTO_REUSE):

# Greedy search
infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.word_embedding, tf.tile(tf.constant([WordEmbedding.start], dtype=tf.int32), [tf.shape(self.x)[0]]), WordEmbedding.end)

decoder = tf.contrib.seq2seq.BasicDecoder(
attn_decoder_cell,
infer_helper,
decoder_initial_state,
output_layer=self.projection_layer
)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sequence,
impute_finished=True)

return outputs.sample_id

# Beam search
# beam_width = 3
# encoder_outputs_beam = tf.contrib.seq2seq.tile_batch(encoder_outputs, beam_width)
# encoder_state_beam = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
# batch_size_beam = tf.shape(encoder_outputs_beam)[0]
#
# attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
# num_units=self.n_hidden, memory=encoder_outputs_beam)
# infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.word_embedding, tf.tile(tf.constant([WordEmbedding.start], dtype=tf.int32), [tf.shape(self.x)[0]]), WordEmbedding.end)
#
# attn_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
# self.cell_decode, attention_mechanism, attention_layer_size=self.n_hidden)
#
# decoder_initial_state = attn_decoder_cell.zero_state(dtype=tf.float32, batch_size=batch_size_beam)
#
# decoder = tf.contrib.seq2seq.BeamSearchDecoder(
# cell=attn_decoder_cell,
# embedding=self.word_embedding,
# start_tokens=tf.tile(tf.constant([WordEmbedding.start], dtype=tf.int32), [tf.shape(self.x)[0]]),
# end_token=WordEmbedding.end,
# initial_state=decoder_initial_state,
# beam_width=beam_width,
# output_layer=self.projection_layer
# decoder = tf.contrib.seq2seq.BasicDecoder(
# attn_decoder_cell,
# infer_helper,
# decoder_initial_state,
# output_layer=self.projection_layer
# )
#
# outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sequence)
# outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sequence,
# impute_finished=True)
#
# return tf.transpose(outputs.predicted_ids, perm=[0, 2, 1]) # [batch size, beam width, sequence length]
# return outputs.sample_id

# Beam search
beam_width = 3
encoder_outputs_beam = tf.contrib.seq2seq.tile_batch(encoder_outputs, beam_width)
encoder_state_beam = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
batch_size_beam = tf.shape(encoder_outputs_beam)[0]

attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
num_units=self.n_hidden, memory=encoder_outputs_beam)

attn_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
self.cell_decode, attention_mechanism, attention_layer_size=self.n_hidden)

decoder_initial_state = attn_decoder_cell.zero_state(dtype=tf.float32, batch_size=batch_size_beam)

decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell=attn_decoder_cell,
embedding=self.word_embedding,
start_tokens=tf.tile(tf.constant([WordEmbedding.start], dtype=tf.int32), [tf.shape(self.x)[0]]),
end_token=WordEmbedding.end,
initial_state=decoder_initial_state,
beam_width=beam_width,
output_layer=self.projection_layer
)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sequence)

return tf.transpose(outputs.predicted_ids, perm=[0, 2, 1]) # [batch size, beam width, sequence length]

def setTrainingData(self, train_x, train_y):
train_x = ParseData.split_data(train_x)
train_y = ParseData.split_data(train_y)

train_x, train_y, x_length, y_length = \
train_x, train_y, x_length, y_length, y_target = \
ParseData.data_to_index(train_x, train_y,
WordEmbedding.words_to_index)

Expand All @@ -157,11 +202,12 @@ def setTrainingData(self, train_x, train_y):
self.train_y = np.array(train_y)
self.train_x_length = np.array(x_length)
self.train_y_length = np.array(y_length)
self.train_y_target = np.array(y_target)

def train(self, epochs=800, display_step=10):
def train(self, epochs=800, display_step=10, epoch_offset=0):
for epoch in range(epochs):
mini_batches_x, mini_batches_x_length, mini_batches_y, mini_batches_y_length \
= get_mini_batches([self.train_x, self.train_x_length, self.train_y, self.train_y_length], self.batch_size)
mini_batches_x, mini_batches_x_length, mini_batches_y, mini_batches_y_length, mini_batches_y_target \
= get_mini_batches([self.train_x, self.train_x_length, self.train_y, self.train_y_length, self.train_y_target], self.batch_size)

# mini_batches_x = [self.train_x]
# mini_batches_x_length = [self.train_x_length]
Expand All @@ -173,25 +219,39 @@ def train(self, epochs=800, display_step=10):
batch_x_length = mini_batches_x_length[batch]
batch_y = mini_batches_y[batch]
batch_y_length = mini_batches_y_length[batch]
batch_y_target = mini_batches_y_target[batch]

if epoch % display_step == 0 or display_step == 0:
if (epoch % display_step == 0 or display_step == 0) and (batch % 100 == 0 or batch == 0):
_, cost_value = self.sess.run([self.train_op, self.cost], feed_dict={
self.x: batch_x,
self.x_length: batch_x_length,
self.y: batch_y,
self.y_length: batch_y_length
self.y_length: batch_y_length,
self.y_target: batch_y_target
})

print("epoch:", epoch, "- (", batch, "/", len(mini_batches_x), ") -", cost_value)
print("epoch:", epoch_offset + epoch, "- (", batch, "/", len(mini_batches_x), ") -", cost_value)
exp.metric("cost", cost_value)

else:
self.sess.run(self.train_op, feed_dict={
self.sess.run([self.train_op], feed_dict={
self.x: batch_x,
self.x_length: batch_x_length,
self.y: batch_y,
self.y_length: batch_y_length
self.y_length: batch_y_length,
self.y_target: batch_y_target
})

summary = self.sess.run(self.merged, feed_dict={
self.x: self.train_x,
self.x_length: self.train_x_length,
self.y: self.train_y,
self.y_length: self.train_y_length,
self.y_target: self.train_y_target
})

self.tensorboard_writer.add_summary(summary, epoch_offset + epoch)

def predict(self, sentence):

input_x, x_length, _ = ParseData.sentence_to_index(ParseData.split_sentence(sentence.lower()),
Expand All @@ -203,19 +263,19 @@ def predict(self, sentence):
self.x_length: np.array([x_length])
})

result = ""
for i in range(len(test_output)):
result = result + WordEmbedding.words[int(test_output[i])] + "(" + str(test_output[i]) + ") "
return result

# list_res = []
# for index in range(len(test_output)):
# result = ""
# for i in range(len(test_output[index])):
# result = result + WordEmbedding.words[int(test_output[index][i])] + " "
# list_res.append(result)
#
# return list_res
# result = ""
# for i in range(len(test_output)):
# result = result + WordEmbedding.words[int(test_output[i])] + "(" + str(test_output[i]) + ") "
# return result

list_res = []
for index in range(len(test_output)):
result = ""
for i in range(len(test_output[index])):
result = result + WordEmbedding.words[int(test_output[index][i])] + " "
list_res.append(result)

return list_res

def predictAll(self, path, save_path=None):
pass
Expand All @@ -229,20 +289,21 @@ def save(self, step=None, meta=True):

question, response = ParseData.load_twitter("./Data/chat.txt")

WordEmbedding.create_embedding(".\\Data\\glove.twitter.27B.100d.txt")
WordEmbedding.create_embedding("./Data/glove.twitter.27B.100d.txt")

test = ChatbotNetwork(learning_rate=0.0001, restore=True)
test = ChatbotNetwork(learning_rate=0.00001, batch_size=4, restore=True)

test.setTrainingData(question, response)
test.setTrainingData(question[0:500], response[0:500])

# clear reference (for gc)
question = None
response = None

step = 1
step = 55

while True:

test.train(1, 1)
test.train(5, 1, 5 * step)

if step > 0:
test.save(step, False)
Expand Down
17 changes: 13 additions & 4 deletions ProjectWaifu/Chatbot/ParseData.py
Expand Up @@ -103,9 +103,13 @@ def split_data(data):
return result


def sentence_to_index(sentence, word_to_index):
result = [word_to_index["<GO>"]]
length = 1
def sentence_to_index(sentence, word_to_index, target=False):
if not target:
result = [word_to_index["<GO>"]]
length = 1
else:
result = []
length = 0
unk = 0
for word in sentence:
length += 1
Expand Down Expand Up @@ -136,6 +140,7 @@ def data_to_index(data_x, data_y, word_to_index):
result_y = []
lengths_x = []
lengths_y = []
result_y_target = []
index = 0

while index < len(data_x):
Expand All @@ -152,4 +157,8 @@ def data_to_index(data_x, data_y, word_to_index):
lengths_x.append(x_length)
lengths_y.append(y_length)

return result_x, result_y, lengths_x, lengths_y
y_target = y[1:]
y_target.append(word_to_index["<EOS>"])
result_y_target.append(y_target)

return result_x, result_y, lengths_x, lengths_y, result_y_target

0 comments on commit 6011c2f

Please sign in to comment.