In [1]:
import tensorflow as tf
import cifar
tf.logging.set_verbosity(tf.logging.WARN)

## Download and extract the dataset

In [2]:
cifar.prepare_cifar_10()
cifar10_labels = cifar.cifar10_labels()

## Define the model function

In [3]:
def model(features, labels, mode, params):
    net = features['images']
    if params["print_shapes"]:
        print(net.shape)

    for filt, kern, stride in zip(params['filters'], params['kern'], params['strides']):
        net = tf.layers.batch_normalization(net, center=params['with_scsf'], scale=params['with_scsf'],
                                            momentum=params['momentum'], training=mode == tf.estimator.ModeKeys.TRAIN)
        net = tf.layers.conv2d(net, filt,
                               kern, stride, activation=tf.nn.relu)
        if params["print_shapes"]:
            print(net.shape)

    net = tf.layers.flatten(net)
    if params["print_shapes"]:
        print(net.shape)

    for units in params['dense']:
        net = tf.layers.batch_normalization(net, center=params['with_scsf'], scale=params['with_scsf'],
                                            momentum=params['momentum'], training=mode == tf.estimator.ModeKeys.TRAIN)
        net = tf.layers.dense(net, units,
                              activation=tf.nn.relu)
        if params["print_shapes"]:
            print(net.shape)

    net = tf.layers.batch_normalization(net, center=params['with_scsf'], scale=params['with_scsf'],
                                            momentum=params['momentum'], training=mode == tf.estimator.ModeKeys.TRAIN)

    logits = tf.layers.dense(net, params['n_classes'])
    if params["print_shapes"]:
        print(logits.shape)
    cls = tf.argmax(logits, -1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions={
            "class": cls,
            "score": tf.nn.softmax(logits)
        })

    loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops={
            "accuracy": tf.metrics.accuracy(labels, cls)
        })
    adam = tf.train.AdamOptimizer()

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt = adam.minimize(loss, global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=opt)

In [4]:
def inp_fn():
    return tf.data.Dataset.from_generator(cifar.cifar10_train, ({"images": tf.float32}, tf.int64),
                                          ({"images": tf.TensorShape([None, 32, 32, 3])}, tf.TensorShape(None)))


def test_inp_fn():
    return tf.data.Dataset.from_generator(cifar.cifar10_test, ({"images": tf.float32}, tf.int64),
                                          ({"images": tf.TensorShape([None, 32, 32, 3])}, tf.TensorShape(None)))

In [5]:
hparams = {
    "filters": [30, 50, 60],
    "kern": [[3, 3]]*3,
    "strides": [[2, 2], [1, 1], [1, 1]],
    "dense": [3500, 700],
    "n_classes": 10,
    "with_scsf": True,
    "momentum": 0.75,
    "print_shapes": True
}

### View layer shapes

In [6]:
model({"images": tf.placeholder(tf.float32, (10, 32, 32, 3))},
      tf.placeholder(tf.int32, (10)), tf.estimator.ModeKeys.TRAIN, hparams)
hparams["print_shapes"] = False

(10, 32, 32, 3)
(10, 15, 15, 30)
(10, 13, 13, 50)
(10, 11, 11, 60)
(10, 7260)
(10, 3500)
(10, 700)
(10, 10)


## With Scale and shift

In [7]:
wscf = tf.estimator.Estimator(model, 'wscf-ckpts', config=tf.estimator.RunConfig(save_summary_steps=2),
                              params=hparams)

#### Start Tensorboard

In [8]:
get_ipython().system_raw("start tensorboard --logdir wscf-ckpts")  # Windows
# get_ipython().system_raw("tensorboard --logdir wscf-ckpts &") #Linux

In [9]:
%%time
for i in range(20):
    wscf.train(inp_fn)
    print(wscf.evaluate(test_inp_fn))

{'accuracy': 0.547, 'loss': 1.2823796, 'global_step': 20}
{'accuracy': 0.5949, 'loss': 1.1662079, 'global_step': 40}
{'accuracy': 0.5788, 'loss': 1.4566054, 'global_step': 60}
{'accuracy': 0.5681, 'loss': 1.7820822, 'global_step': 80}
{'accuracy': 0.5846, 'loss': 1.7017951, 'global_step': 100}
{'accuracy': 0.5846, 'loss': 1.8201271, 'global_step': 120}
{'accuracy': 0.5974, 'loss': 1.7575045, 'global_step': 140}
{'accuracy': 0.599, 'loss': 1.721942, 'global_step': 160}
{'accuracy': 0.6035, 'loss': 1.7390618, 'global_step': 180}
{'accuracy': 0.6043, 'loss': 1.750489, 'global_step': 200}
{'accuracy': 0.6045, 'loss': 1.7619331, 'global_step': 220}
{'accuracy': 0.6053, 'loss': 1.7707057, 'global_step': 240}
{'accuracy': 0.6055, 'loss': 1.7792139, 'global_step': 260}
{'accuracy': 0.6059, 'loss': 1.7870388, 'global_step': 280}
{'accuracy': 0.6065, 'loss': 1.7944248, 'global_step': 300}
{'accuracy': 0.6062, 'loss': 1.8013859, 'global_step': 320}
{'accuracy': 0.6068, 'loss': 1.8080118, 'global_

## Without Scale and Shift

In [10]:
hparams['with_scsf'] = False

In [11]:
woscf = tf.estimator.Estimator(model, 'woscf-ckpts', config=tf.estimator.RunConfig(save_summary_steps=2),
                             params=hparams)

#### Start Tensorboard

In [12]:
get_ipython().system_raw("start tensorboard --logdir woscf-ckpts --port 6007")  # Windows
# get_ipython().system_raw("tensorboard --logdir woscf-ckpts --port 6007 &")  # Linux

In [13]:
%%time
for i in range(20):
    woscf.train(inp_fn)
    print(woscf.evaluate(test_inp_fn))

{'accuracy': 0.5043, 'loss': 1.3885775, 'global_step': 20}
{'accuracy': 0.5801, 'loss': 1.1778008, 'global_step': 40}
{'accuracy': 0.6099, 'loss': 1.1439934, 'global_step': 60}
{'accuracy': 0.5785, 'loss': 1.5351887, 'global_step': 80}
{'accuracy': 0.5711, 'loss': 1.701782, 'global_step': 100}
{'accuracy': 0.5644, 'loss': 1.8049095, 'global_step': 120}
{'accuracy': 0.5793, 'loss': 1.8751414, 'global_step': 140}
{'accuracy': 0.581, 'loss': 1.8933313, 'global_step': 160}
{'accuracy': 0.5866, 'loss': 1.8283887, 'global_step': 180}
{'accuracy': 0.6038, 'loss': 1.7691035, 'global_step': 200}
{'accuracy': 0.6004, 'loss': 1.8271089, 'global_step': 220}
{'accuracy': 0.5978, 'loss': 1.8227904, 'global_step': 240}
{'accuracy': 0.5941, 'loss': 1.8485936, 'global_step': 260}
{'accuracy': 0.596, 'loss': 1.8362767, 'global_step': 280}
{'accuracy': 0.5958, 'loss': 1.900719, 'global_step': 300}
{'accuracy': 0.5942, 'loss': 1.9070389, 'global_step': 320}
{'accuracy': 0.6057, 'loss': 1.891427, 'global_s

## Results
#### With Scale and Shift
![Graph with scale](images/wscf.png)
#### Without Scale and Shift
![Graph without scale](images/woscf.png)

As we can see, there is no noticeable change in the how the loss changes over training. It is because Batch Normalization is applied after activation function. Hence scale and shift parameters can be completely ignored. It also speeds up training. In this notebook you can see that there is a 30 second decrease in time. Also lesser parameters leads to slimmer models.

Read more about Batch Normalization in the [original paper](https://arxiv.org/abs/1502.03167). You can try running this notebook in your local system or in [Google Colab](https://drive.google.com/file/d/1d6RkPZoZ1DanTRRJbZGGBo3wdJxW-dag/view?usp=sharing)