# Introduction

This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it on Laurence Moroney's [Horses or Humans Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient Centralization can both speedup training process and improve the final generalization performance of DNNs. It operates directly on gradients by centralizing the gradient vectors to have zero mean.

This example requires TensorFlow 2.2 or higher, as well as a package I built, [gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow), which can be installed using the following command:

```py
pip install gradient-centralization-tf
```

# Setup

In [1]:
!pip install gradient-centralization-tf

Collecting gradient-centralization-tf
  Downloading https://files.pythonhosted.org/packages/d0/a5/c4f43ea30718fbac477b0386e6f57981214359c917d39e2f6237e5110ff6/gradient_centralization_tf-0.0.3-py3-none-any.whl
Installing collected packages: gradient-centralization-tf
Successfully installed gradient-centralization-tf-0.0.3


In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import RMSprop


import gctf

from time import time
import os

# Prepare the data

For this example, we will be using the [Horses or Humans dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).

In [3]:
num_classes = 2
input_shape = (300, 300, 3)
dataset_name = 'horses_or_humans'

(train_ds, test_ds), metadata = tfds.load(name=dataset_name, 
                                split=[tfds.Split.TRAIN, tfds.Split.TEST],
                                with_info=True,
                                as_supervised=True)

[1mDownloading and preparing dataset horses_or_humans/3.0.0 (download: 153.59 MiB, generated: Unknown size, total: 153.59 MiB) to /root/tensorflow_datasets/horses_or_humans/3.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…







HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/horses_or_humans/3.0.0.incomplete4Y5R5S/horses_or_humans-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=1027.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/horses_or_humans/3.0.0.incomplete4Y5R5S/horses_or_humans-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=256.0), HTML(value='')))

[1mDataset horses_or_humans downloaded and prepared to /root/tensorflow_datasets/horses_or_humans/3.0.0. Subsequent calls will reuse this data.[0m


In [4]:
print(f"Image shape: {metadata.features['image'].shape}")
print(f"Training images: {metadata.splits['train'].num_examples}")
print(f"Test images: {metadata.splits['test'].num_examples}")

Image shape: (300, 300, 3)
Training images: 1027
Test images: 256


# Use Data Augmentation

We will rescale the data to `[0, 1]`  andperform simple augmentations to our data.

In [5]:
rescale = tf.keras.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255)
])

data_augmentation = tf.keras.Sequential([
    layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    layers.experimental.preprocessing.RandomRotation(0.3),
    layers.experimental.preprocessing.RandomZoom(0.2),
])

In [6]:
batch_size = 128
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
    # Rescale dataset
    ds = ds.map(lambda x, y: (rescale(x), y), 
                num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1024)

    # Batch dataset
    ds = ds.batch(batch_size)

    # Use data augmentation only on the training set
    if augment:
        ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                    num_parallel_calls=AUTOTUNE)

    # Use buffered prefecting
    return ds.prefetch(buffer_size=AUTOTUNE)

Rescale and augment the data

In [7]:
train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

# Define a model

In this section we will define a Convolutional neural network

In [8]:
model = tf.keras.Sequential([
    layers.Conv2D(16, (3,3), activation='relu', input_shape=(300, 300, 3)),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(32, (3,3), activation='relu'),
    layers.Dropout(0.5),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.Dropout(0.5),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),
    
    layers.Flatten(),
    layers.Dropout(0.5),
    layers.Dense(512, activation='relu'),

    layers.Dense(1, activation='sigmoid')
])

We will also create a callback which allows us to easily measure the total training time and the time taken for each epoch  since we are interested in comparing the effect of Gradient Centralization on the model we built above.

In [9]:
class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_time_start)

# Train the model without GC

We now train the model we built earlier without Gradient Centralization which we can compare to the training performance of the model trained with GGradient Centralization.

In [12]:
time_callback_no_gc = TimeHistory()
model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(learning_rate=1e-4),
              metrics=['accuracy'])

history_no_gctf = model.fit(
      train_ds,
      epochs=10,
      verbose=1,
      callbacks = [time_callback_no_gc])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
