In [1]:
import tensorflow as tf
import numpy as np

In [2]:
def load_text_to_id(filename):
    with open(filename) as f:
        text = f.read().decode('utf8').replace(u'　', '').replace(u'\n', '')
        words = list(text)
    vocab = ['<unk>'] + sorted(list(set(words)))
    vocab = dict(zip(vocab, range(len(vocab))))
    word_ids = [vocab[w] if w in vocab else 0 for w in words]
    inv_vocab = np.array([x[1] for x in sorted(zip(vocab.values(), vocab.keys()))])
    assert len(vocab) == len(inv_vocab)
    return word_ids, vocab, inv_vocab

word_ids, vocab, inv_vocab = load_text_to_id('raw_novel.txt')

In [3]:
def get_model(scope_name, n_steps, dim_input, dim_hidden, batch_size, vocab_size, n_layer=1):
    input_data = tf.placeholder('int32', [batch_size, n_steps])
    targets = tf.placeholder('int32', [batch_size, n_steps])
    p_keep = tf.placeholder_with_default(tf.constant(1.0), [])

    with tf.variable_scope(scope_name) as scope:
        with tf.device("/cpu:0"):
            try:
                embedding = tf.get_variable('embedding', [vocab_size, dim_hidden],
                    initializer=tf.contrib.layers.xavier_initializer())
            except ValueError:
                scope.reuse_variables()
                embedding = tf.get_variable('embedding', [vocab_size, dim_hidden],
                                            initializer=tf.contrib.layers.xavier_initializer())
            inputs = tf.nn.embedding_lookup(embedding, input_data)
            inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, n_steps, inputs)]

        with tf.device('/gpu:0'):
            cell = tf.nn.rnn_cell.GRUCell(dim_hidden)
            cell = tf.nn.rnn_cell.DropoutWrapper(cell, p_keep, p_keep)
            cell = tf.nn.rnn_cell.MultiRNNCell([cell] * n_layer)
            initial_state = cell.zero_state(batch_size, 'float32')

        outputs, state = tf.nn.rnn(cell, inputs, initial_state=initial_state)
        output = tf.reshape(tf.concat(1, outputs), [-1, dim_hidden])
        with tf.device('/gpu:0'):
            Wy = tf.get_variable('Wy', [dim_hidden, vocab_size],
                                 initializer=tf.contrib.layers.xavier_initializer())
            by = tf.get_variable('by', [vocab_size],
                                 initializer=tf.contrib.layers.xavier_initializer())
            logits = tf.matmul(output, Wy) + by
            probs = tf.nn.softmax(logits)
            loss = tf.nn.seq2seq.sequence_loss_by_example(
                [logits], [tf.reshape(targets, [-1])],
                [tf.ones([batch_size * n_steps], dtype='float32')], vocab_size)
            cost = tf.reduce_sum(loss) / batch_size / n_steps
            final_state = state
            train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)

    return {'train': train_op, 'final_state': final_state, 'cost': cost,
            'logits': logits, 'input': input_data, 'target': targets,
            'init_state': initial_state, 'cell': cell, 'p_keep': p_keep,
            'embedding': embedding, 'probs': probs}

In [4]:
model = get_model('default', 1, 15, 30, 1, vocab_size=len(vocab))

In [21]:
prime = u'集雅科技的主要成員在資料科學、雲端運算、人工智慧等領域都各有專精，'

In [36]:
sent = []  # prime.split()
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(sess, './lstm_zh.checkpoint')
    
    state = sess.run(model['cell'].zero_state(1, tf.float32))
    init_state = model['init_state']
    final_state = model['final_state']
    probs = model['probs']
    logits = model['logits']
    X = model['input']
    for w in list(prime):
        x = np.array([[vocab[w]]]) if w in vocab else np.array([[0]])
        p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
    word = inv_vocab[np.argmax(p)]
    sent.append(word)
    
    for _ in range(120):
        x = np.array([[vocab[word]]])
        p, state = sess.run([probs, final_state], feed_dict={X: x, init_state: state})
        word = np.random.choice(inv_vocab, p=p[0])
        # word = inv_vocab[np.argmax(p[0])]
        sent.append(word)
        if word in u'。？！」':
            break
    print ''.join(sent)

我像坊演門口，我看到他們個眼模人的吉他。
