In [1]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import math
import tflearn
import random

from tensorflow.contrib.slim.python.slim.nets import resnet_v2
from tensorflow.contrib.slim.python.slim.nets import resnet_utils
from tflearn.datasets import cifar100

In [2]:
# initialize data
(X, Y), (X_test, Y_test) = cifar100.load_data(one_hot=True)

image_size = 32
channels = 3
classes = 100

random.seed(31415)

split_prob = 0.95

images = []
labels = []
validation_images = []
validation_labels = []

# split training/validation sets
for (image,label) in zip(X,Y):
    if random.random() < split_prob:
        images.append(image)
        labels.append(label)
    else:
        validation_images.append(image)
        validation_labels.append(label)

images = np.asarray(images)
labels = np.asarray(labels)

validation_images = np.asarray(validation_images)
validation_labels = np.asarray(validation_labels)

images = images.reshape(images.shape[0], image_size,image_size, channels)

num_datapoints = images.shape[0]

validation_images = validation_images.reshape(
    validation_images.shape[0], image_size, image_size, channels
)

test_images = X_test
test_images = test_images.reshape(test_images.shape[0], image_size,image_size,channels)
test_labels = Y_test


In [3]:
# define model

dropout_keep_prob = 0.5
def my_res_net(images):
    net = images
    
    net = slim.conv2d(net, 64, [5,5], scope='conv1')
    net = slim.max_pool2d(net, [2,2], scope='pool1')
    blocks = [
        resnet_utils.Block('block1', resnet_v2.bottleneck,
                         [(256, 64, 1)] * 2 + [(256, 64, 2)]),
        resnet_utils.Block('block2', resnet_v2.bottleneck,
                         [(512, 128, 1)] * 3 + [(512, 128, 2)])
    ]
    net, _ = resnet_v2.resnet_v2(net,
                                 blocks,
                                 classes,
                                 True,
                                 None,
                                 include_root_block=True,
                                 reuse=None,
                                 scope='resnet_v2')
    
    net = slim.flatten(net, scope='flatten3')
    net = slim.fully_connected(net, 1024, scope='fully_connected4')

    net = slim.fully_connected(net,
                               classes,
                               activation_fn=None,
                               scope='fully_connected_out')
    return net
        

x = tf.placeholder(tf.float32, shape=[None, image_size, image_size, channels])
y = tf.placeholder(tf.float32, shape=[None, classes])

predictions = my_res_net(x)
total_loss = tf.nn.softmax_cross_entropy_with_logits(
    logits=predictions,labels=y
)

correct = tf.equal(tf.argmax(predictions,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

learn_rate = 1e-4
train_step = tf.train.AdamOptimizer(learn_rate).minimize(total_loss)

In [4]:
# train model

batch_size = 20

epochs = 100
iterations = math.ceil(num_datapoints / batch_size)
print ("datapoints: {}".format(num_datapoints))
print ("epochs: {}".format(epochs))
print ("iterations: {}".format(iterations))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
validate = math.floor(iterations/(100/epochs))
for e in range(epochs):
    for i in range(iterations):
        idx = batch_size*i
        batch_images = images[idx:idx+batch_size]
        batch_images = batch_images.reshape(
            batch_images.shape[0],
            image_size, image_size,
            channels).tolist()
        batch_labels = labels[idx:idx+batch_size]
        batch_labels = batch_labels.reshape(batch_labels.shape[0], classes).tolist()

        sess.run(train_step, feed_dict={x: batch_images, y:batch_labels})

        if i%validate == 0:
            train_accuracy = sess.run(
                accuracy,
                feed_dict={x: validation_images, y:validation_labels}
            )
            print("epoch: {}, iteration: {}, training accuracy: {}"
                  .format(e, i, train_accuracy))

datapoints: 47467
epochs: 100
iterations: 2374
epoch: 0, iteration: 0, training accuracy: 0.015791552141308784
epoch: 1, iteration: 0, training accuracy: 0.20607975125312805
epoch: 2, iteration: 0, training accuracy: 0.2751677930355072
epoch: 3, iteration: 0, training accuracy: 0.3189893364906311
epoch: 4, iteration: 0, training accuracy: 0.33241215348243713
epoch: 5, iteration: 0, training accuracy: 0.3268851041793823
epoch: 6, iteration: 0, training accuracy: 0.34267666935920715
epoch: 7, iteration: 0, training accuracy: 0.33201736211776733
epoch: 8, iteration: 0, training accuracy: 0.3126727342605591
epoch: 9, iteration: 0, training accuracy: 0.3217528760433197
epoch: 10, iteration: 0, training accuracy: 0.3174101710319519
epoch: 11, iteration: 0, training accuracy: 0.3095144033432007
epoch: 12, iteration: 0, training accuracy: 0.3146466612815857
epoch: 13, iteration: 0, training accuracy: 0.3103039860725403
epoch: 14, iteration: 0, training accuracy: 0.3047769367694855
epoch: 15, i

In [5]:
# evalutate test data

test_accuracy = sess.run(
    accuracy,
    feed_dict={x: test_images, y:test_labels}
)

print("test accuracy: {}".format(test_accuracy))

test accuracy: 0.34689998626708984
