In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [None]:
def gen_data(size=1000000):
    X = np.zeros(size)
    for i in range(size):
        threshold = 0.5
        if i >= 3 and X[i-3] == 1:
            threshold += 0.5
        if i >= 8 and X[i-8] == 1:
            threshold -= 0.25
        if np.random.rand() > threshold:
            X[i] = 0
        else:
            X[i] = 1
    return X
def gen_data_sequence(examples=50000, num_seq = 2, num_steps = 21):
    X = np.zeros([examples, num_seq, num_steps])
    for i in range(examples):
        for j in range(num_seq*num_steps):
            jj, kk = divmod(j,num_steps)
            if ((j + 1) % num_steps == 0):
                X[i,jj, kk] = 2 # EOL character
                continue
            threshold = 0.5
            j3, k3 = divmod(j - 3, num_steps)
            j8, k8 = divmod(j - 8, num_steps)
            if j >= 3 and X[i,j3, k3] == 1:
                threshold += 0.5
            if j >= 8 and X[i,j8, k8] == 1:
                threshold -= 0.25
            if np.random.rand() > threshold:
                X[i,jj,kk] = 0
            else:
                X[i,jj,kk] = 1
    return X

def batch_iter(data, batch_size, num_epochs, shuffle=True):
    """Generates a batch iterator for a dataset."""
    data = np.array(data)
    data_size = len(data)
    if len(data) % batch_size == 0:
        num_batches_per_epoch = int(len(data) / batch_size)
    else:
        num_batches_per_epoch = int(len(data) / batch_size) + 1
    for epoch in range(num_epochs):
        # Shuffle the data at each epoch
        if shuffle:
            indices = np.random.permutation(np.arange(data_size))
        else:
            indices = np.arange(data_size)
        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, data_size)
            if end_index - start_index != batch_size:
                continue
            yield data[indices[start_index:end_index]]

In [None]:
tf.reset_default_graph()

In [None]:
# Generate data
num_steps = 20
num_seq = 2
n_hidden_enc = 4
n_hidden_dec = 8
batch_size = 200
learning_rate = 0.1

In [None]:
## Build graph ##
tf.reset_default_graph()

# input
sequence_input = tf.placeholder(tf.int32, [batch_size, num_seq, num_steps])

sequence = tf.one_hot(sequence_input, 3) # [batch_size, num_seq, num_steps, 1-hot]
sequence = tf.unpack(sequence, axis=1) # list(num_seq * [batch_size, num_steps, 1-hot])
x, y = sequence[0], sequence[1]

x = tf.unpack(x, axis = 1) # list(num_steps * [batch_size, 1-hot])
y = tf.unpack(y, axis = 1) # list(num_steps * [batch_size, 1-hot])

# Encoder
init_state_enc = tf.zeros([batch_size, n_hidden_enc])
encoder_cell = tf.nn.rnn_cell.GRUCell(n_hidden_enc)
_, final_state_enc = tf.nn.rnn(encoder_cell, x, initial_state = init_state_enc)

# Encoder to decoder
W_enc_to_dec = tf.get_variable('W_enc_to_dec', [n_hidden_enc, n_hidden_dec])
init_state_dec = tf.matmul(final_state_enc, W_enc_to_dec)

# Decoder
y = [tf.zeros([batch_size, 3])] + y 
decoder_cell = tf.nn.rnn_cell.GRUCell(n_hidden_dec)
outputs, _ = tf.nn.seq2seq.rnn_decoder(y[:-1], init_state_dec, decoder_cell)

# To output
W_out = tf.get_variable('W_out', [n_hidden_dec, 3])
b_out = tf.get_variable('b_out', [3])

logits = [tf.matmul(o, W_out) + b_out for o in outputs]
predictions = [tf.nn.softmax(l) for l in logits]

y = [tf.argmax(yv, 1) for i, yv in enumerate(y) if i > 0]
loss_weights = [tf.ones([batch_size]) for i in range(num_steps)]

# Cost and training
losses = tf.nn.seq2seq.sequence_loss_by_example(logits, y, loss_weights)
total_loss = tf.reduce_mean(losses)
train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(total_loss)


In [None]:
X = gen_data_sequence()
X = np.reshape(X, [-1, num_seq, num_steps])

batches = batch_iter(list(X), batch_size = batch_size, num_epochs = 3)

acc_loss = 0
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    for i, batch in enumerate(batches):
        x_batch = np.array(batch)
        feed_dict = {sequence_input : x_batch,
                     }
        loss, _ = sess.run([total_loss, train_step], feed_dict=feed_dict)
        acc_loss += loss
        if i % 100 == 0 and i > 0:
            print(acc_loss/100)
            acc_loss = 0