In [1]:
import tensorflow as tf

sequence_length = 8
vocab_size = 128
embedding_size = 4

input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
emb = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), name="emb")
emb_input = tf.nn.embedding_lookup(emb, input_x)
emb_input_expanded = tf.expand_dims(emb_input, -1)

filter_sizes = [2, 3, 4] 
num_filters = 4

pooled_outputs = []
for i, filter_size in enumerate(filter_sizes):
    with tf.name_scope("conv-maxpool-%s" % filter_size):
        filter_shape = [filter_size, embedding_size, 1, num_filters]
        W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
        b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
        conv = tf.nn.conv2d(emb_input_expanded, 
            W, strides=[1, 1, 1, 1], 
            padding="VALID", name="conv")
        h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
        pooled = tf.nn.max_pool( h, ksize=[1, 
            sequence_length - filter_size + 1, 1, 1], 
            strides=[1, 1, 1, 1], padding='VALID', name="pool")
        pooled_outputs.append(pooled)

num_filters_total = num_filters * len(filter_sizes)
h_pool = tf.concat(pooled_outputs, 3)
h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])

dropout_keep_prob = 0.5
with tf.name_scope("dropout"):
    h_drop = tf.nn.dropout(h_pool_flat, dropout_keep_prob)

# Final (unnormalized) scores and predictions
num_classes = 1
with tf.name_scope("output"):
    W = tf.Variable(tf.random_uniform([num_filters_total, num_classes], -1.0, 1.0), name = "W")
    b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
    scores = tf.nn.xw_plus_b(h_drop, W, b, name="scores")
    predictions = tf.argmax(scores, 1, name="predictions")

# [v.name for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv-maxpool-2')]
sess = tf.Session()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter("./cnn_text", graph=tf.get_default_graph())


batch_x = [
    [1, 4, 6, 8, 20, 2, 8, 9], 
    [11, 14, 16, 18, 20, 12, 18, 19],
    [21, 24, 26, 28, 20, 22, 28, 29],
    [31, 34, 36, 38, 20, 32, 38, 39],
    [41, 44, 46, 48, 40, 42, 48, 49],
    [51, 54, 56, 58, 50, 52, 58, 49]
]

sess.run(scores, feed_dict={input_x:batch_x})


array([[ 0.33229524],
       [ 0.34212485],
       [ 1.48613679],
       [ 0.66891855],
       [-0.19099435],
       [ 1.77867949]], dtype=float32)