In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.contrib import slim
import os
import tarfile
from six.moves import urllib

In [None]:
from tensorflow.contrib.slim.nets import vgg

In [None]:
learning_rate = 0.001
training_epochs = 5
batch_size = 8
n_data = 5000

## Prepare Dataset

In [None]:
cifar10 = keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

In [None]:
test_labels = test_labels.astype('uint8')

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images[:n_data], train_labels[:n_data])).shuffle(
    buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images[:n_data], test_labels[:n_data])).shuffle(
    buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()

In [None]:
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
images, labels = iterator.get_next()

In [None]:
train_init = iterator.make_initializer(train_dataset)
test_init = iterator.make_initializer(test_dataset)

In [None]:
images = tf.cast(images, tf.float32)
resized_images = tf.image.resize_images(images, (224, 224))

labels = tf.reshape(labels, (-1,))
onehot_labels = tf.one_hot(labels, 10)

## Download Pretrained Model's Checkpoint

In [None]:
ckpt_dir = "ckpt"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

In [None]:
ckpt_url = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"
zpath = os.path.join(ckpt_dir, "vgg_16_2016_08_28.tar.gz")
if not os.path.exists(zpath):
    print ("Downloading %s ..." % (zpath))
    urllib.request.urlretrieve(ckpt_url, zpath)
    print ("Done!")
else:
    print ("%s Already Exists" % (zpath))

## Unzip Checkpoint

In [None]:
cpath = os.path.join(ckpt_dir, "vgg_16_2016_08_28")
if not os.path.exists(cpath):
    print ("Extracting %s ..." % (cpath))
    tar = tarfile.open(zpath, "r:gz")
    tar.extractall(path=cpath)
    tar.close()
    print ("Done!")
else:
    print ("%s Already Exists" % (cpath))

## Model

In [None]:
with slim.arg_scope(vgg.vgg_arg_scope()):
    mean = tf.constant([123.68, 116.78, 103.94], dtype=tf.float32, shape=[1,1,1,3])
    im_centered = resized_images - mean
    logits, end_points = vgg.vgg_16(inputs=im_centered, num_classes=10, is_training=True)
    preds = tf.nn.softmax(logits, axis=-1)

In [None]:
preds

In [None]:
tf.trainable_variables()

## Fine Tuning

In [None]:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
    logits=logits, labels=onehot_labels))
#optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

In [None]:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(onehot_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [None]:
sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth =True)))

In [None]:
sess.run(tf.global_variables_initializer())

In [None]:
ckpt_name = "vgg_16.ckpt"
ckpt_path = os.path.join(cpath, ckpt_name)

exclude = ['vgg_16/fc8']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)

saver.restore(sess, ckpt_path)

In [None]:
# train my model
print('Learning started. It takes sometime.')
max_test_acc = 0.
for epoch in range(training_epochs):
    avg_cost = 0.
    avg_train_acc = 0.
    avg_test_acc = 0.
    
    total_batch = int(n_data / batch_size)
    total_batch_test = int(n_data / batch_size)
    
    sess.run(train_init)
    for i in range(total_batch):
        acc, c, _ = sess.run([accuracy, cost, optimizer])
        avg_cost += c / total_batch
        avg_train_acc += acc / total_batch
        if i % 100 == 0:
            print("{} Epoch : {} images were used for training".format(epoch+1, i*100))
        
    sess.run(test_init)        
    for i in range(total_batch_test):
        acc = sess.run(accuracy)
        avg_test_acc += acc / total_batch_test
        if i == 300:
            print("Calculating test accuracy...")

    print('Epoch:', '{}'.format(epoch + 1), 'cost =', '{:.8f}'.format(avg_cost), 
          'train accuracy = ', '{:.4f}'.format(avg_train_acc), 
          'test accuracy = ', '{:.4f}'.format(avg_test_acc))


print('Learning Finished!')