In [1]:
import numpy as np
from cs231n import optim
import matplotlib.pyplot as plt
import tensorflow as tf

from cs231n.layers import *
from cs231n.rnn_layers import *

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
class CharacterTable(object):
    """Given a set of characters:
    + Encode them to a one hot integer representation
    + Decode the one hot integer representation to their character output
    + Decode a vector of probabilities to their character output
    """
    def __init__(self, chars):
        """Initialize character table.
        # Arguments
            chars: Characters that can appear in the input.
        """
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

    def encode(self, C, num_rows):
        """One hot encode given string C.
        # Arguments
            num_rows: Number of rows in the returned one hot encoding. This is
                used to keep the # of rows for each data the same.
        """
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1
        return x

    def decode(self, x, calc_argmax=True):
        if calc_argmax:
            x = x.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in x)


# Parameters for the model and dataset.
TRAINING_SIZE = 10000
DIGITS = 7
INVERT = True

# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
# int is DIGITS.
MAXLEN = DIGITS + 2

# All the numbers, plus sign and space for padding.
chars = '0123456789se '
ctable = CharacterTable(chars)

questions = []
expected = []
seen = set()
print('Generating data...')
while len(questions) < TRAINING_SIZE:
    f = lambda: int(''.join(np.random.choice(list('0123456789'))
                    for i in range(np.random.randint(1, DIGITS + 1))))
    a = f()
    b = a+1
    # Skip any addition questions we've already seen
    # Also skip any such that x+Y == Y+x (hence the sorting).
    key = (a, b)
    if key in seen:
        continue
    seen.add(key)
    # Pad the data with spaces such that it is always MAXLEN.
    q = 's{}e'.format(a)
    query = q + ' ' * (MAXLEN - len(q))
    ans = 's{}e'.format(b)
    # Answers can be of maximum size DIGITS + 1.
    ans += ' ' * (MAXLEN + 1 - len(ans))
    if INVERT:
        # Reverse the query, e.g., '12+345  ' becomes '  543+21'. (Note the
        # space used for padding.)
        query = query[::-1]
    questions.append(query)
    expected.append(ans)
print('Total addition questions:', len(questions))

print('Vectorization...')
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.int32)
y = np.zeros((len(questions), MAXLEN + 1), dtype=np.int32)
for i, sentence in enumerate(questions):
    x[i] = ctable.encode(sentence, MAXLEN)
for i, sentence in enumerate(expected):
    y[i] = np.array([ctable.char_indices[z] for z in sentence])

    
# Shuffle (x, y) in unison as the later parts of x will almost all be larger
# digits.
# because it is sorted
indices = np.arange(len(y))
np.random.shuffle(indices)
x = x[indices]
y = y[indices]

# Explicitly set apart 10% for validation data that we never train over.
split_at = len(x) - len(x) // 10
(x_train, x_val) = x[:split_at], x[split_at:]
(y_train, y_val) = y[:split_at], y[split_at:]

print('Training Data:')
print(x_train.shape)
print(y_train.shape)

print('Validation Data:')
print(x_val.shape)
print(y_val.shape)

Generating data...
('Total addition questions:', 10000)
Vectorization...
Training Data:
(9000, 9, 13)
(9000, 10)
Validation Data:
(1000, 9, 13)
(1000, 10)


In [None]:
print x_train[0], y_train[0]
print (ctable.decode(np.array([[0,1,0,0,0,0,0,0,0,0,0,0,0]])))

In [3]:
D = len(chars)
C = len(chars)
T = MAXLEN
N = 5
H = 150
num_epochs = 100
num_batches = x_train.shape[0]//N
batch_size = N

In [15]:
tf.reset_default_graph()

batchX_placeholder = tf.placeholder(tf.float32, [None,T,D])
batchY_placeholder = tf.placeholder(tf.int32, [None,T+1])

cell_state = np.zeros((N,H), dtype=np.float32)
hidden_state = np.zeros((N,H), dtype=np.float32)
init_state = tf.contrib.rnn.LSTMStateTuple(cell_state, hidden_state)

W2 = tf.Variable(np.random.rand(H,C),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,C)), dtype=tf.float32)

# Forward passes
with tf.variable_scope('rnn1'):
    cell = tf.contrib.rnn.BasicLSTMCell(H, state_is_tuple=True)
    states_series, current_state = tf.nn.dynamic_rnn(cell, batchX_placeholder, initial_state=init_state)

losses = []
predictions = []
with tf.variable_scope('rnn2'):
    cell = tf.contrib.rnn.BasicLSTMCell(H, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
    current_input = np.zeros([N,D], dtype=np.float32)
    current_input += (ctable.encode('s',1)).astype(np.float32)
    for t in range(T+1):
        if t > 0: tf.get_variable_scope().reuse_variables()
        states_series, current_state = cell(tf.stack(current_input), current_state)
        logits = tf.matmul(states_series, W2) + b2 #Broadcasted addition
        labels = tf.reshape(batchY_placeholder[:,t], [-1])
        losses.append(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
        prob = tf.nn.softmax(logits)
        pred = tf.one_hot(tf.argmax(prob, axis=1), C, axis=-1)
        current_input = pred
        predictions.append(pred)

predictions = tf.stack(predictions, axis=1)
print (predictions.shape)
total_loss = tf.reduce_mean(losses)
train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

(5, 10, 13)


In [17]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    loss_list = []

    for epoch_idx in range(num_epochs):
        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size

            batchX = x_train[start_idx:end_idx]
            batchY = y_train[start_idx:end_idx]

            _total_loss, _train_step, batchP = sess.run(
                [total_loss, train_step, predictions],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY
                })
        
            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                #plot(loss_list, _predictions_series, batchX, batchY)
        
        i = np.random.randint(x_val.shape[0])
        batchX = x_val[i:i+N]
        batchY = y_val[i:i+N]

        batchP = sess.run(
            predictions,
            feed_dict={
                batchX_placeholder: batchX,
                batchY_placeholder: batchY
            })
        w = [''.join(ctable.indices_char[i] for i in batchY[j]) for j in range(N)]
        q = [ctable.decode(batchP[i]) for i in range(N)]
        for i,j in zip(w,q):
            print (i,j,i==j)

('New data, epoch', 0)
('Step', 0, 'Batch loss', 2.5585134)
('Step', 100, 'Batch loss', 1.9252435)
('Step', 200, 'Batch loss', 1.2742047)
('Step', 300, 'Batch loss', 0.99027258)
('Step', 400, 'Batch loss', 0.83592725)
('Step', 500, 'Batch loss', 0.88858068)
('Step', 600, 'Batch loss', 0.79825097)
('Step', 700, 'Batch loss', 1.0261595)
('Step', 800, 'Batch loss', 1.2549353)
('Step', 900, 'Batch loss', 0.80043638)
('Step', 1000, 'Batch loss', 0.67355829)
('Step', 1100, 'Batch loss', 0.47319546)
('Step', 1200, 'Batch loss', 0.56936693)
('Step', 1300, 'Batch loss', 0.68740243)
('Step', 1400, 'Batch loss', 0.45900881)
('Step', 1500, 'Batch loss', 0.83692008)
('Step', 1600, 'Batch loss', 0.4620496)
('Step', 1700, 'Batch loss', 0.47365409)
('s5854165e ', 's5551465e ', False)
('s6284e    ', 's6284e    ', True)
('s14507e   ', 's41557e   ', False)
('s11116e   ', 's11116e   ', True)
('s987253e  ', 's982523e  ', False)
('New data, epoch', 1)
('Step', 0, 'Batch loss', 0.27246445)
('Step', 100, 'Bat

KeyboardInterrupt: 