# Transfer leaning using Inception-v3 network

In [1]:
#import sys
#sys.path.append("$HOME/models/research/slim/")

import os
import time

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image

import tensorflow as tf

slim = tf.contrib.slim

sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))

In [2]:
data_path = '../data/flowers/'
tfrecord_filenames = os.listdir(data_path)
train_data_filenames = [os.path.join(data_path, name) for name in tfrecord_filenames if 'train' in name]
validation_data_filenames = [os.path.join(data_path, name) for name in tfrecord_filenames if 'validation' in name]
print(train_data_filenames)
print(validation_data_filenames)

['../data/flowers/flowers_train_00003-of-00005.tfrecord', '../data/flowers/flowers_train_00002-of-00005.tfrecord', '../data/flowers/flowers_train_00001-of-00005.tfrecord', '../data/flowers/flowers_train_00004-of-00005.tfrecord', '../data/flowers/flowers_train_00000-of-00005.tfrecord']
['../data/flowers/flowers_validation_00004-of-00005.tfrecord', '../data/flowers/flowers_validation_00001-of-00005.tfrecord', '../data/flowers/flowers_validation_00000-of-00005.tfrecord', '../data/flowers/flowers_validation_00003-of-00005.tfrecord', '../data/flowers/flowers_validation_00002-of-00005.tfrecord']


In [3]:
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=""),
              'image/format': tf.FixedLenFeature((), tf.string, default_value=""),
              'image/class/label': tf.FixedLenFeature((), tf.int64, default_value=0),
              'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
              'image/width': tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  image = tf.image.decode_jpeg(parsed_features["image/encoded"], channels=3)
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  label = tf.cast(parsed_features["image/class/label"], dtype=tf.int32)
  return image, label

In [4]:
def _preprocessing(image, label):
  """data augmentation function for training
  augmentation method is borrowed by inception code
  
  Args:
    image (3-rank Tensor): [?, ?, 3] for flower data
    label (0-rank Tensor): scalar value of corresponding image
    
  Returns:
    image (3-rank Tensor): [299, 299, 3] image transformed
    label (0-rank Tensor): scalar value of corresponding image
  """
  image = tf.image.resize_image_with_crop_or_pad(image, 299, 299)
  image = tf.image.random_flip_left_right(image)
  image = tf.image.random_brightness(image, max_delta=32./255.)
  image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  image = tf.image.random_hue(image, max_delta=0.2)
  image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  image = tf.clip_by_value(image, 0.0, 1.0)
  # Finally, rescale to [-1, 1] instead of [0, 1)
  image = tf.subtract(image, 0.5)
  image = tf.multiply(image, 2.0)
  return image, label

In [5]:
def _central_crop(image, label):
  """data augmentation function for training
  augmentation method is borrowed by inception code
  
  Args:
    image (3-rank Tensor): [?, ?, 3] for flower data
    label (0-rank Tensor): scalar value of corresponding image
    
  Returns:
    image (3-rank Tensor): [299, 299, 3] image transformed
    label (0-rank Tensor): scalar value of corresponding image
  """
  image = tf.image.central_crop(image, central_fraction=0.875)
  image = tf.image.resize_images(image, [299, 299])
  # Finally, rescale to [-1, 1] instead of [0, 1)
  image = tf.subtract(image, 0.5)
  image = tf.multiply(image, 2.0)
  return image, label

In [6]:
batch_size = 32

# for train
train_dataset = tf.data.TFRecordDataset(train_data_filenames)
train_dataset = train_dataset.map(_parse_function)
train_dataset = train_dataset.map(_preprocessing)
train_dataset = train_dataset.shuffle(buffer_size = 10000)
train_dataset = train_dataset.batch(batch_size = batch_size)
print(train_dataset)

# for validation
validation_dataset = tf.data.TFRecordDataset(validation_data_filenames)
validation_dataset = validation_dataset.map(_parse_function)
validation_dataset = validation_dataset.map(_central_crop)
validation_dataset = validation_dataset.shuffle(buffer_size = 10000)
validation_dataset = validation_dataset.batch(batch_size = batch_size)
print(validation_dataset)

<BatchDataset shapes: ((?, 299, 299, 3), (?,)), types: (tf.float32, tf.int32)>
<BatchDataset shapes: ((?, 299, 299, 3), (?,)), types: (tf.float32, tf.int32)>


In [7]:
# tf.data.Iterator.from_string_handle의 output_shapes는 default = None이지만 꼭 값을 넣는 게 좋음
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,
                                               train_dataset.output_types,
                                               train_dataset.output_shapes)
inputs, labels = iterator.get_next()

### Load a Inception-v3 graph

In [8]:
from nets import inception_v3

In [9]:
is_training = tf.placeholder(tf.bool)
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
  logits, _ = inception_v3.inception_v3(inputs, num_classes=5, is_training=is_training)

In [10]:
print(logits)

Tensor("InceptionV3/Logits/SpatialSqueeze:0", shape=(?, 5), dtype=float32)


In [11]:
#y_one_hot = tf.one_hot(y, depth=10)
#cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=y_one_hot, logits=y_pred)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

l2_regualrization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

with tf.name_scope('total_loss'):
  total_loss = cross_entropy + l2_regualrization_loss

In [12]:
# Batch normalization update
batchnorm_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

# Add dependency to compute batchnorm_updates.
with tf.control_dependencies(batchnorm_update_ops):
  train_step = tf.train.RMSPropOptimizer(1e-4).minimize(total_loss)

In [13]:
with tf.Session() as sess:
  writer = tf.summary.FileWriter("./graphs/code14_transfer_learning", sess.graph)
  writer.close()

### Download the Inception-v3 checkpoint: 

```
$ CHECKPOINT_DIR='checkpoints'
$ mkdir ${CHECKPOINT_DIR}
$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
$ tar -xvf inception_v3_2016_08_28.tar.gz
$ mv inception_v3_2016_08_28.tar.gz ${CHECKPOINT_DIR}
$ rm inception_v3_2016_08_28.tar.gz
```

### Restore Inception_v3 weights using `tf.saver.restore`

In [14]:
saver = tf.train.Saver(tf.trainable_variables()[:-8])

sess = tf.Session(config=sess_config)
sess.run(tf.global_variables_initializer())
# use saver object to load variables from the saved model
saver.restore(sess, "../checkpoints/inception_v3.ckpt")

# train_iterator
train_iterator = train_dataset.make_initializable_iterator()
train_handle = sess.run(train_iterator.string_handle())

# Train
max_epochs = 1
step = 0
for epochs in range(max_epochs):
  sess.run(train_iterator.initializer)

  while True:
    try:
      start_time = time.time()
      _, loss = sess.run([train_step, total_loss],
                         feed_dict={handle: train_handle,
                                    is_training: True})
      if step % 1 == 0:
        duration = time.time() - start_time
        examples_per_sec = batch_size / float(duration)
        print("epochs: {}, step: {}, loss: {:g}, ({:.2f} examples/sec; {:.3f} sec/batch)".format(epochs, step, loss, examples_per_sec, duration))
        
      #if step % 200 == 0:
      #  # summary
      #  summary_str = sess.run(summary_op, feed_dict={handle: train_handle, is_training: False})
      #  train_writer.add_summary(summary_str, global_step=step)
        
      step += 1
      #if step > 100:
      #  break

    except tf.errors.OutOfRangeError:
      print("End of dataset")  # ==> "End of dataset"

train_writer.close()
print("training done!")

INFO:tensorflow:Restoring parameters from ../checkpoints/inception_v3.ckpt
epochs: 0, step: 0, loss: 1.83767, (0.34 examples/sec; 92.832 sec/batch)
epochs: 0, step: 1, loss: 1.88018, (0.92 examples/sec; 34.608 sec/batch)


KeyboardInterrupt: 