## Distribution strategies using a custom Tensorflow loop



In [1]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

### Simulate an environment with 4 GPUs 

Don't run these cells when using an ec2 instance

In [2]:
physical_GPUS = tf.config.list_physical_devices(device_type='GPU')
physical_GPUS

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
physical_devices = tf.config.list_physical_devices('GPU')


tf.config.set_logical_device_configuration(
  physical_devices[0],
  [tf.config.LogicalDeviceConfiguration(memory_limit=2000),
    tf.config.LogicalDeviceConfiguration(memory_limit=2000),
    tf.config.LogicalDeviceConfiguration(memory_limit=2000),
    tf.config.LogicalDeviceConfiguration(memory_limit=2000)])

logical_devices = tf.config.list_logical_devices('GPU')
print('logical_devices', logical_devices)

logical_devices [LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU'), LogicalDevice(name='/device:GPU:2', device_type='GPU'), LogicalDevice(name='/device:GPU:3', device_type='GPU')]


### Get any dummy dataset for training

In this case I decided to use the fashion MNIST dataset

In [6]:
GLOBAL_BATCH_SIZE = 64


fashion_mnist, info = tfds.load('fashion_mnist', split= ['train', 'test'], as_supervised=True, with_info =True)
train_images, test_images = fashion_mnist
print(f"Number of examples: {len(train_images)}")

train_images, train_labels = fashion_mnist
print(f"Number of examples: {len(test_images)}")


train_images = train_images.map(lambda x,y : (tf.cast(tf.image.resize(x,(28,28)), dtype=tf.float32)- 127.5) / 127.5).batch(GLOBAL_BATCH_SIZE)
test_images = test_images.map(lambda x,y : (tf.cast(tf.image.resize(x,(28,28)), dtype=tf.float32)- 127.5) / 127.5).batch(GLOBAL_BATCH_SIZE)


# FIRST ADDITIONAL STEP: Create an strategy and use it to distribute the dataset
# The returned tf.distribute.DistributedDataset can be iterated over similar to
#regular datasets. NOTE: The user cannot add any more transformations to a
# tf.distribute.DistributedDataset. You can only create an iterator or examine
# the tf.TypeSpec of the data generated by it. See API docs
# of tf.distribute.DistributedDataset to learn more.

strategy = tf.distribute.MirroredStrategy(logical_devices)
train_dist_dataset = strategy.experimental_distribute_dataset(train_images)
train_images, train_dist_dataset

Number of examples: 60000
Number of examples: 10000








INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


(<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>,
 <tensorflow.python.distribute.input_lib.DistributedDataset at 0x7f298d578fd0>)

## functions to instantiate the model

In [7]:
IMG_SHAPE = (28, 28, 1)
# Size of the noise vector
noise_dim = 128


def conv_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5,
):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_discriminator_model():
    # BLACK BOX: Function that returns a functional tensorflow model
    img_input = layers.Input(shape=IMG_SHAPE)
    x = layers.ZeroPadding2D((2, 2))(img_input)
    x = conv_block(
        x,
        64,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        use_bias=True,
        activation=layers.LeakyReLU(0.2),
        use_dropout=False,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        128,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        256,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        512,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=False,
        drop_value=0.3,
    )

    x = layers.Flatten()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    d_model = keras.models.Model(img_input, x, name="discriminator")
    return d_model


def upsample_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    up_size=(2, 2),
    padding="same",
    use_bn=False,
    use_bias=True,
    use_dropout=False,
    drop_value=0.3,
):
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)

    if use_bn:
        x = layers.BatchNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_generator_model():
    # BLACK BOX: Function that returns a functional tensorflow model
    noise = layers.Input(shape=(noise_dim,))
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Reshape((4, 4, 256))(x)
    x = upsample_block(
        x,
        128,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x,
        64,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
    )

    x = layers.Cropping2D((2, 2))(x)
    g_model = keras.models.Model(noise, x, name="generator")
    return g_model

# Abstraction of a WGAN, which contains the discriminator and the generator
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp


## Creating the model and compiling in the distribution strategy context

In [9]:
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)

### SECOND ADDITIONAL STEP, INSTANTIATE MODEL AN OPTIMIZERS IN THE
# CONTEXT OF THE STRATEGY
with strategy.scope():
  d_model = get_discriminator_model()
  d_model.summary()
  g_model = get_generator_model()
  g_model.summary()

  # Instantiate the optimizer for both networks
  # (learning_rate=0.0002, beta_1=0.5 are recommended)
  generator_optimizer = keras.optimizers.Adam(
      learning_rate=0.0002, beta_1=0.5, beta_2=0.9
  )
  discriminator_optimizer = keras.optimizers.Adam(
      learning_rate=0.0002, beta_1=0.5, beta_2=0.9
  )

  # Set the number of epochs for trainining.
  epochs = 2

  # Instantiate the WGAN model.
  model = WGAN(
      discriminator=d_model,
      generator=g_model,
      latent_dim=noise_dim,
      discriminator_extra_steps=3,
  )

  # Compile the WGAN model.
  model.compile(
      d_optimizer=discriminator_optimizer,
      g_optimizer=generator_optimizer,
      g_loss_fn=generator_loss,
      d_loss_fn=discriminator_loss,
  )


Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 32, 32, 1)         0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 16, 16, 64)        1664      
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 8, 8, 128)         204928    
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 8, 8, 128)       

## Defining a custom train step and use it inside the distributed train step

   

In [10]:
## DEFINE YOUR CUSTOM TRAIN STEP, THIS DOESN'T DEPEND ON THE STRATEGY
def train_step(real_images):
  batch_size = tf.shape(real_images)[0]
  #tf.print('batch_size', batch_size)
  #tf.print('tf.distribute.get_replica_context()', tf.distribute.get_replica_context())

  for i in range(model.d_steps):
    # Get the latent vector
    random_latent_vectors = tf.random.normal(
        shape=(batch_size, model.latent_dim)
    )
    with tf.GradientTape() as tape:
        # Generate fake images from the latent vector
        fake_images = model.generator(random_latent_vectors, training=True)
        # Get the logits for the fake images
        fake_logits = model.discriminator(fake_images, training=True)
        # Get the logits for the real images
        real_logits = model.discriminator(real_images, training=True)

        # Calculate the discriminator loss using the fake and real image logits
        d_cost = model.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
        # Calculate the gradient penalty
        gp = model.gradient_penalty(batch_size, real_images, fake_images)
        # Add the gradient penalty to the original discriminator loss
        d_loss = d_cost + gp * model.gp_weight

    # Get the gradients w.r.t the discriminator loss
    d_gradient = tape.gradient(d_loss, model.discriminator.trainable_variables)
    # Update the weights of the discriminator using the discriminator optimizer
    model.d_optimizer.apply_gradients(
        zip(d_gradient, model.discriminator.trainable_variables)
          )

  # Train the generator
  # Get the latent vector
  random_latent_vectors = tf.random.normal(shape=(batch_size, model.latent_dim))
  with tf.GradientTape() as tape:
      # Generate fake images using the generator
      generated_images = model.generator(random_latent_vectors, training=True)
      # Get the discriminator logits for fake images
      gen_img_logits = model.discriminator(generated_images, training=True)
      # Calculate the generator loss
      g_loss = model.g_loss_fn(gen_img_logits)

  # Get the gradients w.r.t the generator loss
  gen_gradient = tape.gradient(g_loss, model.generator.trainable_variables)
  # Update the weights of the generator using the generator optimizer

  # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#example_12
  model.g_optimizer.apply_gradients(
      zip(gen_gradient, model.generator.trainable_variables)
  )
  return {"d_loss": d_loss, "g_loss": g_loss}

   
# THIRD ADDITIONAL STEP: Create a distributed train step that runs the
# train_step function using strategy.run

# use strategy.reduce to get an estimate of the loss.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))

  new_dict = {}
  for k, v in per_replica_losses.items():
    new_dict[k] =  strategy.reduce(tf.distribute.ReduceOp.MEAN, v,axis=None)
  return new_dict

### Training the model

You can use your custom training loop, the only thing that you need to do is 
to remember calling distributed_train_step instead of train step.

In [12]:
EPOCHS = 2 

In [13]:
counter = 0

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    step_losses = distributed_train_step(x)
    print('step_losses', step_losses)

    counter += 1
    if counter > 10:
      break
    

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3').


step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=8.237641>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-0.12954344>}
step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=2.116333>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-1.8905858>}
step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-6.1784964>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-6.995822>}
step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-8.308886>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-6.760395>}
step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-13.45362>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-4.194737>}
step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-15.94053>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-3.1482012>}
step_losses {'d_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-14.065365>, 'g_loss': <tf.Tensor: shape=(), dtype=float32, numpy=-3.9245