In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf

from tensorflow.python.estimator.inputs.queues import feeding_functions as ff
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators.dynamic_rnn_estimator import PredictionType

In [2]:
RGB_INPUT = 'data/train.csv'
BATCH_SIZE = 32
CHARACTERS = [chr(i) for i in range(256)]

SEQUENCE_LENGTH_KEY = 'sequence_length'
COLOR_NAME_KEY = 'color_name'
RGB_KEY = 'rgb'

In [3]:
# Input Function                                                 
def get_input_fn(csv_file, batch_size, epochs=None):
  with open(RGB_INPUT, 'r') as f:
    df = pd.read_csv(f)
    df['sequence_length'] = df.name.str.len().astype(np.int32)
    
  def input_fn():
    pandas_queue = ff._enqueue_data(df,
                                    capacity=1024,
                                    shuffle=True,
                                    min_after_dequeue=256,
                                    num_threads=4,
                                    enqueue_size=16,
                                    num_epochs=epochs)
    
    _, color_name, r, g, b, seq_len = pandas_queue.dequeue_up_to(batch_size)
    
    split_color_name = tf.string_split(color_name, delimiter='')
    rgb = tf.to_float(tf.stack([r, g, b], axis=1)) / 255.0

    batched = tf.train.shuffle_batch({COLOR_NAME_KEY: split_color_name,
                                      SEQUENCE_LENGTH_KEY: seq_len,
                                      RGB_KEY: rgb},
                                     batch_size,
                                     min_after_dequeue=100,
                                     num_threads=4,
                                     capacity=1000,
                                     enqueue_many=True,
                                     allow_smaller_final_batch=True)
    label = batched.pop(RGB_KEY)
    return batched, label
  return input_fn



In [4]:
train_input_fn = get_input_fn(RGB_INPUT, BATCH_SIZE)

In [12]:
vocabulary = tf.constant(list(" abcdefghijklmnopqrstuvwxyz"), name="vocab")
vocab = tf.contrib.lookup.index_table_from_tensor(vocabulary)

with tf.Graph().as_default():
  train_input = train_input_fn()
  with tf.train.MonitoredSession() as sess:
    # print (train_input)
    k = sess.run(train_input)
   
    # for each character, lookup the index
    print(k[0][COLOR_NAME_KEY])
    encoded = vocab.lookup(k[0][COLOR_NAME_KEY])
    print(sess.run(encoded))

TypeError: Fetch argument array([[ 0,  0],
       [ 0,  1],
       [ 0,  2],
       [ 0,  3],
       [ 0,  4],
       [ 0,  5],
       [ 0,  6],
       [ 0,  7],
       [ 0,  8],
       [ 0,  9],
       [ 0, 10],
       [ 1,  0],
       [ 1,  1],
       [ 1,  2],
       [ 1,  3],
       [ 1,  4],
       [ 1,  5],
       [ 1,  6],
       [ 2,  0],
       [ 2,  1],
       [ 2,  2],
       [ 2,  3],
       [ 2,  4],
       [ 2,  5],
       [ 2,  6],
       [ 2,  7],
       [ 2,  8],
       [ 2,  9],
       [ 2, 10],
       [ 2, 11],
       [ 2, 12],
       [ 3,  0],
       [ 3,  1],
       [ 3,  2],
       [ 3,  3],
       [ 3,  4],
       [ 4,  0],
       [ 4,  1],
       [ 4,  2],
       [ 4,  3],
       [ 4,  4],
       [ 4,  5],
       [ 4,  6],
       [ 4,  7],
       [ 4,  8],
       [ 4,  9],
       [ 4, 10],
       [ 4, 11],
       [ 4, 12],
       [ 5,  0],
       [ 5,  1],
       [ 5,  2],
       [ 5,  3],
       [ 6,  0],
       [ 6,  1],
       [ 6,  2],
       [ 6,  3],
       [ 6,  4],
       [ 6,  5],
       [ 6,  6],
       [ 6,  7],
       [ 6,  8],
       [ 6,  9],
       [ 6, 10],
       [ 6, 11],
       [ 7,  0],
       [ 7,  1],
       [ 7,  2],
       [ 7,  3],
       [ 7,  4],
       [ 7,  5],
       [ 7,  6],
       [ 7,  7],
       [ 7,  8],
       [ 7,  9],
       [ 7, 10],
       [ 7, 11],
       [ 8,  0],
       [ 8,  1],
       [ 8,  2],
       [ 8,  3],
       [ 8,  4],
       [ 8,  5],
       [ 8,  6],
       [ 8,  7],
       [ 9,  0],
       [ 9,  1],
       [ 9,  2],
       [ 9,  3],
       [ 9,  4],
       [ 9,  5],
       [ 9,  6],
       [ 9,  7],
       [ 9,  8],
       [ 9,  9],
       [10,  0],
       [10,  1],
       [10,  2],
       [10,  3],
       [10,  4],
       [10,  5],
       [10,  6],
       [10,  7],
       [10,  8],
       [10,  9],
       [10, 10],
       [10, 11],
       [10, 12],
       [10, 13],
       [10, 14],
       [10, 15],
       [10, 16],
       [10, 17],
       [10, 18],
       [11,  0],
       [11,  1],
       [11,  2],
       [11,  3],
       [11,  4],
       [11,  5],
       [11,  6],
       [11,  7],
       [11,  8],
       [11,  9],
       [11, 10],
       [11, 11],
       [12,  0],
       [12,  1],
       [12,  2],
       [12,  3],
       [12,  4],
       [12,  5],
       [12,  6],
       [12,  7],
       [12,  8],
       [12,  9],
       [12, 10],
       [12, 11],
       [12, 12],
       [12, 13],
       [12, 14],
       [12, 15],
       [12, 16],
       [12, 17],
       [13,  0],
       [13,  1],
       [13,  2],
       [13,  3],
       [13,  4],
       [13,  5],
       [13,  6],
       [13,  7],
       [14,  0],
       [14,  1],
       [14,  2],
       [14,  3],
       [14,  4],
       [14,  5],
       [14,  6],
       [14,  7],
       [14,  8],
       [14,  9],
       [14, 10],
       [14, 11],
       [14, 12],
       [14, 13],
       [14, 14],
       [14, 15],
       [14, 16],
       [15,  0],
       [15,  1],
       [15,  2],
       [15,  3],
       [15,  4],
       [15,  5],
       [15,  6],
       [15,  7],
       [15,  8],
       [15,  9],
       [16,  0],
       [16,  1],
       [16,  2],
       [16,  3],
       [16,  4],
       [16,  5],
       [16,  6],
       [16,  7],
       [16,  8],
       [16,  9],
       [16, 10],
       [16, 11],
       [17,  0],
       [17,  1],
       [17,  2],
       [17,  3],
       [17,  4],
       [17,  5],
       [17,  6],
       [17,  7],
       [17,  8],
       [17,  9],
       [17, 10],
       [17, 11],
       [17, 12],
       [17, 13],
       [18,  0],
       [18,  1],
       [18,  2],
       [18,  3],
       [18,  4],
       [18,  5],
       [19,  0],
       [19,  1],
       [19,  2],
       [19,  3],
       [19,  4],
       [19,  5],
       [19,  6],
       [19,  7],
       [19,  8],
       [19,  9],
       [19, 10],
       [19, 11],
       [19, 12],
       [19, 13],
       [19, 14],
       [20,  0],
       [20,  1],
       [20,  2],
       [20,  3],
       [20,  4],
       [20,  5],
       [20,  6],
       [20,  7],
       [20,  8],
       [20,  9],
       [20, 10],
       [20, 11],
       [21,  0],
       [21,  1],
       [21,  2],
       [21,  3],
       [21,  4],
       [21,  5],
       [21,  6],
       [21,  7],
       [21,  8],
       [21,  9],
       [21, 10],
       [21, 11],
       [21, 12],
       [22,  0],
       [22,  1],
       [22,  2],
       [22,  3],
       [22,  4],
       [22,  5],
       [22,  6],
       [22,  7],
       [22,  8],
       [22,  9],
       [22, 10],
       [22, 11],
       [22, 12],
       [22, 13],
       [22, 14],
       [22, 15],
       [23,  0],
       [23,  1],
       [23,  2],
       [23,  3],
       [23,  4],
       [23,  5],
       [23,  6],
       [23,  7],
       [23,  8],
       [23,  9],
       [23, 10],
       [24,  0],
       [24,  1],
       [24,  2],
       [24,  3],
       [24,  4],
       [24,  5],
       [24,  6],
       [24,  7],
       [24,  8],
       [24,  9],
       [24, 10],
       [24, 11],
       [24, 12],
       [24, 13],
       [25,  0],
       [25,  1],
       [25,  2],
       [25,  3],
       [25,  4],
       [25,  5],
       [25,  6],
       [25,  7],
       [25,  8],
       [25,  9],
       [25, 10],
       [25, 11],
       [25, 12],
       [25, 13],
       [26,  0],
       [26,  1],
       [26,  2],
       [26,  3],
       [26,  4],
       [26,  5],
       [26,  6],
       [26,  7],
       [27,  0],
       [27,  1],
       [27,  2],
       [27,  3],
       [27,  4],
       [27,  5],
       [28,  0],
       [28,  1],
       [28,  2],
       [28,  3],
       [28,  4],
       [28,  5],
       [28,  6],
       [28,  7],
       [28,  8],
       [28,  9],
       [28, 10],
       [28, 11],
       [28, 12],
       [28, 13],
       [29,  0],
       [29,  1],
       [29,  2],
       [29,  3],
       [29,  4],
       [29,  5],
       [29,  6],
       [29,  7],
       [29,  8],
       [29,  9],
       [29, 10],
       [29, 11],
       [29, 12],
       [29, 13],
       [29, 14],
       [30,  0],
       [30,  1],
       [30,  2],
       [30,  3],
       [30,  4],
       [30,  5],
       [30,  6],
       [30,  7],
       [30,  8],
       [31,  0],
       [31,  1],
       [31,  2],
       [31,  3],
       [31,  4],
       [31,  5],
       [31,  6],
       [31,  7]]) has invalid type <class 'numpy.ndarray'>, must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.)

In [None]:
def model_fn(features, target):
    
    # use the vocabulary lookup table
    vocabulary = tf.constant(list(" abcdefghijklmnopqrstuvwxyz"), name="vocab")
    vocab = tf.contrib.lookup.index_table_from_tensor(vocabulary)

    # for each character, lookup the index
    encoded = vocab.lookup(features[COLOR_NAME_KEY])
    
    
    # perform one_hot encoding
    dense_encoding = tf.sparse_tensor_to_dense(encoded, default_value=-1)
    

    one_hot = tf.one_hot(dense_encoding, vocabulary.get_shape()[0])
    
    lengths = features[SEQUENCE_LENGTH_KEY]
  
    rnn_layers = []
    rnn_cell_sizes = [128, 128]
    for size in rnn_cell_sizes:
        rnn_layers.append(tf.contrib.rnn.LSTMCell(size))

    multi_rnn_cell = tf.contrib.rnn.MultiRNNCell(rnn_layers)
    outputs, encoding = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                       inputs=one_hot,
                                       sequence_length=lengths,
                                       dtype=tf.float32)
    
    # slice to keep only the last cell of the RNN
    output = outputs[-1]
    
    #logits = tf.contrib.layers.fully_connected(output, 3)
    weight = tf.Variable(tf.random_normal([128, 3]))
    bias = tf.Variable(tf.random_normal([3]))
    logits = tf.matmul(output, weight) + bias

    # target and logits have different shape, not really sure why..
    # check Lackshman example
    loss = tf.contrib.losses.mean_squared_error(target, logits)
    
    """Character level recurrent neural network model to predict classes.
    target = tf.one_hot(target, 15, 1, 0)
    byte_list = tf.one_hot(features, 256, 1, 0)
    byte_list = tf.unstack(byte_list, axis=1)

    cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
    output, _ = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)
    logits = tf.contrib.layers.fully_connected(output, 3, activation_fn=None)
    loss = tf.contrib.losses.mean_squared_error(logits, target)
"""

    train_op = tf.contrib.layers.optimize_loss(
      loss,
      tf.contrib.framework.get_global_step(),
      optimizer='Adam',
      learning_rate=0.01)

    return ({
      'predictions': logits
    }, loss, train_op)

estimator = tf.contrib.learn.Estimator(model_fn=model_fn)

estimator.fit(input_fn=train_input_fn, steps=2000)
p = estimator.predict(input_fn=get_test_inputs)
for e in p:
	print(e)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_num_worker_replicas': 0, '_master': '', '_task_type': None, '_num_ps_replicas': 0, '_is_chief': True, '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1
}
, '_task_id': 0, '_keep_checkpoint_every_n_hours': 10000, '_save_checkpoints_secs': 600, '_evaluation_master': '', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fc6a386abe0>, '_environment': 'local', '_save_checkpoints_steps': None, '_keep_checkpoint_max': 5, '_tf_random_seed': None, '_model_dir': None, '_save_summary_steps': 100}
here
