In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.contrib import rnn

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
class SeriesPredictor:
    def __init__(self, input_dim, seq_size, hidden_dim=10):
        self.input_dim = input_dim
        self.seq_size = seq_size
        self.hidden_dim = hidden_dim
        
        self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]), name='W_out')
        self.b_out = tf.Variable(tf.random_normal([1]), name='b_out')
        self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim])
        self.y = tf.placeholder(tf.float32, [None, seq_size])
        self.cost = tf.reduce_mean(tf.square(self.model() - self.y)) 
        self.train_op = tf.train.AdamOptimizer().minimize(self.cost)
        self.saver = tf.train.Saver()
        
    def model(self): 
        """
        :param x: inputs of size [T, batch_size, input_size] 
        :param W: matrix of fully-connected output layer weights 
        :param b: vector of fully-connected output layer biases 
        """

        cell = rnn.BasicLSTMCell(self.hidden_dim)
        outputs, states = tf.nn.dynamic_rnn(cell, self.x, dtype=tf.float32)
        num_examples = tf.shape(self.x)[0]
        print('Shape '+str(np.shape(self.x))) # this turns out to be [3, 4, 1] - 3 samples of size 4 in a matrix
        print('Num Ex: '+str(np.shape(self.x)[0])) # this is 3 for self.x
        W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1]) # creates repeated array of [10, 1, 1] * [3, 1, 1] - our "memory" of what we picked for each hidden dim in the 3 states
        out = tf.matmul(outputs, W_repeated) + self.b_out # [batch_size, max_time, cell.output_size] so [3, 3, 10] * [30, 1, 1] + [3]
        out = tf.squeeze(out) # removes all size 1 dimensions, so takes us down from 
        return out
    
    def train(self, train_x, train_y):
        with tf.Session() as sess:
            tf.get_variable_scope().reuse_variables()
            sess.run(tf.global_variables_initializer())
            for i in range(1000):
                _, mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y})
                if i % 100 == 0:
                    print(i, mse)
            save_path = self.saver.save(sess, 'ch10-model.ckpt')
            print('Model saved to {}'.format(save_path))
            
    def test(self, test_x):
        with tf.Session() as sess:
            tf.get_variable_scope().reuse_variables()
            self.saver.restore(sess, './ch10-model.ckpt')
            output = sess.run(self.model(), feed_dict={self.x: test_x})
            print(output)

In [3]:
predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10)
train_x = [[[1], [2], [5], [6]],
           [[5], [7], [7], [8]],
           [[3], [4], [5], [7]]]
train_y = [[1, 3, 7, 11],
           [5, 12, 14, 15],
           [3, 7, 9, 12]]
predictor.train(train_x, train_y)
test_x = [[[1], [2], [3], [4]],
          [[4], [5], [6], [7]]]
predictor.test(test_x)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Shape (?, 4, 1)
Num Ex: ?
Instructions for updating:
Use tf.cast instead.
0 55.205883
100 31.606201
200 10.045015
300 4.8394103
400 2.9640453
500 1.9300634
600 1.2844108
700 0.79509085
800 0.4975969
900 0.3301896
Model saved to ch10-model.ckpt
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./ch10-model.ckpt
Shape (?, 4, 1)
Num Ex: ?
[[ 1.9399943  2.621526   4.988883   7.8408556]
 [ 4.5700607  9.042807  12.229868  13.867058 ]]
