In [1]:
import functools
import tensorflow as tf
import codecs
import numpy as np
import random

In [58]:
class BatchGenerator:

    def __init__(self, text, seq_len, batch_size, state_size):
        self.text = text
        self.text_size = len(text)
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.state_size = state_size
        self._init_dict()
        self.state_holders = [self._create_state_holder() for _ in range(self.batch_size)]

    def _create_state_holder(self):
        index = random.randint(0, self.text_size - self.seq_len * 10)
        state = np.zeros(shape=(self.state_size), dtype=np.float)
        return (index, state)

    def _init_dict(self):
        vocab = set()
        for c in text:
            vocab.update([c])
        self.index2char = ['<pad>', '<go>', '<eos>'] + [c for c in vocab]
        self.char2index = {}
        for i, c in enumerate(self.index2char):
            self.char2index[c] = i
        self.vocab_size = len(self.index2char)

    def _next_seq(self, ind):
        seq = np.zeros(shape=(self.seq_len), dtype=np.int32)
        seq_lrn = np.zeros(shape=(self.seq_len), dtype=np.int32)
        seq[0] = self.char2index['<go>']
        l = 0
        for i in range(self.seq_len):
            if ind + i >= self.text_size:
                break
            if i+1 < self.seq_len:
                seq[i+1] = self.char2index[text[ind+i]]
            seq_lrn[i] = self.char2index[text[ind+i]]
        return (seq, seq_lrn, i+1)

    def get_batch(self):
        seqs = []
        seq_lrns = []
        lens = []
        states = []
        for i, _ in enumerate(self.state_holders):
            if self.state_holders[i][0] >= self.text_size:
                self.state_holders[i] = self._create_state_holder()
            (seq, seq_lrn, seq_len) = self._next_seq(self.state_holders[i][0])
            self.state_holders[i] = (self.state_holders[i][0] + seq_len, self.state_holders[i][1])
            seqs.append(seq)
            seq_lrns.append(seq_lrn)
            lens.append(seq_len)
            states.append(self.state_holders[i][1])
        return (seqs, seq_lrns, lens, states)

    def update_states(self, states):
        for i, _ in enumerate(self.state_holders):
            self.state_holders[i] = (self.state_holders[i][0], states[i])

    def seq2text(self, seq):
        return ''.join([self.index2char[i] for i in seq])

In [81]:
tf.reset_default_graph()
tf.enable_eager_execution()

seq_len = 64
batch_size = 128
state_size = 256
epoches = 1000
embedding_size = 128
layers = 3

with codecs.open('text.txt', 'r', encoding='utf-8') as f:
    text = f.read()

batch_generator = BatchGenerator(text, seq_len, batch_size, state_size)
sample_batch_generator = BatchGenerator(text, 100, 1, state_size)
steps_num = batch_generator.text_size // seq_len

inputs = tf.placeholder(tf.int32, [None, None], name='inputs')
labels = tf.placeholder(tf.int32, [None, None], name='labels')
lengths = tf.placeholder(tf.int32, [None], name='lengths')
states = tf.placeholder(tf.float32, [None, state_size], name='states')

embedding_table = tf.Variable(tf.random_uniform([batch_generator.vocab_size, embedding_size]))
embedding = tf.nn.embedding_lookup(embedding_table, inputs)

cell = tf.contrib.rnn.GRUCell(state_size)
projection_layer = tf.layers.Dense(batch_generator.vocab_size, use_bias=False)
print('vocab size: ', batch_generator.vocab_size)

helper = tf.contrib.seq2seq.TrainingHelper(embedding, lengths)
train_decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, states, output_layer=projection_layer)
train_outputs, train_states, _ = tf.contrib.seq2seq.dynamic_decode(train_decoder)
train_output = train_outputs.rnn_output
train_sample_id = train_outputs.sample_id

p1 = tf.print('output: ', tf.shape(train_output))
p2 = tf.print('sample_id: ', train_sample_id)
p3 = tf.print('states: ', tf.shape(train_states))
p4 = tf.print('inputs: ', inputs)
p5 = tf.print('labels: ', tf.shape(labels))
p6 = tf.print('lengths: ', tf.shape(lengths))

masks = tf.sequence_mask(lengths=lengths, dtype=tf.float32)
p7 = tf.print('masks: ', masks)
#with tf.control_dependencies([p2]):
loss = tf.contrib.seq2seq.sequence_loss(logits=train_output, targets=labels, weights=masks)

optimize = tf.train.AdamOptimizer(learning_rate=.001).minimize(loss)
accuracy = tf.metrics.accuracy(inputs, train_sample_id, masks)

# prediction decoder
#prediction_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
#    embedding=embedding,
#    start_tokens=tf.fill([batch_size], tf.to_int32(batch_generator.char2index['<go>'])),
#    end_token=tf.to_int32(tf.fill([], batch_generator.char2index['<eos>'])))

#prediction_decoder = tf.contrib.seq2seq.BasicDecoder(cell, prediction_helper, states, output_layer=projection_layer)
#prediction_output, _, _ = tf.contrib.seq2seq.dynamic_decode(prediction_decoder, maximum_iterations=seq_len)
#preds = batch_generator.seq2text(tf.to_int64(prediction_output.sample_id))


print('epoches: ', epoches)
print('steps: ', steps_num)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    for epoch in range(epoches):
        print('start epoch: {}...'.format(epoch+1))
        for step in range(steps_num):
            (x, y, l, s) = batch_generator.get_batch()
            _, new_states, acc = sess.run([optimize, train_states, accuracy], feed_dict={inputs: x, labels: y, lengths: l, states: s})
            batch_generator.update_states(new_states)
            if step % 100 == 0:
                print('step {}, acc: {}'.format(step, acc))
                #res = sess.run(preds, feed_dict={states: s})
                # print(res)

vocab size:  83
epoches:  1000
steps:  2860
start epoch: 1...
step 0, acc: (0.0, 0.0012207031)
step 100, acc: (0.05314414, 0.052707348)
step 200, acc: (0.03202327, 0.03191616)
step 300, acc: (0.025483606, 0.025435433)
step 400, acc: (0.02259134, 0.022569701)
step 500, acc: (0.021152006, 0.021141216)


KeyboardInterrupt: 

In [4]:
def gen_batch(batch, vocab_size):
    target = []
    input = []
    for seq in batch:
        x = []
        y = []
        for i in range(len(seq)):
            t_x = [0] * vocab_size
            t_y = [0] * vocab_size
            c_i = int(seq[i])-1
            n_i = 0
            if i < len(seq) - 1:
                n_i = int(seq[i+1])-1
            t_x[c_i] = 1    
            t_y[n_i] = 1
            x.append(t_x)
            y.append(t_y)

        input.append(x)
        target.append(y)

    return np.array(input), np.array(target)

In [19]:
tf.reset_default_graph()
class SoftmaxPredictionRnn:

    def __init__(self, input, target, num_hidden=64, num_layers=3):
        self._num_hidden = num_hidden
        self._num_layers = num_layers
        self._max_grad_norm = .2
        self._learning_rate = .001
        self._input = input
        self._target = target
        self.prediction
        self.error
        self.optimize


    @lazy_property
    def length(self):
        used = tf.sign(tf.reduce_max(tf.abs(self._input), reduction_indices=2))
        length = tf.reduce_sum(used, reduction_indices=1)
        length = tf.cast(length, tf.int32)
        return length


    @lazy_property
    def prediction(self):
        # Recurrent network.
        cells = []
        for _ in range(self._num_layers):
            cells.append(tf.contrib.rnn.GRUCell(self._num_hidden))
        cell = tf.contrib.rnn.MultiRNNCell(cells)

        # Get dimensions
        self._max_length = int(self._input.get_shape()[1])
        self._num_classes = int(self._input.get_shape()[2])
        batch_size = tf.shape(self._input)[0]

        states = cell.zero_state(batch_size, tf.float32)
        state_type = type(states)
        self._initial_state = [
            tf.placeholder_with_default(zero_state, [None, self._num_hidden]) for zero_state in states]
        self._initial_state = state_type(self._initial_state)
        self._zero_state = self._initial_state
        
        self._output, self._final_state = tf.nn.dynamic_rnn(cell, self._input,
                                                            dtype=tf.float32, sequence_length=self.length,
                                                            initial_state=self._initial_state)

        # Softmax layer.
        weight = tf.get_variable('W', [self._num_hidden, self._num_classes])
        bias = tf.get_variable('b', [self._num_classes], initializer=tf.constant_initializer(0.1))

        # Flatten to apply same weights to all time steps.
        output = tf.reshape(self._output, [-1, self._num_hidden])
        self._raw_logits = tf.matmul(output, weight) + bias
        self._logits = tf.nn.softmax(tf.matmul(output, weight) + bias)
        prediction = tf.reshape(self._logits, [-1, self._max_length, self._num_classes])

        tf.summary.histogram("rnn_output", output)
        for w in cell.weights:
            tf.summary.histogram("rnn_weight", w)
        tf.summary.histogram("softmax_w", weight)
        tf.summary.histogram("softmax_bias", bias)
        tf.summary.histogram("prediction", prediction)

        return prediction


    @lazy_property
    def cost(self):
        # Compute cross entropy for each frame.
        cross_entropy = self._target * tf.log(self.prediction)
        cross_entropy = -tf.reduce_sum(cross_entropy, reduction_indices=2)
        mask = tf.sign(tf.reduce_max(tf.abs(self._target), reduction_indices=2))
        cross_entropy *= mask

        # Average over actual sequence lengths.
        cross_entropy = tf.reduce_sum(cross_entropy, reduction_indices=1)
        cross_entropy /= tf.cast(self.length, tf.float32)

        loss = tf.reduce_mean(cross_entropy)
        tf.summary.scalar('cross_entropy', loss)

        return loss


    @lazy_property
    def optimize(self):
        tvars = tf.trainable_variables()
        grads = tf.gradients(self.cost, tvars)
        clip_grads, _ = tf.clip_by_global_norm(grads, self._max_grad_norm)
        optimizer = tf.train.AdamOptimizer(self._learning_rate)

        #tf.summary.histogram("gradients", grads)
        #tf.summary.histogram("clip_gradients", clip_grads)

        return optimizer.apply_gradients(zip(clip_grads, tvars))

    @lazy_property
    def error(self):
        mistakes = tf.not_equal(tf.argmax(self._target, 2), tf.argmax(self.prediction, 2))
        mistakes = tf.cast(mistakes, tf.float32)
        mask = tf.sign(tf.reduce_max(tf.abs(self._target), reduction_indices=2))
        mistakes *= mask

        # Average over actual sequence lengths.
        mistakes = tf.reduce_sum(mistakes, reduction_indices=1)
        mistakes /= tf.cast(self.length, tf.float32)
        mistake = tf.reduce_mean(mistakes)

        tf.summary.scalar('error', mistake)

        return mistake


    def train_epoch(self, session, batch_generator, model, epoch, steps, x, y):
        batch = batch_generator.start()
        batch_x, batch_y = gen_batch(batch, batch_generator._vocab_size)
        state = session.run(model._zero_state, feed_dict={self._input: batch_x})
        
        for step in range(steps):
            _, state, s = sess.run([model.optimize, model._final_state, summaries],
                                   feed_dict={x: batch_x, y: batch_y, model._initial_state: state})
            writer.add_summary(s, epoch * steps + step)
            batch = batch_generator.get_batch()
            batch_x, batch_y = gen_batch(batch, batch_generator._vocab_size)

    def sample(self, session, start_text, length, temperature=1., max_prob=True):
        def get_input(symbol, seq_len, vocab_size):
            input = []
            one_hot_one = [0.] * vocab_size
            one_hot_one[int(symbol)-1] = 1.
            
            seq = []
            seq.append(one_hot_one)
            for _ in range(1, seq_len):
                seq.append([0.] * vocab_size)

            input.append(seq)
            return np.array(input, dtype=np.float32)

        # Prepare network's state to generate
        x = get_input(start_text[0], self._max_length, self._num_classes)
        state = session.run(self._zero_state, feed_dict={self._input: x})
        sample = start_text[0]
        for char in start_text[:-1]:
            x = get_input(char, self._max_length, self._num_classes)
            state = session.run(self._final_state, {self._input: x, self._initial_state: state})
            
        # Generate symbols
        x = get_input(start_text[-1], self._max_length, self._num_classes)
        seq = []
        
        for i in range(length):
            state, logits = session.run([self._final_state, self._logits],
                                        {self._input: x, self._initial_state: state})

            sample = np.argmax(logits[0]) + 1
            seq.append(sample)
            x = get_input(sample, self._max_length, self._num_classes)

        return seq