In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [None]:
# original tensorflow implementation

  def train_step(self, data):
    """The logic for one training step.

    This method can be overridden to support custom training logic.
    For concrete examples of how to override this method see
    [Customizing what happends in fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit).
    This method is called by `Model.make_train_function`.

    This method should contain the mathematical logic for one step of training.
    This typically includes the forward pass, loss calculation, backpropagation,
    and metric updates.

    Configuration details for *how* this logic is run (e.g. `tf.function` and
    `tf.distribute.Strategy` settings), should be left to
    `Model.make_train_function`, which can also be overridden.

    Args:
      data: A nested structure of `Tensor`s.

    Returns:
      A `dict` containing values that will be passed to
      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
      values of the `Model`'s metrics are returned. Example:
      `{'loss': 0.2, 'accuracy': 0.7}`.
    """
    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
    # Run forward pass.
    with tf.GradientTape() as tape:
      y_pred = self(x, training=True)
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    self._validate_target_and_loss(y, loss)
    # Run backwards pass.
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    return self.compute_metrics(x, y, y_pred, sample_weight)

In [None]:
from tensorflow.keras.engine import data_adapter

def train_step(self, data):
    X, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)

Below implementations adapted from https://keras.io/guides/customizing_what_happens_in_fit/

In [24]:
from tensorflow import keras
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import Sequence

import copy
from tensorflow.python.eager import context
from keras import callbacks as callbacks_module
from keras.engine import base_layer
from keras.engine import data_adapter
from keras.engine import training_utils
from keras.utils import tf_utils
from keras.utils import traceback_utils
from keras.utils import version_utils

In [21]:
class BatchGenerator(Sequence):
    def __init__(self, length, batch_size):
        self.X = np.random.random((length, batch_size, 14))
        self.y = np.random.random((length, batch_size, 1))
        self.epoch_counter = 0
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, index):
        return self.X[index, :, :], self.y[index, :, :]
    
    def on_epoch_end(self):
        self.epoch_counter += 1
        print(f'[BG] updated self.epoch_counter: {self.epoch_counter}')

In [33]:
batch_generator = BatchGenerator(20, 32)
X, y = batch_generator.__getitem__(0)
print(f'X: {X.shape}, y: {y.shape}')

X: (32, 14), y: (32, 1)


In [34]:
loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")

def _disallow_inside_tf_function(method_name):
    if tf.inside_function():
        error_msg = (
                'Detected a call to `Model.{method_name}` inside a `tf.function`. '
                '`Model.{method_name} is a high-level endpoint that manages its own '
                '`tf.function`. Please move the call to `Model.{method_name}` outside '
                'of all enclosing `tf.function`s. Note that you can call a `Model` '
                'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
        ).format(method_name=method_name)
        raise RuntimeError(error_msg)

class CustomModel(keras.Model):
    
    def _disallow_inside_tf_function(method_name):
        if tf.inside_function():
            error_msg = (
                    'Detected a call to `Model.{method_name}` inside a `tf.function`. '
                    '`Model.{method_name} is a high-level endpoint that manages its own '
                    '`tf.function`. Please move the call to `Model.{method_name}` outside '
                    'of all enclosing `tf.function`s. Note that you can call a `Model` '
                    'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
            ).format(method_name=method_name)
            raise RuntimeError(error_msg)
    
    @traceback_utils.filter_traceback
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose='auto',
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):
      
        base_layer.keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')
        _disallow_inside_tf_function('fit')
        if verbose == 'auto':
            if self.distribute_strategy._should_use_with_coordinator:        # pylint: disable=protected-access
                        verbose = 2        # Default to epoch-level logging for PSStrategy.
            else:
                        verbose = 1        # Default to batch-level logging otherwise.
        elif verbose == 1 and self.distribute_strategy._should_use_with_coordinator:        # pylint: disable=protected-access
            raise ValueError(
                                    '`verbose=1` is not allowed with `ParameterServerStrategy` for '
                                    f'performance reasons. Received: `verbose`={verbose}')
        if validation_split:
                    # Create the validation data using the training data. Only supported for
                    # `Tensor` and `NumPy` input.
                    (x, y, sample_weight), validation_data = (
                                    data_adapter.train_validation_split(
                                                    (x, y, sample_weight), validation_split=validation_split))
        if validation_data:
            val_x, val_y, val_sample_weight = (
                    data_adapter.unpack_x_y_sample_weight(validation_data))
        if self.distribute_strategy._should_use_with_coordinator:    # pylint: disable=protected-access
            self._cluster_coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
                    self.distribute_strategy)
        with self.distribute_strategy.scope(), \
                 training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = data_adapter.get_data_handler(
                    x=x,
                    y=y,
                    sample_weight=sample_weight,
                    batch_size=batch_size,
                    steps_per_epoch=steps_per_epoch,
                    initial_epoch=initial_epoch,
                    epochs=epochs,
                    shuffle=shuffle,
                    class_weight=class_weight,
                    max_queue_size=max_queue_size,
                    workers=workers,
                    use_multiprocessing=use_multiprocessing,
                    model=self,
                    steps_per_execution=self._steps_per_execution)
            self.data_handler = data_handler
            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                        callbacks,
                        add_history=True,
                        add_progbar=verbose != 0,
                        model=self,
                        verbose=verbose,
                        epochs=epochs,
                        steps=data_handler.inferred_steps)
            self.stop_training = False
            self.train_function = self.make_train_function()
            self._train_counter.assign(0)
            callbacks.on_train_begin()
            training_logs = None
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (    # pylint: disable=protected-access
                    self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            logs = None
            
            print(vars(data_handler))
            
            for epoch, iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                with data_handler.catch_stop_iteration():
                    for step in data_handler.steps():
                        with tf.profiler.experimental.Trace(
                                'train',
                                epoch_num=epoch,
                                step_num=step,
                                batch_size=batch_size,
                                _r=1):
                            callbacks.on_train_batch_begin(step)
                            print(f'iterator: {iterator}')
                            tmp_logs = self.train_function(iterator)
                            if data_handler.should_sync:
                                context.async_wait()
                            logs = tmp_logs    # No error, now safe to assign to logs.
                            end_step = step + data_handler.step_increment
                            callbacks.on_train_batch_end(end_step, logs)
                            if self.stop_training:
                                break
                
                logs = tf_utils.sync_to_numpy_or_python_type(logs)
                if logs is None:
                    raise ValueError('Unexpected result of `train_function` '
                                                     '(Empty logs). Please use '
                                                     '`Model.compile(..., run_eagerly=True)`, or '
                                                     '`tf.config.run_functions_eagerly(True)` for more '
                                                     'information of where went wrong, or file a '
                                                     'issue/bug to `tf.keras`.')
                epoch_logs = copy.copy(logs)
                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    # Create data_handler for evaluation and cache it.
                    if getattr(self, '_eval_data_handler', None) is None:
                        self._eval_data_handler = data_adapter.get_data_handler(
                                x=val_x,
                                y=val_y,
                                sample_weight=val_sample_weight,
                                batch_size=validation_batch_size or batch_size,
                                steps_per_epoch=validation_steps,
                                initial_epoch=0,
                                epochs=1,
                                max_queue_size=max_queue_size,
                                workers=workers,
                                use_multiprocessing=use_multiprocessing,
                                model=self,
                                steps_per_execution=self._steps_per_execution)
                    val_logs = self.evaluate(
                            x=val_x,
                            y=val_y,
                            sample_weight=val_sample_weight,
                            batch_size=validation_batch_size or batch_size,
                            steps=validation_steps,
                            callbacks=callbacks,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            return_dict=True,
                            _use_cached_eval_dataset=True)
                    val_logs = {'val_' + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)
                callbacks.on_epoch_end(epoch, epoch_logs)
                training_logs = epoch_logs
                if self.stop_training:
                    break
            # If eval data_handler exists, delete it after all epochs are done.
            if getattr(self, '_eval_data_handler', None) is not None:
                del self._eval_data_handler
            callbacks.on_train_end(logs=training_logs)
            return self.history
    
    def train_step(self, data):
        print(vars(self.data_handler._adapter._keras_sequence))
        x, y = data
        print(f'data: {type(data)}, X: {type(x)} {x.shape}, y: {type(y)} {y.shape}')
        
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute our own loss
            loss = keras.losses.mean_squared_error(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics
        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [loss_tracker, mae_metric]


# Construct an instance of CustomModel
inputs = keras.Input(shape=(14,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't passs a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(batch_generator, epochs=5)

{'_initial_epoch': 0, '_epochs': 5, '_insufficient_data': False, '_model': <__main__.CustomModel object at 0x000001EEBCEF8100>, '_steps_per_execution': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=1>, '_adapter': <keras.engine.data_adapter.KerasSequenceAdapter object at 0x000001EEBCED1E20>, '_current_step': 0, '_step_increment': 0, '_inferred_steps': 20, '_dataset': <PrefetchDataset element_spec=(TensorSpec(shape=(None, None), dtype=tf.float32, name=None), TensorSpec(shape=(None, None), dtype=tf.float32, name=None))>}
Epoch 1/5
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
{'X': array([[[0.34819391, 0.18740289, 0.31448158, ..., 0.34214086,
         0.30489974, 0.34475161],
        [0.64632784, 0.40692046, 0.73995774, ..., 0.94228406,
         0.05641092, 0.44217661],
        [0.62172679, 0.54165388, 0.10476333, ..., 0.54237517,
         0.93648416, 0.77627781],
        ...,
        [0.20314842, 0.0586868 , 0.73262245, ..., 0.857

 1/20 [>.............................] - ETA: 5s - loss: 0.4883 - mae: 0.5946iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBDDCEB0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator obj

iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBE666A0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBBE666A0>
[BG] updated self.epoch_counter: 4
Epoch 5/5
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
 1/20 [>.............................] - ETA: 0s - loss: 0.1507 - mae: 0.3270iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
iterator: <tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x000001EEBCF661C0>
iterator: <tensorflow.p

<keras.callbacks.History at 0x1eebced1250>