# Introduction
- When you need to customize what `fit()` does, you should **override the training step function of the `Model` class**. 
    - This is the function that is called by `fit()` for every batch of data. 
    - You will then be able to call `fit()` as usual -- and it will be running your own learning algorithm.
- Note that this pattern does not prevent you from building models with the Functional API. 
    - You can do this whether you're building Sequential models, Functional API models, or subclassed models.

# Setup

In [1]:
import tensorflow as tf
from tensorflow import keras

# A first simple example
- Let's start from a simple example.
    - We create a new class that subclasses `keras.Model`.
    - We just override the method `train_step(self, data)`.
    - We return a dictionary mapping metric names (including the loss) to their current value.
- The input argument `data` is what gets passed to `fit()` as training data.
    - If you pass NumPy arrays, by calling `fit(X, y, ...)`, then `data` will be the tuple `(X, y)`.
    - If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, the `data` will be what gets yielded by `dataset` at each batch.
- In the body of the `train_step` method, we implement a regular training update, similar to what you are already familiar with.
    - Importantly, we **compute the loss via `self.compiled_loss`**, which wraps the losses functions that were passed to `compile()`.
    - Similarly, we call **`self.compiled_metrics.update_state(y, y_pred)`** to update the state of the metrics that were passed in `compile()`, and we query results from `self.metrics` at the end to retrieve their current value.

In [2]:
class CustomModel(keras.Model):
    def train_step(self, data):
        X, y = data
        
        with tf.GradientTape() as tape:
            y_pred = self(X, training=True) # forward pass
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(zip(gradients, trainable_vars))
        
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Update metrics
        self.compiled_metrics.update_state(y, y_pred)
        
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [3]:
import numpy as np

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
X = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(X, y, epochs=3)

Train on 1000 samples
Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fd3e1deb690>

# Going lower-level
- Naturally, you could just skip passing a loss function in `compile()`, and instead do everything manually in `train_step`.
- Likewise for metrics.
- Here's a lower-level example which only uses `compile()` to configure the optimizer.

In [4]:
mae_metric = keras.metrics.MeanAbsoluteError(name='mae')
loss_tracker = keras.metrics.Mean(name='loss')

In [5]:
class CustomModel(keras.Model):
    def train_step(self, data):
        X, y = data

        with tf.GradientTape() as tape:
            y_pred = self(X, training=True)  # Forward pass
            
            # Compute our own loss (this loss is not from compile())
            loss = keras.losses.mean_squared_error(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics (this metric is not from compile())
        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}

In [6]:
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

In [7]:
# # We don't passs a loss or metrics here
# It doesn't work for me?

# # model.compile(optimizer="adam")
# # Just use `fit` as usual -- you can use callbacks, etc.
# x = np.random.random((1000, 32))
# y = np.random.random((1000, 1))
# model.fit(x, y, epochs=3)

# Supporting `sample_weight` & `class_weight`
-  If you want to support the `fit()` arguments `sample_weight` and `class_weight`, you'd simply do the following:
    - Unpack sample_weight from the data argument 
    - Pass it to `compiled_loss` & `compiled_metrics` (of course, you could also just apply it manually if you don't rely on `compile()` for losses & metrics)

In [8]:
class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data
        if len(data) == 3:
            X, y, sample_weight = data
        else:
            X, y = data
            
        with tf.GradientTape() as tape:
            y_pred = self(X, training=True)  # Forward pass
            
            # Compute the loss value
            # The loss function is configured in `compile()`
            loss = self.compiled_loss(
                y,
                y_pred,
                sample_weight=sample_weight,
                regularization_losses=self.losses,
            )
            
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics
        # Metrics are configured in `compile()`
        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics)
        return {m.name: m.result() for m in self.metrics}

In [9]:
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)

model = keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

# You can now use sample_weight argument
X = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
# model.fit(X, y, sample_weight=sw, epochs=3) # Doesn't work for me again?

# Providing your own evaluation step
- What if you want to do the same for calls to `model.evaluate()`?
- Then you would **override `test_step`** in exactly the same way.

In [10]:
class CustomModel(keras.Model):
    def test_step(self, data):
        X, y = data
        
        y_pred = self(X, training=False)
        
        # Update the metrics traking the loss
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        
        # Update the metrics
        self.compiled_metrics.update_state(y, y_pred)
        
        # Return a dict mapping metric names to current value
        # Note that it will include the loss tracked by self.metrics
        return {m.name: m.result() for m in self.metrics}

In [11]:
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

In [12]:
# Evaluate with our custom test_step
X = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(X, y, verbose=0)

[0.569937665939331, 0.63245946]

# Wrapping up: an end-to-end GAN example
- Let's walk through an end-to-end example that leverages everything you just learned.
- Let's consider:
    - A generator network meant to generate 28X28X1 images
    - A discriminator network meant to classify 28X28X1 images into two classes ("fake" and "real")
    - One optimizer for each above
    - A loss function to train the discriminator

In [13]:
from tensorflow.keras import layers

# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1)
    ],
    name='discriminator'
)

In [14]:
# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid")
    ],
    name='generator'
)

- Here is a feature-complete GAN class, overriding `compile()` to use its own signature, and implementing the entire GAN algorithm in 17 lines in `train_step`.

In [15]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
            
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        
        # Assemble labels discriminating real from fake images
        labels = tf.concat([
            tf.ones((batch_size, 1)),
            tf.zeros((batch_size, 1))], axis=0)
        
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        
        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        
        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))
        
        # Train the generator
        # Note that we should NOT update the weights of the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        
        return {"d_loss": d_loss, "g_loss": g_loss}

In [16]:
# Prepare the dataset
# We use both the training & test MNIST digits
batch_size = 64
(X_train, _), (X_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([X_train, X_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

In [17]:
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)

In [18]:
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

In [19]:
# To limit execution time, we only train on 100 batches 
# You can train on the entire dataset
# You will need about 20 epochs to get nice results

# gan.fit(dataset.take(100), epochs=1) # Not working on my notebook