In [1]:
import os
import pickle
import numpy as np
import tensorflow as tf
from datetime import datetime

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

  from ._conv import register_converters as _register_converters


### global functions

In [524]:
def create_path(path):
    """Create path if not exist"""
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise


def txt_to_string(text_file_path):
    """Read text file as string"""

    f = open(text_file_path, 'r')
    txt_string = ''
    while True:
        single_line = f.readline()
        if single_line == '':
            break
        txt_string += single_line
    f.close()

    return txt_string


def create_rnn_dataset(txt_array, max_chars):
    """
    This function would return a training matrix and a label set
    matrix: # rows represents number of sentence observations
            # cols represents max number of chars in each sentence
    """
    txt_array_label = txt_array[1:]
    nrows = int(np.floor(txt_array.shape[0] / max_chars))
    txt_array = txt_array[:nrows*max_chars]
    txt_array_label = txt_array_label[:nrows*max_chars]
    
    txt_array = txt_array.reshape([-1, max_chars])
    txt_array_label = txt_array_label.reshape([-1, max_chars])
    
    return txt_array, txt_array_label


def train_val_split(train_mtx, label_mtx, train_proportion=0.8,
                    random_state=666):
    np.random.seed(random_state)
    num_train_rows = np.round(train_mtx.shape[0] * train_proportion).astype(int)
    rows_selected = np.random.choice(train_mtx.shape[0],
                                     num_train_rows, replace=False)
    rows_not_selected = list(
        set(range(train_mtx.shape[0])) - set(rows_selected))
    
    return (train_mtx[rows_selected], train_mtx[rows_not_selected],
            label_mtx[rows_selected], label_mtx[rows_not_selected])


class RNNDataset():
    def __init__(self, X, y):
        self.X = X.copy()
        self.y = y.copy()
        

class BatchManager():

    def __init__(self, train_set, num_epochs, shuffle=True,
                 random_state=666):
        """
        train_set, val_set: RNNDataset instances
        """
        self.train_set = train_set
        self.num_epochs = num_epochs
        self.shuffle = shuffle
        self.random_state = random_state
        self.current_epoch = 0
        self.rows_in_batch = []

    def next_batch(self, batch_size):
        """
        Output next batch as (X, y), return None if ran over num_epochs
        """
        num_rows = self.train_set.X.shape[0]

        while len(self.rows_in_batch) < batch_size:
            self.current_epoch += 1
            row_nums = list(range(num_rows))
            if self.shuffle:
                np.random.seed(self.random_state)
                np.random.shuffle(row_nums)
            self.rows_in_batch += row_nums
            
        selected_X = self.train_set.X[self.rows_in_batch[:batch_size]]
        selected_y = self.train_set.y[self.rows_in_batch[:batch_size]]
        self.rows_in_batch = self.rows_in_batch[batch_size:]

        if self.current_epoch > self.num_epochs:
            return None
        return selected_X, selected_y

### data preprocessing

In [472]:
max_chars = 40

In [473]:
# read text file as string
txt_string = txt_to_string('./data/shakespeare.txt')

# convert characters to numbers
txt_char_ls = list(txt_string)
unique_chars = np.unique(txt_char_ls)
le_char = LabelEncoder()
le_char.fit(unique_chars)
txt_num_ls = le_char.transform(txt_char_ls)
txt_num_array = np.array(txt_num_ls)

# construct text dataset matrix and label set
txt_mtx, txt_mtx_label = create_rnn_dataset(txt_num_array, max_chars)

# create train test split
train_x, val_x, train_y, val_y = train_val_split(
    txt_mtx, txt_mtx_label, train_proportion=0.8, random_state=666)

### NN

In [536]:
VOCAB_SIZE = len(unique_chars)
EMBEDDING_SIZE = 100
NUM_RNN_LAYER_UNITS = [128, 128]
KEEP_PROB = 0.8
LEARNING_RATE = 1e-4
BATCH_SIZE = 500
NUM_EPOCHS = 200
SHUFFLE_BATCH = True
EVAL_FREQUENCY = 10
EARLY_STOPPING_EVAL_ROUNDS = 5

In [537]:
tf.reset_default_graph()
graph = tf.Graph()

time_now = datetime.utcnow().strftime('%Y%m%d%H%M%S')
tf_graph_dir = './tf_graph/run-{}/'.format(time_now)
tf_model_dir = './tf_model/model-{}/'.format(time_now)

create_path(tf_model_dir)
with open(tf_model_dir+'char_encoder.pkl', 'wb') as f:
    pickle.dump(le_char, f)

with graph.as_default():
    txt_input = tf.placeholder(tf.int32, [None, None], 'text_input')
    txt_input_next = tf.placeholder(tf.int32, [None, None], 'text_label')
    txt_input_next_onehot = tf.one_hot(txt_input_next, depth=VOCAB_SIZE,
                                       axis=2, dtype=tf.int32, name='text_label_onehot')

    with tf.variable_scope('embedding'):
        embed_matrix = tf.Variable(tf.random_uniform([VOCAB_SIZE, EMBEDDING_SIZE]))    
        txt_embedded = tf.nn.embedding_lookup(embed_matrix, txt_input)

    with tf.variable_scope('rnn_1'):
        lstm_1 = tf.nn.rnn_cell.BasicLSTMCell(NUM_RNN_LAYER_UNITS[0], activation=tf.nn.tanh)
        lstm_dropout_1 = tf.nn.rnn_cell.DropoutWrapper(lstm_1, output_keep_prob=KEEP_PROB)
        rnn_1 = tf.nn.dynamic_rnn(lstm_dropout_1, txt_embedded, dtype=tf.float32)

    with tf.variable_scope('rnn_2'):
        lstm_2 = tf.nn.rnn_cell.BasicLSTMCell(NUM_RNN_LAYER_UNITS[1], activation=tf.nn.tanh)
        lstm_dropout_2 = tf.nn.rnn_cell.DropoutWrapper(lstm_2, output_keep_prob=KEEP_PROB)
        rnn_2 = tf.nn.dynamic_rnn(lstm_dropout_2, rnn_1[0], dtype=tf.float32)

    with tf.variable_scope('concat'):
        concat_out = tf.concat([txt_embedded, rnn_1[0], rnn_2[0]], axis=2)
        
    with tf.variable_scope('output'):
        logit_out = tf.layers.dense(concat_out, VOCAB_SIZE)
        softmax_out = tf.nn.softmax(logit_out)

    with tf.variable_scope('loss'):
        loss_word = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=txt_input_next_onehot, logits=logit_out)
        loss_sum_sentence = tf.reduce_sum(loss_word, axis=1)
        loss_avg_batch = tf.reduce_mean(loss_sum_sentence)
        
    with tf.variable_scope('optimization'):
        optimizer = tf.train.AdamOptimizer(LEARNING_RATE)
        train = optimizer.minimize(loss_avg_batch)
        
    init = tf.global_variables_initializer()
    
# tf.summary.FileWriter(tf_graph_dir, graph)

In [None]:
bst_score = 99999
step_counter = 1
early_stopping_counter = 0
train_set = RNNDataset(train_x, train_y)
val_set = RNNDataset(val_x, val_y)
batch_manager = BatchManager(train_set, num_epochs=NUM_EPOCHS,
                             shuffle=SHUFFLE_BATCH, random_state=666)

with graph.as_default():
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init)

        while True:
            batch = batch_manager.next_batch(BATCH_SIZE)
            if batch is None:
                break
            batch_x, batch_y = batch[0], batch[1]

            if step_counter % EVAL_FREQUENCY == 0:
                train_loss = sess.run(loss_avg_batch, feed_dict={
                    txt_input:batch_x,
                    txt_input_next:batch_y
                })

                val_loss = sess.run(loss_avg_batch, feed_dict={
                    txt_input:val_set.X,
                    txt_input_next:val_set.y
                })
                
                print('Training Loss: {} | Validation Loss: {}'.format(
                    train_loss, val_loss))

                if val_loss < bst_score:
                    early_stopping_counter = 0
                    saver.save(sess, tf_model_dir+'char_generator.ckpt')
                else:
                    early_stopping_counter += 1
                    
                if early_stopping_counter > EARLY_STOPPING_EVAL_ROUNDS:
                    break

            sess.run(train, feed_dict={
                txt_input:batch_x,
                txt_input_next:batch_y
            })

            step_counter += 1

    sess.close()