In [1]:
# Simple character-level prediction using RNN
# 2017-03-30 jkang
# 
# 'hello_world_good_morning_see_you_hello_great'
#
# input:  'ello_world_good_morning_see_you_hello_great'
# output: 'hello_world_good_morning_see_you_hello_grea'
# 
# Python3.5
# Tensorflow1.0.1
# ref: https://hunkim.github.io/ml/

import tensorflow as tf
import numpy as np

In [2]:
# Make input data
char_raw = 'hello_world_good_morning_see_you_hello_great'
char_list = list(set(char_raw))
char_idx = {c: i for i, c in enumerate(char_list)} # character with index
char_data = [char_idx[c] for c in char_raw]
char_data_onehot = tf.one_hot(char_data, 
                              depth=len(char_list), 
                              on_value=1., 
                              off_value=0.,
                              axis=1, 
                              dtype=tf.float32)
char_input = char_data_onehot[:-1] # 'ello_world_good_morning_see_you_hello_great'
char_output = char_data_onehot[1:] # 'hello_world_good_morning_see_you_hello_grea'
print('char_data:', char_data)
print('char_data_onehot:', char_data_onehot.shape)
print('char_input:', char_input.shape)
print('char_output:', char_output.shape)

char_data: [8, 12, 3, 3, 10, 4, 0, 10, 13, 3, 15, 4, 9, 10, 10, 15, 4, 5, 10, 13, 2, 11, 2, 9, 4, 6, 12, 12, 4, 1, 10, 16, 4, 8, 12, 3, 3, 10, 4, 9, 13, 12, 7, 14]
char_data_onehot: (44, 17)
char_input: (43, 17)
char_output: (43, 17)


In [3]:
# Set configurations
n_char = len(char_list)
rnn_size = n_char # number of one-hot coding vectors == output size for each cell
n_timestep = char_input.shape.as_list()[0] # length of the input
batch_size = 1  # one example
max_iter = 300

# Set RNN
rnn_cell = tf.contrib.rnn.BasicRNNCell(rnn_size)
init_state = tf.zeros([batch_size, rnn_cell.state_size])
input_split = tf.split(value=char_input, num_or_size_splits=n_timestep, axis=0)
outputs, state = tf.contrib.rnn.static_rnn(rnn_cell, input_split, init_state)

In [4]:
# logits: A 3D Tensor of shape [batch_size x sequence_length x num_decoder_symbols] and dtype float. 
# targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype int.
# weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype float.
# logits = tf.reshape(tf.concat(values=char_output, axis=1), [batch_size, n_timestep, rnn_size])
logits = tf.reshape(outputs, [batch_size, n_timestep, rnn_size])
targets = tf.reshape(char_data[1:], [batch_size, n_timestep]) # target as index
weights = tf.ones((batch_size, n_timestep))

loss = tf.contrib.seq2seq.sequence_loss(logits, targets, weights)
cost = tf.reduce_sum(loss) / batch_size
train_op = tf.train.RMSPropOptimizer(learning_rate = 0.01, decay = 0.9).minimize(cost)

In [5]:
# Launch the graph in a session
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(max_iter):
        _, c = sess.run([train_op, cost])
        result = sess.run(tf.arg_max(logits, 2))
        print('Epoch: {:>4}'.format(i + 1), '/', str(max_iter),
              'Cost: {:4f}'.format(c), 'Predict:', ''.join([char_list[t] for t in result[0]]))

Epoch:    1 / 300 Cost: 2.939752 Predict: iysnellehlywleellhehwrwrweygdeewlmynnellh_h
Epoch:    2 / 300 Cost: 2.937397 Predict: iysnellehlywleellhehwrwrweygdeewlmynnellh_h
Epoch:    3 / 300 Cost: 2.934924 Predict: iysnellehlywleelleehwrwrweygdeewlmynnellh_h
Epoch:    4 / 300 Cost: 2.932330 Predict: iysnellehlywleelleehwrwrweygdeewleynnellh_h
Epoch:    5 / 300 Cost: 2.929610 Predict: iysnellehlylleelleehwrwrweygdeewleynnellh_h
Epoch:    6 / 300 Cost: 2.926758 Predict: iysnellehlylleelleehwrwrweygdeewleynnellh_h
Epoch:    7 / 300 Cost: 2.923770 Predict: iysnellehlylleelleehwrwrweygdeewleynnellh_h
Epoch:    8 / 300 Cost: 2.920639 Predict: iysnellehlylleelleehwrwrweygdeewleynnellh_h
Epoch:    9 / 300 Cost: 2.917361 Predict: iysnellehlylleelleehwrerleygdeewleynnellh_h
Epoch:   10 / 300 Cost: 2.913930 Predict: iysnellehlylleelleehwrerleygdeewleynnellh_h
Epoch:   11 / 300 Cost: 2.910341 Predict: iysnellehlylleelleehwrerleygdeewleydnellh_h
Epoch:   12 / 300 Cost: 2.906589 Predict: iysnellehlyl