# Transfer Learning in Tensorflow with Pre-trained Inception-Resnet Model from TF-Slim

In this project, we are going to use a pre-trained inception-resnet model fromTensorflow Slim (TF-Slim)
to classify images from another dataset. TF-Slim is a wrapper for tensorflow that makes it particularly easy to use pre-trained models for image classification,
amongst other things.
<br>
<br> You can check out the official documentation of TF-Slim here (this tutorial is adapted from the documentation): https://github.com/tensorflow/models/blob/master/slim/README.md
<br>As well as a slim walk-through with examples: https://github.com/tensorflow/models/blob/master/slim/slim_walkthrough.ipynb

<br>One modification to adapt a pre-trained model to your use case, is to remove the final pre-softmax layer and to replace it
with weights/logits that correspond to the number of classes/labels in your (smaller) dataset. We then tune the model on this new dataset.

<br>Before you write any code, be sure to install TF-Slim and the corresponding image models library.
The instructions are found in the official documentation: https://github.com/tensorflow/models/blob/master/slim/README.md
<br><br>Note, first install the models library and then place slim in this models directory.
<br>I found the following resource clearer than the documentation for figuring out how to create and structure the directories:
https://hackaday.io/project/20448-elephant-ai/log/56896-retraining-tensorflow-inception-v3-using-tensorflow-slim-part-1

<br>I also assume you've converted your data into TFRecord format, see my other post on how to do that.

<br> Okay! Now we can code!
First, let us import all the things we will need.

Note that the imports "inception_resnet_v2" and "inception_preprocessing" come from the TF-Slim models library.

In [None]:
import os
import tensorflow as tf
import inception_resnet_v2  #From: https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py
import inception_preprocessing # From: https://github.com/tensorflow/models/blob/master/slim/preprocessing/inception_preprocessing.py
from tensorflow.contrib import slim

## Locate Directories

The next step is to locate the relevant directories (incl. checkpoint), and to provide some information about your dataset. 

In [None]:
#Dataset directory where the tfrecord files are located
dataset_dir = "/Users/charujaiswal/PycharmProjects/models/slim/DATASET/flowers"
# I created a directory within slim where I placed the converted tfrecord files. 

#State where your log file is. If it doesn't exist yet, it will be created below in a function.
log_dir = './log'

#Location of checkpoint file (see hackaday.io as mentioned aboved for how to structure directory)
checkpoint_file = "/Users/charujaiswal/PycharmProjects/models/slim/CHECKPOINTDIR/inception_resnet_v2_2016_08_30.ckpt"

#Locate the labels file and then read it
labels_file = '/Users/charujaiswal/PycharmProjects/models/slim/DATASET/flowers/labels.txt'
labels = open(labels_file, 'r')

#The image size you're resizing your images to-- here we use the default inception size of 299.
image_size = inception_resnet_v2.inception_resnet_v2.default_image_size # 299

#Number of classes to predict:
num_classes = 5 # the flowers dataset has 5 classes

#Batch size 
batch_size= 10

#Number of epochs 
num_epochs= 1

Slim's documentation has an informative file called "learning.py", located here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py

<br> In it is a useful generalized overview of how to fine-tune a model from a checkpoint, we will loosely follow this process, I've reproduced it below:

In [None]:
# Create the train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
checkpoint_path = '/path/to/old_model_checkpoint'
# Specify the variables to restore via a list of inclusion or exclusion
# patterns:
variables_to_restore = slim.get_variables_to_restore(
  include=["conv"], exclude=["fc8", "fc9])
# or
variables_to_restore = slim.get_variables_to_restore(exclude=["conv"])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
  checkpoint_path, variables_to_restore)
# Create an initial assignment function.
def InitAssignFn(sess):
  sess.run(init_assign_op, init_feed_dict)
# Run training.
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)

## Reading data with TF-Slim: Dataset and DatasetDataProvider 
Before we can write the train_op etc. we will need to figure out how to read the data. Reading data in TF-Slim has two parts: a Dataset (descriptor of dataset) and a DatasetDataProvider (actually reads the data). 

<br> We took care of the Dataset when we converted our data into TFRecord files. DatasetDataProvider is a class that reads the data, and can be configured to read the data in various ways-- including if your data is sharded.
The first step in this process of feeding the data correctly is to get a dataset tuple with instructions for reading our data. We will write this into a function that is adapted from the function "get_split" that is in the file: models/slim/datasets/flowers.py. 


In [None]:
########### LOADING THE DATASET ##########

# Here we create a function that creates a Dataset class which will give us TFRecord files to feed into a queue in parallel.
def get_split(split_name, dataset_dir, file_pattern=file_pattern):

    """Gets a dataset tuple with instructions for reading the data.
    Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.

    Returns:
    A `Dataset` namedtuple.
    Raises:
    ValueError: if `split_name` is not a valid train/validation split.
    """
    # First check whether the split_name is train or validation
    if split_name not in ['train', 'validation']:
      raise ValueError('The split_name %s is not recognized.' % (split_name))

    # Create a path to locate the tfrecord files, using file_pattern of the name
    file_pattern_path = os.path.join(dataset_dir, file_pattern % (split_name))

    # Create a reader that outputs the records from a tfrecord file
    reader = tf.TFRecordReader

    #Count the number of samples
    num_samples = 0
    file_pattern_for_counting = 'flowers_tfrecord' + split_name
    tfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir) if
                          file.startswith(file_pattern_for_counting)]
    for tfrecord_file in tfrecords_to_count:
        for record in tf.python_io.tf_record_iterator(tfrecord_file):
            num_samples += 1

    ##Next we create the keys_to_features and item_to_handlers dictionaries and the decoder which are used later by the
    # DatasetDataProvider object to decode tf-example into tensor objects.

    ''' More detail on the DatasetDataProvider later
    '''

    # Create the keys_to_features dictionary for the decoder
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
        'image/class/label': tf.FixedLenFeature(
            [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
    }

    # Create the items_to_handlers dictionary for the decoder.
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(),
        'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }

    # Start to create the decoder
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

    # Create the labels_to_name file
    labels_to_name_dict = labels_to_name

    # Actually create the dataset
    dataset = slim.dataset.Dataset(
        data_sources=file_pattern_path,
        decoder=decoder,
        reader=reader,
        num_readers=4,
        num_samples=num_samples,
        num_classes=num_classes,
        labels_to_name=labels_to_name_dict,
        items_to_descriptions=items_to_descriptions)

    return dataset

# DatasetDataProvider and How to Load Data 

The next thing we will do is load tensors from out dataset into batches for training-- via the DatasetDataProvider object. We will create a function called "load_batch" that will 1) process our images to fit the format required by the inception network, 2) return the processed images as batches of tensors, 3) return the corresponding labels and 4) return the raw images in case you want to visualize them. 


In [None]:
def load_batch(dataset, batch_size=batch_size, height=image_size, width=image_size, is_training=True):
    """Loads a single batch of data for training.

    Args:
      dataset: The dataset to load, created in the get_split function.
      batch_size: The number of images in the batch.
      height(int): int value that is the size the image will be resized to during preprocessing
      width: The size that the image will be resized to during preprocessing
      is_training: Whether or not we're currently training or evaluating.

    Returns:
      images: A Tensor of size [batch_size, height, width, channels(3)], image samples that have been preprocessed, that contain one batch of images.
      images_raw: A Tensor of size [batch_size, height, width, 3], image samples that can be used for visualization.
      labels: A Tensor of size [batch_size], whose values range between 0 and dataset.num_classes (requires one-hot encodings)
    """

    # First create the data_provider object
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        common_queue_capacity=32,
        common_queue_min=8)

    # Obtain the raw image using the get method
    raw_image, label = data_provider.get(['image', 'label'])

    # Perform the correct preprocessing for this image depending if it is training or evaluating
    image = inception_preprocessing.preprocess_image(raw_image, height, width, is_training)

    # Preprocess the image for display purposes.
    raw_image = tf.expand_dims(raw_image, 0)
    raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
    raw_image = tf.squeeze(raw_image)

    # Batch up the images by enqueing the tensors internally in a FIFO queue and dequeueing many elements with tf.train.batch.
    images, raw_images, labels = tf.train.batch(
        [image, raw_image, label],
        batch_size=batch_size,
        num_threads=4,
        capacity=4 * batch_size,
        allow_smaller_final_batch=True)

    return images, raw_images, labels

# Create Graph 

<br> Now we will create a graph in a run function, and will also create a log directory that will be useful if you want to visualize the training in tensorboard. 


In [None]:
def run():
    # Create the log directory here. When training, it's helpful to train and evaluate progress in real-time.
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    # ======================= TRAINING =========================
    # Construct the graph and build our model
    with tf.Graph().as_default():
        tf.logging.set_verbosity(tf.logging.INFO)  # Sets the threshold for what messages will be logged, in this case it is set to 'INFO'

        # Get dataset and load a batch
        dataset = get_split('train', dataset_dir, file_pattern=file_pattern)
        images, _, labels = load_batch(dataset, batch_size=batch_size)


        #Create the model, use the default arg scope to configure the batch norm parameters.
        with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
            logits, end_points = inception_resnet_v2.inception_resnet_v2(images, num_classes=dataset.num_classes, is_training=True)

        #Scopes that you want to exclude for restoration, from the checkpoint
        exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        #One-hot encode the labels
        one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)

        #Specify the loss function;
        # slim.losses.softmax_cross_entropy(logits, one_hot_labels)
        # total_loss = slim.losses.get_total_loss()
        loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits)
        total_loss = tf.losses.get_total_loss()  # obtain the regularization losses as well

        # Specify the optimizer and create the train op:
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        # State the metrics that you want to predict.
        predictions = tf.argmax(end_points['Predictions'], 1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)

        # Create some summaries to visualize the training process:
        tf.summary.scalar('losses/Total_Loss', total_loss)
        tf.summary.scalar('accuracy', accuracy)


        #Now we will restore the variables from the checkpoint file
        init_fn = slim.assign_from_checkpoint_fn(checkpoint_file, slim.get_variables_to_restore())


        # Run the training inside a session:

        final_loss = slim.learning.train(
            train_op,
            logdir=log_dir,
            init_fn=init_fn,
            number_of_steps=num_epochs)


    print('Finished training. Last batch loss %f' % final_loss)



if __name__ == '__main__':
    run()


Conveniently, the slim.learning.train function used above does the following:
1) For each iteration, evaluates the train_op, which updates the parameters using the optimizer applied to the current minibatch. Also, updates the global_step.
2) Occasionally store the model checkpoint in the specified directory. This is useful in case your machine crashes - then you can simply restart from the specified checkpoint.

