Skip to content

Commit

Permalink
feat: support tf.data.Dataset for all TF version
Browse files Browse the repository at this point in the history
Do not checkpoint tf.data.Datasets. Resuming training from the
correct place for  tf.data.Dataset inputs will be supported
once data layer is enabled.

DET-2835  #Done.
  • Loading branch information
aaron276h committed Apr 14, 2020
1 parent 1f0552c commit 5e14162
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 271 deletions.
1 change: 0 additions & 1 deletion harness/determined/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from determined.estimator._util import (
_cleanup_after_train_step,
_cleanup_after_validation_step,
_delete_input_pipeline_checkpoints,
_update_checkpoint_path_in_state_file,
_scan_checkpoint_directory,
)
Expand Down
11 changes: 0 additions & 11 deletions harness/determined/estimator/_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,6 @@ def _init_model(self) -> None:

all_hooks = [*self.user_train_spec.hooks]

# The following skip_checkpointing_input flag is a workaround for
# stateful datasets in TF 1.14. Stateful input pipeline functions
# cannot be serialized and therefore checkpointing them should be
# skipped.
if (
not self.env.experiment_config.get("data", {}).get("skip_checkpointing_input", False)
and not self.env.experiment_config.input_from_dataflow()
):
all_hooks.append(tf.data.experimental.CheckpointInputPipelineHook(self.estimator))

if self.hvd_config.use:
all_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

Expand All @@ -478,7 +468,6 @@ def _init_model(self) -> None:
# In the short term, behave like other trials and reset input
# state if we are warm started. This will create an inconsistency
# wrt saved optimizer state.
estimator._delete_input_pipeline_checkpoints(str(self.estimator_dir))

self.train_spec = tf.estimator.TrainSpec(
input_fn=self.user_train_spec.input_fn, hooks=all_hooks
Expand Down
22 changes: 0 additions & 22 deletions harness/determined/estimator/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
# TODO: DET-175 direct checkpoint file manipulation is deprecated in TF 1.13.


# This is the filename prefix used by `tf.data.estimator.CheckpointInputPipelineHook`
# to write its checkpoints.
_INPUT_CHECKPOINT_PREFIX = "input"


CheckpointInputPipelineHook = tf.data.experimental.CheckpointInputPipelineHook


class Checkpoint:
"""
The metadata about a checkpoint.
Expand All @@ -47,8 +39,6 @@ class Checkpoint:
the "model" checkpoint data, index and metadata for step 0
- checkpoint_input
the "input" checkpoint state
- input.ckpt-0.data-00000-of-00001, input.ckpt-0.index, input.ckpt-0.meta
the "input" checkpoint data, index and metadata for step 0
- graph.pbtxt
protobuf of graph in text form; typically present but not relevant to
`Checkpoint`
Expand Down Expand Up @@ -272,18 +262,6 @@ def delete_all_checkpoints_except_most_recent(model_dir: str) -> None:
)


def _delete_input_pipeline_checkpoints(model_dir: str) -> None:
for checkpoint in _scan_checkpoint_directory(model_dir):
if checkpoint.name != _INPUT_CHECKPOINT_PREFIX:
continue

for path in checkpoint.state.all_model_checkpoint_paths:
basename = os.path.basename(path)
for p in checkpoint.paths[basename]:
logging.debug("Deleting input state file %s", p)
os.remove(p)


def load_global_step_from_checkpoint(checkpoint_dir: str) -> Optional[tf.Tensor]:
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if checkpoint is None:
Expand Down
1 change: 0 additions & 1 deletion harness/determined/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
_ArrayLikeAdapter,
_SequenceWithOffset,
_SequenceAdapter,
_TFDatasetAdapter,
InputData,
adapt_keras_data,
adapt_validation_data,
Expand Down
100 changes: 8 additions & 92 deletions harness/determined/keras/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import tensorflow
from packaging import version
import tensorflow as tf

import determined as det
from determined_common import check

# Handle TensorFlow compatibility issues.
if version.parse(tensorflow.__version__) >= version.parse("1.14.0"):
import tensorflow.compat.v1 as tf
else:
import tensorflow as tf

ArrayLike = Union[np.ndarray, List[np.ndarray], Dict[str, np.ndarray]]


Expand Down Expand Up @@ -219,82 +212,7 @@ def stop(self, timeout: Optional[int] = None) -> None:
self._enqueuer.stop(timeout=timeout)


class _TFDatasetAdapter:
"""
A class to assist with restoring and saving iterators for a dataset.
"""

def __init__(self, dataset: tf.data.Dataset, prefetch_buffer: int = 1) -> None:
self.dataset = dataset
self.prefetch_buffer = prefetch_buffer

def get_iterator(self, repeat: bool = False) -> tf.data.Iterator:
"""
Return a tf.data.Iterator
Arguments:
repeat:
Indicate if dataset should be pre-transformed with a repeat().
"""
temp = self.dataset
if repeat:
# Having an extra repeat should be ok, so we don't need to check if
# the dataset already has one.
temp = temp.repeat()

if self.prefetch_buffer > 0:
temp = temp.prefetch(self.prefetch_buffer)

return temp.make_one_shot_iterator()

def save_iterator(
self, iterator: tf.data.Iterator, save_path: str, save_session: tf.Session
) -> None:
"""
Save an iterator to a checkpoint.
Arguments:
iterator:
The iterator to be saved.
save_path:
The path to a checkpoint used for restoring an iterator.
save_session:
The TensorFlow session which should be used for restoring an
iterator from a checkpoint.
"""
saveable = tf.data.experimental.make_saveable_from_iterator(iterator)
saver = tf.train.Saver({"iterator": saveable})
saver.save(save_session, save_path)

def restore_iterator(
self,
iterator: tf.data.Iterator,
restore_path: str,
restore_session: tf.Session,
run_options: tf.RunOptions = None,
) -> tf.data.Iterator:
"""
Restore an iterator from a checkpoint.
Arguments:
iterator:
The iterator to be restored.
restore_path:
The path to a checkpoint used for restoring an iterator.
restore_session:
The TensorFlow session which should be used for restoring an
iterator from a checkpoint.
run_options:
The tf.RunOptions to pass to the tf.Session during
tf.Saver.restore().
"""
saveable = tf.data.experimental.make_saveable_from_iterator(iterator)
restorer = tf.train.Saver({"iterator": saveable})
restorer.restore(restore_session, restore_path, options=run_options)
return iterator


InputData = Union[tf.keras.utils.Sequence, tf.data.Dataset, _TFDatasetAdapter, _SequenceAdapter]
InputData = Union[tf.keras.utils.Sequence, tf.data.Dataset, _SequenceAdapter]


def adapt_keras_data(
Expand All @@ -307,8 +225,8 @@ def adapt_keras_data(
max_queue_size: int = 10,
drop_leftovers: bool = False,
) -> InputData:
"""adapt_keras_data adapts input and target data to a _SequenceAdapter or a
_TFDatasetAdapter, both of which are designed to support random access efficiently,
"""adapt_keras_data adapts input and target data to a _SequenceAdapter or leaves
it as a tf.data.Dataset, both of which are designed to support random access efficiently,
for the purpose of supporting a Determined-managed training loop.
Multiprocessing or multithreading for native Python generators is not supported.
Expand Down Expand Up @@ -380,9 +298,9 @@ def check_y_is_none(y: Any) -> None:

elif isinstance(x, tf.data.Dataset):
check_y_is_none(y)
return _TFDatasetAdapter(x)
return x

elif isinstance(x, (_SequenceAdapter, _TFDatasetAdapter)):
elif isinstance(x, _SequenceAdapter):
check_y_is_none(y)
return x

Expand All @@ -395,9 +313,7 @@ def check_y_is_none(y: Any) -> None:
)


ValidationData = Union[
tuple, tf.keras.utils.Sequence, tf.data.Dataset, _TFDatasetAdapter, _SequenceAdapter
]
ValidationData = Union[tuple, tf.keras.utils.Sequence, tf.data.Dataset, _SequenceAdapter]


def adapt_validation_data(
Expand All @@ -407,7 +323,7 @@ def adapt_validation_data(
workers: int = 1,
) -> InputData:
"""adapt_validation_data adapts inputs and targets of validation data to
a _SequenceAdapter or _TFDatasetAdapter, both of which are designed to
a _SequenceAdapter or leaves it as a tf.data.Dataset, both of which are designed to
support random access efficiently, for the purpose of supporting Determined
managed training loop.
Expand Down
76 changes: 11 additions & 65 deletions harness/determined/keras/_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,29 +387,14 @@ def supports_multi_gpu_training() -> bool:
return True

def set_data_loaders(self, train_config: keras.TFKerasTrainConfig) -> None:
if isinstance(train_config.training_data, keras._TFDatasetAdapter):
if isinstance(train_config.training_data, tf.data.Dataset):
self.is_tf_dataset = True

# TensorFlow 1.15.0 does not allow running model.fit() with a tf.data.Iterator. We use
# model.fit() with a tf.data.Iterator because creating the Iterator ourselves is the
# only way to access the saveable state of the iterator.
check.lt(
version.parse(tf.__version__),
version.parse("1.15.0"),
"TFKerasTrial does not accept tf.data.Dataset objects for training data with "
"TensorFlow 1.15.0 or higher, due to breaking changes in tf.keras. Please "
"downgrade TensorFlow or use a different dataset.",
)
else:
self.is_tf_dataset = False

if self.is_tf_dataset:
self.training_tf_data_adapter = cast(
keras._TFDatasetAdapter, train_config.training_data
)
self.validation_tf_dataset = cast(
keras._TFDatasetAdapter, train_config.validation_data
).dataset
self.training_tf_dataset = train_config.training_data
self.validation_tf_dataset = train_config.validation_data
else:
self.training_keras_data_adapter = cast(
keras._SequenceAdapter, train_config.training_data
Expand All @@ -418,21 +403,6 @@ def set_data_loaders(self, train_config: keras.TFKerasTrainConfig) -> None:
keras._SequenceAdapter, train_config.validation_data
)

def _initialize_tf_dataset_iterators(self) -> None:
"""
Initialize training iterator for tensorflow dataset. It can be already
initialized if we are resuming from the checkpoint.
For distributed training, we don't need to do offset calculation for
sharding the data like how we do in _initialize_keras_data_iterators.
We will use dataset's shard api instead.
Note: We are using validation dataset instead of dataset iterator in
evaluate so not creating validation iterator here.
"""
# self.training_iterator may be not None if we restored it from a checkpoint.
if self.training_iterator is None:
self.training_iterator = self.training_tf_data_adapter.get_iterator(repeat=True)

def _initialize_keras_data_iterators(self) -> None:
"""
Initialize training and validation iterator for keras sequence or
Expand Down Expand Up @@ -465,26 +435,13 @@ def _initialize_iterators(self) -> None:
Initialize training and validation iterators, the training iterator
remains initialized throughout the lifetime of this process.
"""
if self.is_tf_dataset:
self._initialize_tf_dataset_iterators()
else:
if not self.is_tf_dataset:
self._initialize_keras_data_iterators()

def _load(self) -> None:
if not self.load_path:
return

# load training keras
if self.is_tf_dataset:
# TODO (DET-1792): Currently, in distributed training, this will
# load the checkpoint of process 0 on all processes during restart.
# This needs to be fixed. More details in the ticket.
iterator_path = self.load_path.joinpath("iterator_state.pkl")
self.training_iterator = self.training_tf_data_adapter.get_iterator(repeat=True)
self.training_iterator = self.training_tf_data_adapter.restore_iterator(
self.training_iterator, str(iterator_path), self.session
)

# load model
full_ckpt_path = self.load_path.joinpath("determined-keras-model")
logging.info(f"Restoring checkpoint from {full_ckpt_path}")
Expand All @@ -500,11 +457,6 @@ def _save_checkpoint(self, path: pathlib.Path) -> workload.Response:

# save training data iterator position.
path.mkdir(parents=True, exist_ok=True)
if self.is_tf_dataset:
iterator_path = path.joinpath("iterator_state.pkl")
self.training_tf_data_adapter.save_iterator(
self.training_iterator, str(iterator_path), self.session
)

# save model weights
tf.keras.models.save_model(
Expand Down Expand Up @@ -560,7 +512,6 @@ def run(self) -> None:

def _launch_fit(self, initial_epoch: int) -> None:
check.false(self.fit_loop_started)
check.is_not_none(self.training_iterator)
self.fit_loop_started = True

self.tf_keras_callbacks.append(WaitForInstructionsCallback(self))
Expand All @@ -578,12 +529,14 @@ def _launch_fit(self, initial_epoch: int) -> None:

# Tensorflow dataset doesn't provide length api so use the configured batches_per_step
if self.is_tf_dataset:
training_input = self.training_tf_dataset
steps_per_epoch = self.batches_per_step
else:
training_input = self.training_iterator
steps_per_epoch = len(self.training_keras_data_adapter)

_ = self.model.fit(
self.training_iterator,
training_input,
callbacks=self.tf_keras_callbacks,
shuffle=False,
steps_per_epoch=steps_per_epoch,
Expand Down Expand Up @@ -693,11 +646,10 @@ def build_training_data_loader(self) -> keras.InputData:
interface, or a `tf.data.Dataset
<https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/data/Dataset>`__.
WARNING: If you are using ``tf.data.Dataset`` with
distributed training, Determined’s support for automatically checkpointing and
resuming workloads does not work correctly. Therefore, using
tf.data.Dataset inputs with distributed training is currently not
recommended.
WARNING: If you are using ``tf.data.Dataset``, Determined’s support for
automatically checkpointing the dataset does not currently work correctly.
This means that resuming workloads will start from the beginning of the dataset
if using `tf.data.Dataset`.
"""
pass

Expand All @@ -710,12 +662,6 @@ def build_validation_data_loader(self) -> keras.InputData:
<https://tensorflow.org/api_docs/python/tf/keras/utils/Sequence>`__
interface, or a `tf.data.Dataset
<https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/data/Dataset>`__.
WARNING: If you are using ``tf.data.Dataset`` with
distributed training, Determined’s support for automatically checkpointing and
resuming workloads does not work correctly. Therefore, using
tf.data.Dataset inputs with distributed training is currently not
recommended.
"""
pass

Expand Down

0 comments on commit 5e14162

Please sign in to comment.