In [1]:
import tensorflow as tf

CIFAR10_CLASSES = 10

CONV_LAYER_PARAMS = {
    "filters": 128,
    "kernel_size": (5, 5),
    "strides": 1,
    "activation": tf.nn.relu,
    "padding": 'same'
}

MAXPOOL_LAYER_PARAMS = {
    "pool_size": (2, 2),
    "strides": 2,
    "padding": 'same'
}

DENSE_LAYER_PARAMS = {
    "units": 4096,
    "activation": tf.nn.relu
}

In [2]:
global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32, [None, 32, 32, 3], name='input')
y = tf.placeholder(tf.float32, [None, 10], name='labels')
keep_prob = tf.placeholder(tf.float32)

In [3]:
# convolutional layers
conv1 = tf.layers.conv2d(X, name='conv1', **CONV_LAYER_PARAMS)
maxpool1 = tf.layers.max_pooling2d(conv1, name='maxpool1', **MAXPOOL_LAYER_PARAMS)

conv2 = tf.layers.conv2d(maxpool1, name='conv2', **CONV_LAYER_PARAMS)
maxpool2 = tf.layers.max_pooling2d(conv2, name='maxpool2', **MAXPOOL_LAYER_PARAMS)

conv3 = tf.layers.conv2d(maxpool2, name='conv3', **CONV_LAYER_PARAMS)
maxpool3 = tf.layers.max_pooling2d(conv3, name='maxpool3', **MAXPOOL_LAYER_PARAMS)

conv4 = tf.layers.conv2d(maxpool3, name='conv4', **CONV_LAYER_PARAMS)
maxpool4 = tf.layers.max_pooling2d(conv4, name='maxpool4', **MAXPOOL_LAYER_PARAMS)

conv5 = tf.layers.conv2d(maxpool4, name='conv5', **CONV_LAYER_PARAMS)
maxpool5 = tf.layers.max_pooling2d(conv5, name='maxpool5', **MAXPOOL_LAYER_PARAMS)

# fully-connected layers
new_shape = [-1, maxpool5.shape[1] * maxpool5.shape[2] * maxpool5.shape[3]]
dense1 = tf.layers.dense(tf.reshape(maxpool5, new_shape), name='dense1', **DENSE_LAYER_PARAMS)
dense1_dropout = tf.nn.dropout(dense1, name='dense1_dropout', keep_prob=keep_prob)

dense2 = tf.layers.dense(dense1_dropout, name='dense2', **DENSE_LAYER_PARAMS)
dense2_dropout = tf.nn.dropout(dense2, name='dense2_dropout', keep_prob=keep_prob)

output = tf.layers.dense(dense2_dropout, name='output', units=CIFAR10_CLASSES)

In [4]:
print(conv1.shape)
print(maxpool1.shape)
print(conv2.shape)
print(maxpool2.shape)
print(conv3.shape)
print(maxpool3.shape)
print(conv4.shape)
print(maxpool4.shape)
print(conv5.shape)
print(maxpool5.shape)

(?, 32, 32, 128)
(?, 16, 16, 128)
(?, 16, 16, 128)
(?, 8, 8, 128)
(?, 8, 8, 128)
(?, 4, 4, 128)
(?, 4, 4, 128)
(?, 2, 2, 128)
(?, 2, 2, 128)
(?, 1, 1, 128)


In [5]:
print(dense1.shape)
print(dense2.shape)
print(output.shape)

(?, 4096)
(?, 4096)
(?, 10)
