In [1]:
import tensorflow as tf
import numpy as np
from collections import Counter

In [2]:
def load_text_to_id(filename):
    with open(filename) as f:
        text = f.read().decode('utf8').replace(u'　', '').replace(u'\n', '')
        words = list(text)
    vocab = ['<unk>'] + sorted(list(set(words)))
    vocab = dict(zip(vocab, range(len(vocab))))
    word_ids = [vocab[w] if w in vocab else 0 for w in words]
    inv_vocab = np.array([x[1] for x in sorted(zip(vocab.values(), vocab.keys()))])
    assert len(vocab) == len(inv_vocab)
    return word_ids, vocab, inv_vocab

word_ids, vocab, inv_vocab = load_text_to_id('raw_novel.txt')

In [3]:
def get_model(scope_name, n_steps, dim_input, dim_hidden, batch_size, vocab_size, n_layer=1):
    g = tf.Graph()
    with g.as_default():
        input_data = tf.placeholder('int32', [batch_size, n_steps])
        targets = tf.placeholder('int32', [batch_size, n_steps])
        p_keep = tf.placeholder_with_default(tf.constant(1.0), [])
        xavier = tf.contrib.layers.xavier_initializer()

        with tf.variable_scope(scope_name) as scope:
            with tf.device("/cpu:0"):
                try:
                    embedding = tf.get_variable('embedding', [vocab_size, dim_input], initializer=xavier)
                except ValueError:
                    scope.reuse_variables()
                    embedding = tf.get_variable('embedding', [vocab_size, dim_input], initializer=xavier)
                inputs = tf.nn.embedding_lookup(embedding, input_data)
                inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, n_steps, inputs)]

            with tf.device('/gpu:0'):
                cell = tf.nn.rnn_cell.GRUCell(dim_hidden)
                cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=p_keep)
                cell = tf.nn.rnn_cell.MultiRNNCell([cell] * n_layer)
                initial_state = cell.zero_state(batch_size, 'float32')

            outputs, state = tf.nn.rnn(cell, inputs, initial_state=initial_state)
            output = tf.reshape(tf.concat(1, outputs), [-1, dim_hidden])
            with tf.device('/gpu:0'):
                Wy = tf.get_variable('Wy', [dim_hidden, vocab_size], initializer=xavier)
                by = tf.get_variable('by', [vocab_size], initializer=xavier)
                logits = tf.matmul(output, Wy) + by
                probs = tf.nn.softmax(logits)
                loss = tf.nn.seq2seq.sequence_loss_by_example(
                    [logits], [tf.reshape(targets, [-1])],
                    [tf.ones([batch_size * n_steps], dtype='float32')], vocab_size)
                cost = tf.reduce_sum(loss) / batch_size / n_steps
                final_state = state
                train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)

    return {'train': train_op, 'final_state': final_state, 'cost': cost,
            'logits': logits, 'input': input_data, 'target': targets,
            'init_state': initial_state, 'cell': cell, 'p_keep': p_keep,
            'embedding': embedding, 'probs': probs, 'graph': g}

In [4]:
def predict_by_max(prefix, max_len=120):
    model = get_model('default', 1, 50, 100, 1, vocab_size=len(vocab))
    sent = list(prefix)
    with model['graph'].as_default():
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            saver.restore(sess, './lstm_zh.checkpoint')

            state = sess.run(model['cell'].zero_state(1, tf.float32))
            init_state = model['init_state']
            final_state = model['final_state']
            probs = model['probs']
            logits = model['logits']
            X = model['input']
            for w in list(prefix):
                x = np.array([[vocab[w]]]) if w in vocab else np.array([[0]])
                p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
            word = inv_vocab[np.argmax(p)]
            sent.append(word)

            for _ in range(max_len):
                x = np.array([[vocab[word]]])
                p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
                # ---------------------------------------------
                word = inv_vocab[np.argmax(p[0])]
                # ---------------------------------------------
                sent.append(word)
            return ''.join(sent)

In [5]:
def predict_by_sample(prefix, max_len=120):
    model = get_model('default', 1, 50, 100, 1, vocab_size=len(vocab))
    sent = list(prefix)
    with model['graph'].as_default():
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            saver.restore(sess, './lstm_zh.checkpoint')

            state = sess.run(model['cell'].zero_state(1, tf.float32))
            init_state = model['init_state']
            final_state = model['final_state']
            probs = model['probs']
            logits = model['logits']
            X = model['input']
            for w in list(prefix):
                x = np.array([[vocab[w]]]) if w in vocab else np.array([[0]])
                p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
            word = inv_vocab[np.argmax(p)]
            sent.append(word)

            for _ in range(max_len):
                x = np.array([[vocab[word]]])
                p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
                # ---------------------------------------------
                word = np.random.choice(inv_vocab, p=p[0])
                # ---------------------------------------------
                sent.append(word)
            return ''.join(sent)

In [6]:
print predict_by_max(u'身為一個程式設計師，')

身為一個程式設計師，在我的姐姐，我們就在我的姐姐，我知道，我們就在我的姐姐，」我說：「你要你的姐姐！」我說：「你要你的姐姐！」我說：「你要你的姐姐！」我說：「你要你的姐姐！」我說：「你要你的姐姐！」我說：「你要你的姐姐！」我說：「你要你的姐姐！」我說：「你要你的


In [7]:
print predict_by_sample(u'身為一個程式設計師，')

身為一個程式設計師，在這法握下著頓：「那麼，楚濂陶劍波，陶劍波呢？這份關愁楚份，一些吃很。才，怎樣？」我問：「我很好說！」母親直奔進身向忽然一句搖頭。「不上考去，我愛十忍著費雲帆被費雲帆粉房門，我走著嘴：「別是放生的酒，這種信爸爸。你並在「我干什麼辦？幾句不再放


## Beam Search

In [8]:
def predict_by_beam(prefix, beam_size=8, max_len=120):
    model = get_model('default', 1, 50, 100, 1, vocab_size=len(vocab))
    last_sent = ''
    with model['graph'].as_default():
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            saver.restore(sess, './lstm_zh.checkpoint')

            state = sess.run(model['cell'].zero_state(1, tf.float32))
            init_state = model['init_state']
            final_state = model['final_state']
            probs = model['probs']
            logits = model['logits']
            X = model['input']
            for w in prefix[:-1]:
                x = np.array([[vocab[w]]]) if w in vocab else np.array([[0]])
                p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
            beam = [(0.0, state, prefix)]
            
            while len(beam[0][2]) < max_len:
                new_beam = []
                for score, state, sent in beam:
                    x = np.array([[vocab[sent[-1]]]]) if sent[-1] in vocab else np.array([[0]])
                    p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
                    top_words = inv_vocab[np.argsort(p[0])[::-1][:beam_size]]
                    top_prob = np.sort(p[0])[::-1][:beam_size]
                    for w, p in zip(top_words, top_prob):
                        if w in sent:
                            new_beam.append((score + np.log(p) * 2, state, sent + w))
                        else:
                            new_beam.append((score + np.log(p), state, sent + w))
                for i in range(len(new_beam))[::-1]:
                    if new_beam[i][2][-1] in u'。？！」':
                        last_sent = new_beam.pop(i)[2]
                beam = sorted(new_beam)[::-1][:beam_size]
            return beam[0][2]

In [9]:
print predict_by_beam(u'身為一個程式設計師，', 8)

身為一個程式設計師，和綠萍的姐姐，我不知道什麼地方去，你還有辦法，但是，我看到楚濂，他們又開始默默的說：「紫菱，這些年來，我在做什麼，我也沒有幾句話吧，你確實以為她彈吉他，爸爸……天哪，我就可能的故事，讓我聽著那珠簾幽夢的時候，你要進去了，
