# Transfer leaning using Inception-v3 network

In [None]:
#import sys
#sys.path.append("$HOME/models/research/slim/")

import os
import time

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image

import tensorflow as tf

slim = tf.contrib.slim

sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
data_path = '../data/flowers/'
tfrecord_filenames = [name for name in os.listdir(data_path) if 'tfrecord' in name]
train_data_filenames = [os.path.join(data_path, name) for name in tfrecord_filenames if 'train' in name]
validation_data_filenames = [os.path.join(data_path, name) for name in tfrecord_filenames if 'validation' in name]
print(train_data_filenames)
print(validation_data_filenames)

In [None]:
tfrecord_filenames

In [None]:
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=""),
              'image/format': tf.FixedLenFeature((), tf.string, default_value=""),
              'image/class/label': tf.FixedLenFeature((), tf.int64, default_value=0),
              'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
              'image/width': tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  image = tf.image.decode_jpeg(parsed_features["image/encoded"], channels=3)
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  label = tf.cast(parsed_features["image/class/label"], dtype=tf.int32)
  return image, label

In [None]:
def _preprocessing(image, label):
  """data augmentation function for training
  augmentation method is borrowed by inception code
  
  Args:
    image (3-rank Tensor): [?, ?, 3] for flower data
    label (0-rank Tensor): scalar value of corresponding image
    
  Returns:
    image (3-rank Tensor): [299, 299, 3] image transformed
    label (0-rank Tensor): scalar value of corresponding image
  """
  image = tf.image.resize_image_with_crop_or_pad(image, 299, 299)
  image = tf.image.random_flip_left_right(image)
  image = tf.image.random_brightness(image, max_delta=32./255.)
  image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  image = tf.image.random_hue(image, max_delta=0.2)
  image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  image = tf.clip_by_value(image, 0.0, 1.0)
  # Finally, rescale to [-1, 1] instead of [0, 1)
  image = tf.subtract(image, 0.5)
  image = tf.multiply(image, 2.0)
  return image, label

In [None]:
def _central_crop(image, label):
  """data augmentation function for training
  augmentation method is borrowed by inception code
  
  Args:
    image (3-rank Tensor): [?, ?, 3] for flower data
    label (0-rank Tensor): scalar value of corresponding image
    
  Returns:
    image (3-rank Tensor): [299, 299, 3] image transformed
    label (0-rank Tensor): scalar value of corresponding image
  """
  image = tf.image.central_crop(image, central_fraction=0.875)
  image = tf.image.resize_images(image, [299, 299])
  # Finally, rescale to [-1, 1] instead of [0, 1)
  image = tf.subtract(image, 0.5)
  image = tf.multiply(image, 2.0)
  return image, label

In [None]:
batch_size = 32

# for train
train_dataset = tf.data.TFRecordDataset(train_data_filenames)
train_dataset = train_dataset.map(_parse_function)
train_dataset = train_dataset.map(_preprocessing)
train_dataset = train_dataset.shuffle(buffer_size = 10000)
train_dataset = train_dataset.batch(batch_size = batch_size)
print(train_dataset)

# for validation
validation_dataset = tf.data.TFRecordDataset(validation_data_filenames)
validation_dataset = validation_dataset.map(_parse_function)
validation_dataset = validation_dataset.map(_central_crop)
#validation_dataset = validation_dataset.shuffle(buffer_size = 10000)
#validation_dataset = validation_dataset.batch(batch_size = batch_size)
validation_dataset = validation_dataset.batch(batch_size = 350)
print(validation_dataset)

In [None]:
# tf.data.Iterator.from_string_handle의 output_shapes는 default = None이지만 꼭 값을 넣는 게 좋음
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,
                                               train_dataset.output_types,
                                               train_dataset.output_shapes)
inputs, labels = iterator.get_next()

### Load a Inception-v3 graph

In [None]:
from nets import inception_v3

In [None]:
is_training = tf.placeholder(tf.bool)
with slim.arg_scope(inception_v3.inception_v3_arg_scope(weight_decay=0.00004)):
  logits, _ = inception_v3.inception_v3(inputs, num_classes=5, is_training=is_training)

In [None]:
print(logits)

In [None]:
def _get_variables_to_train(trainable_scopes):
  """Returns a list of variables to train.

  Returns:
    A list of variables to train by the optimizer.
  """
  if trainable_scopes is None:
    return tf.trainable_variables()
  else:
    scopes = [scope.strip() for scope in trainable_scopes.split(',')]

  variables_to_train = []
  for scope in scopes:
    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
    variables_to_train.extend(variables)
  return variables_to_train

In [None]:
trainable_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits'
variables_to_train = _get_variables_to_train(trainable_scopes)

for var in variables_to_train:
  print(var)

In [None]:
#y_one_hot = tf.one_hot(y, depth=10)
#cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=y_one_hot, logits=y_pred)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

l2_regualrization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

with tf.name_scope('total_loss'):
  total_loss = cross_entropy + l2_regualrization_loss

In [None]:
# Batch normalization update
batchnorm_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

# Add dependency to compute batchnorm_updates.
with tf.control_dependencies(batchnorm_update_ops):
  train_step = tf.train.RMSPropOptimizer(0.01).minimize(total_loss,
                                                        var_list=tf.trainable_variables()[-10:])

In [None]:
graph_location = 'graphs/11.transfer.learning.with.inception_v3'
print('Saving graph to: %s' % graph_location)
train_writer = tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph())

In [None]:
with tf.name_scope('summaries'):
  tf.summary.scalar('loss/cross_entropy', cross_entropy)
  tf.summary.scalar('loss/l2_regualrization_loss', l2_regualrization_loss)
  tf.summary.scalar('loss/total_loss', total_loss)
  tf.summary.image('images', inputs)
  for var in tf.trainable_variables():
    tf.summary.histogram(var.op.name, var)
  # merge all summaries
  summary_op = tf.summary.merge_all()

In [None]:
def _get_init_fn(checkpoint_exclude_scopes):
  """Returns a function run by the chief worker to warm-start the training.

  Note that the init_fn is only run when initializing the model during the very
  first global step.

  Returns:
    An init function run by the supervisor.
  """
  exclusions = []
  if checkpoint_exclude_scopes:
    exclusions = [scope.strip()
                  for scope in checkpoint_exclude_scopes.split(',')]

  # TODO(sguada) variables.filter_variables()
  variables_to_restore = []
  for var in slim.get_model_variables():
    for exclusion in exclusions:
      if var.op.name.startswith(exclusion):
        break
    else:
      variables_to_restore.append(var)

  return variables_to_restore

In [None]:
checkpoint_exclude_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits'
variables_to_restore = _get_init_fn(checkpoint_exclude_scopes)
for var in variables_to_restore:
  print(var)

### Download the Inception-v3 checkpoint: 

```
$ CHECKPOINT_DIR='../checkpoints'
$ mkdir ${CHECKPOINT_DIR}
$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
$ tar -xvf inception_v3_2016_08_28.tar.gz
$ mv inception_v3_2016_08_28.tar.gz ${CHECKPOINT_DIR}
$ rm inception_v3_2016_08_28.tar.gz
```

In [None]:
# Download the Inception-v3 checkpoint: 
# if you already have a inception_v3.ckpt then skip and comment below commands
#!mkdir ../checkpoints
#!wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
#!tar -xvf inception_v3_2016_08_28.tar.gz
#!mv inception_v3.ckpt ../checkpoints
#!rm inception_v3_2016_08_28.tar.gz
#print('done')

### Restore Inception_v3 weights using `tf.saver.restore`

In [None]:
saver = tf.train.Saver(variables_to_restore)

sess = tf.Session(config=sess_config)
sess.run(tf.global_variables_initializer())
# use saver object to load variables from the saved model
saver.restore(sess, "../checkpoints/inception_v3.ckpt")

# train_iterator
train_iterator = train_dataset.make_initializable_iterator()
train_handle = sess.run(train_iterator.string_handle())

# Train
max_epochs = 10
step = 0
for epochs in range(max_epochs):
  sess.run(train_iterator.initializer)

  while True:
    try:
      start_time = time.time()
      _, loss = sess.run([train_step, total_loss],
                         feed_dict={handle: train_handle,
                                    is_training: True})
      if step % 10 == 0:
        duration = time.time() - start_time
        examples_per_sec = batch_size / float(duration)
        print("epochs: {}, step: {}, loss: {:g}, ({:.2f} examples/sec; {:.3f} sec/batch)".format(epochs, step, loss, examples_per_sec, duration))
        
      if step % 2000 == 0:
        # summary
        summary_str = sess.run(summary_op, feed_dict={handle: train_handle, is_training: False})
        train_writer.add_summary(summary_str, global_step=step)
        
      step += 1
      #if step > 100:
      #  break

    except tf.errors.OutOfRangeError:
      print("End of dataset")  # ==> "End of dataset"
      break

train_writer.close()
print("training done!")

In [None]:
# validation_iterator
validation_iterator = validation_dataset.make_initializable_iterator()
validation_handle = sess.run(validation_iterator.string_handle())
sess.run(validation_iterator.initializer)

In [None]:
accuracy, acc_op = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, 1), name='accuracy')
sess.run(tf.local_variables_initializer())

sess.run(acc_op, feed_dict={handle: validation_handle, is_training: False})
print("test accuracy:", sess.run(accuracy))