# <div class="alert alert-block alert-info" style="border-width:4px">Resnet model training on CIFAR10 odd numbered classes on SBrain</div>

# Introduction

THIS NOTEBOOK CAN TAKE LONG TIME TO RUN. SO IT IS MAINLY FOR REFERENCE. IT MAY TAKE SOME TIME IF YOU TRY TO RUN IT.

In this notebook, we will build a Resnet model and train it on CIFAR10 dataset, but only on odd numbered classes. When we do transfer learning in <a href="2_TransferLearning.ipynb">TransferLearning Notebook</a>, this model will serve as the model that we transfer from. 

Below, we define an input_function to be used in SBrain. More details on input_function is in <a href="2_TransferLearning.ipynb">TransferLearning Notebook</a>.
In this function, we download the CIFAR-10 data, load TFRecordDataset from it and filter it for only the odd numbered classes. 



In [None]:
from sbrain.learning.experiment import *
from sbrain.dataset.dataset import *

def only_odd_classes_input_function(mode, batch_size, params):
    import os
    import tensorflow as tf
    import sys
    import tarfile
    import pickle
    CLASS_INDEX_MOD = 0

    HEIGHT = 32
    WIDTH = 32
    DEPTH = 3

    CIFAR_FILENAME = 'cifar-10-python.tar.gz'
    CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
    CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'

    def download_and_extract(data_dir):
        # download CIFAR-10 if not already downloaded.
        tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir,
                                                      CIFAR_DOWNLOAD_URL)
        tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),
                     'r:gz').extractall(data_dir)

    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _get_file_names():
        """Returns the file names expected to exist in the input_dir."""
        file_names = {}
        file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
        file_names['validation'] = ['data_batch_5']
        file_names['eval'] = ['test_batch']
        return file_names

    def read_pickle_from_file(filename):
        with tf.gfile.Open(filename, 'rb') as f:
            if sys.version_info >= (3, 0):
                data_dict = pickle.load(f, encoding='bytes')
            else:
                data_dict = pickle.load(f)
        return data_dict

    def convert_to_tfrecord(input_files, output_file):
        """Converts a file to TFRecords."""
        print('Generating %s' % output_file)
        with tf.python_io.TFRecordWriter(output_file) as record_writer:
            for input_file in input_files:
                data_dict = read_pickle_from_file(input_file)
                data = data_dict[b'data']
                labels = data_dict[b'labels']
                num_entries_in_batch = len(labels)
                for i in range(num_entries_in_batch):
                    example = tf.train.Example(features=tf.train.Features(
                        feature={
                            'image': _bytes_feature(data[i].tobytes()),
                            'label': _int64_feature(labels[i])
                        }))
                    record_writer.write(example.SerializeToString())

    def setup_cifar10_data(data_dir):
        train_dir = os.path.join(data_dir, 'train.tfrecords')
        validation_dir = os.path.join(data_dir, 'validation.tfrecords')
        eval_dir = os.path.join(data_dir, 'eval.tfrecords')
        if os.path.exists(train_dir) and os.path.exists(validation_dir) and os.path.exists(eval_dir):
            print("Data already present.")
        else:
            print('Download from {} and extract. Wait for download complete message..'.format(CIFAR_DOWNLOAD_URL))
            download_and_extract(data_dir)
            print('Download completed')
            file_names = _get_file_names()
            input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
            for mode, files in file_names.items():
                input_files = [os.path.join(input_dir, f) for f in files]
                output_file = os.path.join(data_dir, mode + '.tfrecords')
                try:
                    os.remove(output_file)
                except OSError:
                    pass
                # Convert to tf.train.Example and write the to TFRecords.
                convert_to_tfrecord(input_files, output_file)
            print('Done!')

    class Cifar10DataSet(object):
        """Cifar10 data set.
        Described by http://www.cs.toronto.edu/~kriz/cifar.html.
        """

        def __init__(self, data_dir, subset='train', use_distortion=True):
            self.data_dir = data_dir
            self.subset = subset
            self.use_distortion = use_distortion

        def get_filenames(self):
            if self.subset in ['train', 'validation', 'eval']:
                return [os.path.join(self.data_dir, self.subset + '.tfrecords')]
            else:
                raise ValueError('Invalid data subset "%s"' % self.subset)

        def parser(self, serialized_example):
            """Parses a single tf.Example into image and label tensors."""
            # Dimensions of the images in the CIFAR-10 dataset.
            # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
            # input format.
            features = tf.parse_single_example(
                serialized_example,
                features={
                    'image': tf.FixedLenFeature([], tf.string),
                    'label': tf.FixedLenFeature([], tf.int64),
                })
            image = tf.decode_raw(features['image'], tf.uint8)
            image.set_shape([DEPTH * HEIGHT * WIDTH])

            # Reshape from [depth * height * width] to [depth, height, width].
            image = tf.cast(
                tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
                tf.float32)
            label = tf.cast(features['label'], tf.int32)

            # Custom preprocessing.
            image = self.preprocess(image)

            return ({"data": image}, label)

        def adjust_labels(self, data, label):
            return data, tf.floordiv(label, 2)

        def filter_fun(self, data, label):
            return tf.equal(tf.mod(label, 2), CLASS_INDEX_MOD)

        def get_dataset(self):
            """Read the images and labels from 'filenames'."""
            filenames = self.get_filenames()

            dataset = tf.data.TFRecordDataset(filenames)

            # Parse records.
            dataset = dataset.map(self.parser).filter(self.filter_fun).map(self.adjust_labels)
            return dataset

        def preprocess(self, image):
            """Preprocess a single image in [height, width, depth] layout."""
            if self.subset == 'train' and self.use_distortion:
                # Pad 4 pixels on each dimension of feature map, done in mini-batch
                image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
                image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
                image = tf.image.random_flip_left_right(image)
            return image

    use_distortion = False
    if mode == tf.estimator.ModeKeys.TRAIN:
        subset = 'train'
        use_distortion = True
    elif mode == tf.estimator.ModeKeys.EVAL:
        subset = 'validation'
    else:
        subset = 'eval'
    data_dir = "/workspace/shared-dir/sample-notebooks/demo-data/learning/OddClasses/"

    setup_cifar10_data(data_dir)

    dataset = Cifar10DataSet(data_dir, subset, use_distortion).get_dataset()

    dataset = dataset.shuffle(1000).batch(batch_size)
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.repeat()
    return dataset

Here we define, a model function where we build a RESNET model to be trained on the above dataset.

In [None]:
def cifar_model_function(features, labels, mode, params):
    import tensorflow as tf
    import numpy as np
    import os
    ######################################## Resnet model #########################################

    class ResNet(object):
        """ResNet model."""

        def __init__(self, is_training, data_format, batch_norm_decay, batch_norm_epsilon):
            """ResNet constructor.

            Args:
              is_training: if build training or inference model.
              data_format: the data_format used during computation.
                           one of 'channels_first' or 'channels_last'.
            """
            self._batch_norm_decay = batch_norm_decay
            self._batch_norm_epsilon = batch_norm_epsilon
            self._is_training = is_training
            assert data_format in ('channels_first', 'channels_last')
            self._data_format = data_format

        def forward_pass(self, x):
            raise NotImplementedError(
                'forward_pass() is implemented in ResNet sub classes')

        def _residual_v1(self,
                         x,
                         kernel_size,
                         in_filter,
                         out_filter,
                         stride,
                         activate_before_residual=False):
            """Residual unit with 2 sub layers, using Plan A for shortcut connection."""

            del activate_before_residual
            with tf.name_scope('residual_v1') as name_scope:
                orig_x = x

                x = self._conv(x, kernel_size, out_filter, stride)
                x = self._batch_norm(x)
                x = self._relu(x)

                x = self._conv(x, kernel_size, out_filter, 1)
                x = self._batch_norm(x)

                if in_filter != out_filter:
                    orig_x = self._avg_pool(orig_x, stride, stride)
                    pad = (out_filter - in_filter) // 2
                    if self._data_format == 'channels_first':
                        orig_x = tf.pad(orig_x, [[0, 0], [pad, pad], [0, 0], [0, 0]])
                    else:
                        orig_x = tf.pad(orig_x, [[0, 0], [0, 0], [0, 0], [pad, pad]])

                x = self._relu(tf.add(x, orig_x))

                tf.logging.info('image after unit %s: %s', name_scope, x.get_shape())
                return x

        def _conv(self, x, kernel_size, filters, strides, is_atrous=False):
            """Convolution."""

            padding = 'SAME'
            if not is_atrous and strides > 1:
                pad = kernel_size - 1
                pad_beg = pad // 2
                pad_end = pad - pad_beg
                if self._data_format == 'channels_first':
                    x = tf.pad(x, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
                else:
                    x = tf.pad(x, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
                padding = 'VALID'
            return tf.layers.conv2d(
                inputs=x,
                kernel_size=kernel_size,
                filters=filters,
                strides=strides,
                padding=padding,
                use_bias=False,
                data_format=self._data_format)

        def _batch_norm(self, x):
            if self._data_format == 'channels_first':
                data_format = 'NCHW'
            else:
                data_format = 'NHWC'
            return tf.contrib.layers.batch_norm(
                x,
                decay=self._batch_norm_decay,
                center=True,
                scale=True,
                epsilon=self._batch_norm_epsilon,
                is_training=self._is_training,
                fused=True,
                data_format=data_format)

        def _relu(self, x):
            return tf.nn.relu(x)

        def _fully_connected(self, x, out_dim):
            with tf.name_scope('fully_connected') as name_scope:
                x = tf.layers.dense(x, out_dim)

            tf.logging.info('image after unit %s: %s', name_scope, x.get_shape())
            return x

        def _avg_pool(self, x, pool_size, stride):
            with tf.name_scope('avg_pool') as name_scope:
                x = tf.layers.average_pooling2d(
                    x, pool_size, stride, 'SAME', data_format=self._data_format)

            tf.logging.info('image after unit %s: %s', name_scope, x.get_shape())
            return x

        def _global_avg_pool(self, x):
            with tf.name_scope('global_avg_pool') as name_scope:
                assert x.get_shape().ndims == 4
                if self._data_format == 'channels_first':
                    x = tf.reduce_mean(x, [2, 3])
                else:
                    x = tf.reduce_mean(x, [1, 2])
            tf.logging.info('image after unit %s: %s', name_scope, x.get_shape())
            return x

    ####################################### end Resnet base model definition ######################################

    ####################################### start cifar resnet subclassing ########################################

    class ResNetCifar10(ResNet):
        """Cifar10 model with ResNetV1 and basic residual block."""

        def __init__(self,
                     num_layers,
                     is_training,
                     batch_norm_decay,
                     batch_norm_epsilon,
                     data_format='channels_first'):
            super(ResNetCifar10, self).__init__(
                is_training,
                data_format,
                batch_norm_decay,
                batch_norm_epsilon
            )
            self.n = (num_layers - 2) // 6
            # Add one in case label starts with 1. No impact if label starts with 0.
            self.num_classes = 5 + 1
            self.filters = [16, 16, 32, 64]
            self.strides = [1, 2, 2]

        def forward_pass(self, x, input_data_format='channels_last'):
            """Build the core model within the graph."""
            if self._data_format != input_data_format:
                if input_data_format == 'channels_last':
                    # Computation requires channels_first.
                    x = tf.transpose(x, [0, 3, 1, 2])
                else:
                    # Computation requires channels_last.
                    x = tf.transpose(x, [0, 2, 3, 1])

            # Image standardization.
            x = x / 128 - 1

            x = self._conv(x, 3, 16, 1)
            x = self._batch_norm(x)
            x = self._relu(x)

            # Use basic (non-bottleneck) block and ResNet V1 (post-activation).
            res_func = self._residual_v1

            # 3 stages of block stacking.
            for i in range(3):
                with tf.name_scope('stage'):
                    for j in range(self.n):
                        if j == 0:
                            # First block in a stage, filters and strides may change.
                            x = res_func(x, 3, self.filters[i], self.filters[i + 1],
                                         self.strides[i])
                        else:
                            # Following blocks in a stage, constant filters and unit stride.
                            x = res_func(x, 3, self.filters[i + 1], self.filters[i + 1], 1)

            x = self._global_avg_pool(x)
            x = self._fully_connected(x, self.num_classes)

            return x
    ####################################### end cifar resnet subclassing ########################################

    ####################################### start loss etc definitions ##########################################

    num_layers = 44
    batch_norm_decay = 0.997
    batch_norm_epsilon = 1e-5
    weight_decay = 2e-4
    learning_rate = 0.1

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    model = ResNetCifar10(
        num_layers,
        batch_norm_decay=batch_norm_decay,
        batch_norm_epsilon=batch_norm_epsilon,
        is_training=is_training,
        data_format="channels_last")

    data = tf.feature_column.input_layer(features, [tf.feature_column.numeric_column("data", shape=(32,32,3))])
    data = tf.reshape(data, (-1,32,32,3))
    logits = model.forward_pass(data, input_data_format='channels_last')

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': tf.argmax(input=logits, axis=1),
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
    loss = tf.reduce_mean(loss)
    model_params = tf.trainable_variables()
    loss += weight_decay * tf.add_n([tf.nn.l2_loss(v) for v in model_params])

    ####################################### end loss etc definitions ############################################
    # Compute evaluation metrics.
    accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, axis=1),
                                   name='acc_op')
    metrics = {'accuracy': accuracy}
    tf.summary.scalar('accuracy', accuracy[1])
    # TODO tf.summary.scalar('accuracy', accuracy[1])

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=metrics)

    # Create training op.
    assert mode == tf.estimator.ModeKeys.TRAIN
    #FIXME
    num_batches_per_epoch = 45000 // 64  # * num_workers)
    boundaries = [ num_batches_per_epoch * x for x in np.array([82, 123, 300], dtype=np.int64)]
    staged_lr = [learning_rate * x for x in [1, 0.1, 0.01, 0.002]]
    learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(), boundaries, staged_lr)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)

    # optimizer = tf.train.SyncReplicasOptimizer(optimizer, 4, 4)
    # optimizer.make_session_run_hook()

    global_step = tf.train.get_global_step()
    print("Device is")
    print(global_step.device)
    train_op = optimizer.minimize(loss, global_step=global_step)


    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        with tf.control_dependencies([train_op]):
            train_op = tf.Print(global_step, [global_step])

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, training_chief_hooks=None)

### Run the job

Below we run this setup for 5000 iterations and two workers. We save this model under the name "Cifar10_Odd_Classes_5000_Iters". Later in <a href="2_TransferLearning.ipynb">TransferLearning Notebook</a>, we will look this up and transfer from this model. 

In [None]:
estimator = Estimator.NewClassificationEstimator(model_fn=cifar_model_function)
name = "CIFAR10_odd_class_estimator"
estimator = Estimator.create(name, "Resnet Cifar10 estimator", estimator)

hyper_parameters = HParams(iterations=5000, batch_size=128)
rc = RunConfig(no_of_ps=1, no_of_workers=2, run_eval=True, use_gpu=True)


odd_class_experiment_name = "Cifar10_Odd_Classes_5000_Iterations"
experiment = Experiment.run(experiment_name=odd_class_experiment_name,
                     description="Cifar10 Odd classes Model",
                     estimator=estimator,
                     hyper_parameters=hyper_parameters,
                     run_config=rc,
                     dataset_version_split=None,
                     input_function=only_odd_classes_input_function)
job = experiment.get_single_job()
print("tensorboard url")
print(job.get_tensorboard_url())

print("Has the job finished? {}".format(job.has_finished()))

job.wait_until_finish()

print("Model metrics..")
print(job.get_model().model_metrics)

We are done!!