# Introduction

This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it on Laurence Moroney's [Horses or Humans Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient Centralization can both speedup training process and improve the final generalization performance of DNNs. It operates directly on gradients by centralizing the gradient vectors to have zero mean. Gradient Centralization morever improves the Lipschitzness of the loss function and its gradient so that the training process becomes more efficient and stable.

This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can be installed with this command:

```
pip install tensorflow-datasets
```

We will be implementing Gradient Centralization in this example but you could also use this very easily with a package I built, [gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow).

# Setup

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import RMSprop
import keras.backend as K

from time import time
import os

# Prepare the data

For this example, we will be using the [Horses or Humans dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).

In [2]:
num_classes = 2
input_shape = (300, 300, 3)
dataset_name = 'horses_or_humans'

(train_ds, test_ds), metadata = tfds.load(name=dataset_name, 
                                split=[tfds.Split.TRAIN, tfds.Split.TEST],
                                with_info=True,
                                as_supervised=True)

[1mDownloading and preparing dataset horses_or_humans/3.0.0 (download: 153.59 MiB, generated: Unknown size, total: 153.59 MiB) to /root/tensorflow_datasets/horses_or_humans/3.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…







HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/horses_or_humans/3.0.0.incompleteDMFU7G/horses_or_humans-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=1027.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/horses_or_humans/3.0.0.incompleteDMFU7G/horses_or_humans-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=256.0), HTML(value='')))

[1mDataset horses_or_humans downloaded and prepared to /root/tensorflow_datasets/horses_or_humans/3.0.0. Subsequent calls will reuse this data.[0m


In [3]:
print(f"Image shape: {metadata.features['image'].shape}")
print(f"Training images: {metadata.splits['train'].num_examples}")
print(f"Test images: {metadata.splits['test'].num_examples}")

Image shape: (300, 300, 3)
Training images: 1027
Test images: 256


# Use Data Augmentation

We will rescale the data to `[0, 1]`  andperform simple augmentations to our data.

In [4]:
rescale = tf.keras.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255)
])

data_augmentation = tf.keras.Sequential([
    layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    layers.experimental.preprocessing.RandomRotation(0.3),
    layers.experimental.preprocessing.RandomZoom(0.2),
])

In [5]:
batch_size = 128
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
    # Rescale dataset
    ds = ds.map(lambda x, y: (rescale(x), y), 
                num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1024)

    # Batch dataset
    ds = ds.batch(batch_size)

    # Use data augmentation only on the training set
    if augment:
        ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                    num_parallel_calls=AUTOTUNE)

    # Use buffered prefecting
    return ds.prefetch(buffer_size=AUTOTUNE)

Rescale and augment the data

In [6]:
train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

# Define a model

In this section we will define a Convolutional neural network

In [7]:
model = tf.keras.Sequential([
    layers.Conv2D(16, (3,3), activation='relu', input_shape=(300, 300, 3)),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(32, (3,3), activation='relu'),
    layers.Dropout(0.5),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.Dropout(0.5),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),
    
    layers.Flatten(),
    layers.Dropout(0.5),
    layers.Dense(512, activation='relu'),

    layers.Dense(1, activation='sigmoid')
])

We will also create a callback which allows us to easily measure the total training time and the time taken for each epoch  since we are interested in comparing the effect of Gradient Centralization on the model we built above.

In [8]:
class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_time_start)

# Train the model without GC

We now train the model we built earlier without Gradient Centralization which we can compare to the training performance of the model trained with Gradient Centralization.

In [14]:
time_callback_no_gc = TimeHistory()
model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(learning_rate=1e-4),
              metrics=['accuracy'])

model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 298, 298, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 149, 149, 16)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 147, 147, 32)      4640      
_________________________________________________________________
dropout (Dropout)            (None, 147, 147, 32)      0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 73, 73, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 71, 71, 64)        18496     
_________________________________________________________________
dropout_1 (Dropout)          (None, 71, 71, 64)       

We also save the history since we later want to compare our model trained with and not trained with Gradient Centralization

In [9]:
history_no_gc = model.fit(
      train_ds,
      epochs=10,
      verbose=1,
      callbacks = [time_callback_no_gc])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


# Train the model with GC 

We will now train the same model, this time using Gradient Centralization. We will now subclass the `RMSProp` optimizer class modifying the `tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient Centralization. On a high level the idea is that let us say we obtain our gradients through back propogation for a Dense or Convolution layer we then compute the mean of the column vectors of the weight matrix, and then remove the mean from each column vector.

The experiments in [this paper](https://arxiv.org/abs/2004.01461) on various applications, including general image classification, fine-grained image classification, detection and segmentation and Person ReID demonstrate that GC can consistently improve the performance of DNN learning.

In [10]:
class gc_rmsprop(RMSprop):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are 
        # trying to just compute the centralized gradients.

        grads = []
        for grad in K.gradients(loss, params):
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= tf.reduce_mean(grad,
                                        axis=axis,
                                        keep_dims=True)
            grads.append(grad)

        if None in grads:
            raise ValueError('An operation has `None` for gradient. '
                              'Please make sure that all of your ops have a '
                              'gradient defined (i.e. are differentiable). '
                              'Common ops without gradient: '
                              'K.argmax, K.round, K.eval.')
        if hasattr(optimizer, 'clipnorm') and optimizer.clipnorm > 0:
            norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads]))
            grads = [
                tf.keras.optimizers.clip_norm(
                    g,
                    optimizer.clipnorm,
                    norm) for g in grads]
        if hasattr(optimizer, 'clipvalue') and optimizer.clipvalue > 0:
            grads = [K.clip(g, -optimizer.clipvalue, optimizer.clipvalue)
                      for g in grads]
        return grads

optimizer = gc_rmsprop(learning_rate = 1e-4)

We will now train our model this time using Gradient Centralization, notice our optimizer is the one using Gradient Centralization this time.


In [15]:
time_callback_gc = TimeHistory()
model.compile(loss='binary_crossentropy',
              optimizer=optimizer,
              metrics=['accuracy'])

model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 298, 298, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 149, 149, 16)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 147, 147, 32)      4640      
_________________________________________________________________
dropout (Dropout)            (None, 147, 147, 32)      0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 73, 73, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 71, 71, 64)        18496     
_________________________________________________________________
dropout_1 (Dropout)          (None, 71, 71, 64)       

In [11]:
history_gc = model.fit(
      train_ds,
      epochs=10,
      verbose=1,
      callbacks = [time_callback_gc])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


# Comparing performance

In [12]:
print("Not using Gradient Centralization")
print(f"Loss: {history_no_gc.history['loss'][-1]}")
print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_no_gc.times)}")

Not using Gradient Centralization
Loss: 0.6229214668273926
Accuracy: 0.7059396505355835
Training Time: 231.0875632762909


In [13]:
print("Using Gradient Centralization")
print(f"Loss: {history_gc.history['loss'][-1]}")
print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_gc.times)}")

Using Gradient Centralization
Loss: 0.452496737241745
Accuracy: 0.8111003041267395
Training Time: 200.79542303085327


Readers are encouraged to try out Gradient Centralization on different datasets from different domains and experiment with it's effect. You are strongly advised to check out the [original paper](https://arxiv.org/abs/2004.01461) as well - the authors present several studies on Gradient Centralization showing how it can improve general performance, generalization, training time as well as more efficient.

Many thanks to [Ali Mustufa Shaikh](https://github.com/ialimustufa) for reviewing this implementation.