Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DET-2835] feat: support tf.data.Dataset for all TF version #98

Merged
merged 1 commit into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion harness/determined/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from determined.estimator._util import (
_cleanup_after_train_step,
_cleanup_after_validation_step,
_delete_input_pipeline_checkpoints,
_update_checkpoint_path_in_state_file,
_scan_checkpoint_directory,
)
Expand Down
11 changes: 0 additions & 11 deletions harness/determined/estimator/_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,6 @@ def _init_model(self) -> None:

all_hooks = [*self.user_train_spec.hooks]

# The following skip_checkpointing_input flag is a workaround for
# stateful datasets in TF 1.14. Stateful input pipeline functions
# cannot be serialized and therefore checkpointing them should be
# skipped.
if (
not self.env.experiment_config.get("data", {}).get("skip_checkpointing_input", False)
and not self.env.experiment_config.input_from_dataflow()
):
all_hooks.append(tf.data.experimental.CheckpointInputPipelineHook(self.estimator))

if self.hvd_config.use:
all_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

Expand All @@ -478,7 +468,6 @@ def _init_model(self) -> None:
# In the short term, behave like other trials and reset input
# state if we are warm started. This will create an inconsistency
# wrt saved optimizer state.
estimator._delete_input_pipeline_checkpoints(str(self.estimator_dir))

self.train_spec = tf.estimator.TrainSpec(
input_fn=self.user_train_spec.input_fn, hooks=all_hooks
Expand Down
22 changes: 0 additions & 22 deletions harness/determined/estimator/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
# TODO: DET-175 direct checkpoint file manipulation is deprecated in TF 1.13.


# This is the filename prefix used by `tf.data.estimator.CheckpointInputPipelineHook`
# to write its checkpoints.
_INPUT_CHECKPOINT_PREFIX = "input"


CheckpointInputPipelineHook = tf.data.experimental.CheckpointInputPipelineHook


class Checkpoint:
"""
The metadata about a checkpoint.
Expand All @@ -47,8 +39,6 @@ class Checkpoint:
the "model" checkpoint data, index and metadata for step 0
- checkpoint_input
the "input" checkpoint state
- input.ckpt-0.data-00000-of-00001, input.ckpt-0.index, input.ckpt-0.meta
the "input" checkpoint data, index and metadata for step 0
- graph.pbtxt
protobuf of graph in text form; typically present but not relevant to
`Checkpoint`
Expand Down Expand Up @@ -272,18 +262,6 @@ def delete_all_checkpoints_except_most_recent(model_dir: str) -> None:
)


def _delete_input_pipeline_checkpoints(model_dir: str) -> None:
for checkpoint in _scan_checkpoint_directory(model_dir):
if checkpoint.name != _INPUT_CHECKPOINT_PREFIX:
continue

for path in checkpoint.state.all_model_checkpoint_paths:
basename = os.path.basename(path)
for p in checkpoint.paths[basename]:
logging.debug("Deleting input state file %s", p)
os.remove(p)


def load_global_step_from_checkpoint(checkpoint_dir: str) -> Optional[tf.Tensor]:
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if checkpoint is None:
Expand Down
1 change: 0 additions & 1 deletion harness/determined/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
_ArrayLikeAdapter,
_SequenceWithOffset,
_SequenceAdapter,
_TFDatasetAdapter,
InputData,
adapt_keras_data,
adapt_validation_data,
Expand Down
100 changes: 8 additions & 92 deletions harness/determined/keras/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import tensorflow
from packaging import version
import tensorflow as tf

import determined as det
from determined_common import check

# Handle TensorFlow compatibility issues.
if version.parse(tensorflow.__version__) >= version.parse("1.14.0"):
import tensorflow.compat.v1 as tf
else:
import tensorflow as tf

ArrayLike = Union[np.ndarray, List[np.ndarray], Dict[str, np.ndarray]]


Expand Down Expand Up @@ -219,82 +212,7 @@ def stop(self, timeout: Optional[int] = None) -> None:
self._enqueuer.stop(timeout=timeout)


class _TFDatasetAdapter:
"""
A class to assist with restoring and saving iterators for a dataset.
"""

def __init__(self, dataset: tf.data.Dataset, prefetch_buffer: int = 1) -> None:
self.dataset = dataset
self.prefetch_buffer = prefetch_buffer

def get_iterator(self, repeat: bool = False) -> tf.data.Iterator:
"""
Return a tf.data.Iterator

Arguments:
repeat:
Indicate if dataset should be pre-transformed with a repeat().
"""
temp = self.dataset
if repeat:
# Having an extra repeat should be ok, so we don't need to check if
# the dataset already has one.
temp = temp.repeat()

if self.prefetch_buffer > 0:
temp = temp.prefetch(self.prefetch_buffer)

return temp.make_one_shot_iterator()

def save_iterator(
self, iterator: tf.data.Iterator, save_path: str, save_session: tf.Session
) -> None:
"""
Save an iterator to a checkpoint.

Arguments:
iterator:
The iterator to be saved.
save_path:
The path to a checkpoint used for restoring an iterator.
save_session:
The TensorFlow session which should be used for restoring an
iterator from a checkpoint.
"""
saveable = tf.data.experimental.make_saveable_from_iterator(iterator)
saver = tf.train.Saver({"iterator": saveable})
saver.save(save_session, save_path)

def restore_iterator(
self,
iterator: tf.data.Iterator,
restore_path: str,
restore_session: tf.Session,
run_options: tf.RunOptions = None,
) -> tf.data.Iterator:
"""
Restore an iterator from a checkpoint.

Arguments:
iterator:
The iterator to be restored.
restore_path:
The path to a checkpoint used for restoring an iterator.
restore_session:
The TensorFlow session which should be used for restoring an
iterator from a checkpoint.
run_options:
The tf.RunOptions to pass to the tf.Session during
tf.Saver.restore().
"""
saveable = tf.data.experimental.make_saveable_from_iterator(iterator)
restorer = tf.train.Saver({"iterator": saveable})
restorer.restore(restore_session, restore_path, options=run_options)
return iterator


InputData = Union[tf.keras.utils.Sequence, tf.data.Dataset, _TFDatasetAdapter, _SequenceAdapter]
InputData = Union[tf.keras.utils.Sequence, tf.data.Dataset, _SequenceAdapter]


def adapt_keras_data(
Expand All @@ -307,8 +225,8 @@ def adapt_keras_data(
max_queue_size: int = 10,
drop_leftovers: bool = False,
) -> InputData:
"""adapt_keras_data adapts input and target data to a _SequenceAdapter or a
_TFDatasetAdapter, both of which are designed to support random access efficiently,
"""adapt_keras_data adapts input and target data to a _SequenceAdapter or leaves
it as a tf.data.Dataset, both of which are designed to support random access efficiently,
for the purpose of supporting a Determined-managed training loop.

Multiprocessing or multithreading for native Python generators is not supported.
Expand Down Expand Up @@ -380,9 +298,9 @@ def check_y_is_none(y: Any) -> None:

elif isinstance(x, tf.data.Dataset):
check_y_is_none(y)
return _TFDatasetAdapter(x)
return x

elif isinstance(x, (_SequenceAdapter, _TFDatasetAdapter)):
elif isinstance(x, _SequenceAdapter):
check_y_is_none(y)
return x

Expand All @@ -395,9 +313,7 @@ def check_y_is_none(y: Any) -> None:
)


ValidationData = Union[
tuple, tf.keras.utils.Sequence, tf.data.Dataset, _TFDatasetAdapter, _SequenceAdapter
]
ValidationData = Union[tuple, tf.keras.utils.Sequence, tf.data.Dataset, _SequenceAdapter]


def adapt_validation_data(
Expand All @@ -407,7 +323,7 @@ def adapt_validation_data(
workers: int = 1,
) -> InputData:
"""adapt_validation_data adapts inputs and targets of validation data to
a _SequenceAdapter or _TFDatasetAdapter, both of which are designed to
a _SequenceAdapter or leaves it as a tf.data.Dataset, both of which are designed to
support random access efficiently, for the purpose of supporting Determined
managed training loop.

Expand Down
76 changes: 11 additions & 65 deletions harness/determined/keras/_tf_keras_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,29 +387,14 @@ def supports_multi_gpu_training() -> bool:
return True

def set_data_loaders(self, train_config: keras.TFKerasTrainConfig) -> None:
if isinstance(train_config.training_data, keras._TFDatasetAdapter):
if isinstance(train_config.training_data, tf.data.Dataset):
self.is_tf_dataset = True

# TensorFlow 1.15.0 does not allow running model.fit() with a tf.data.Iterator. We use
# model.fit() with a tf.data.Iterator because creating the Iterator ourselves is the
# only way to access the saveable state of the iterator.
check.lt(
version.parse(tf.__version__),
version.parse("1.15.0"),
"TFKerasTrial does not accept tf.data.Dataset objects for training data with "
"TensorFlow 1.15.0 or higher, due to breaking changes in tf.keras. Please "
"downgrade TensorFlow or use a different dataset.",
)
else:
self.is_tf_dataset = False
Comment on lines 391 to 393
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need self.is_tf_dataset? Would be nice to get rid of it and just use isinstance() in cases where we need to treat it differently. As far as I can tell in this PR, it seems like only steps_per_epoch needs to be treated different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im just leaving it in for now since i already have separate (blocked) PR that cleans all of this up: #6


if self.is_tf_dataset:
self.training_tf_data_adapter = cast(
keras._TFDatasetAdapter, train_config.training_data
)
self.validation_tf_dataset = cast(
keras._TFDatasetAdapter, train_config.validation_data
).dataset
self.training_tf_dataset = train_config.training_data
self.validation_tf_dataset = train_config.validation_data
else:
self.training_keras_data_adapter = cast(
keras._SequenceAdapter, train_config.training_data
Expand All @@ -418,21 +403,6 @@ def set_data_loaders(self, train_config: keras.TFKerasTrainConfig) -> None:
keras._SequenceAdapter, train_config.validation_data
)

def _initialize_tf_dataset_iterators(self) -> None:
"""
Initialize training iterator for tensorflow dataset. It can be already
initialized if we are resuming from the checkpoint.
For distributed training, we don't need to do offset calculation for
sharding the data like how we do in _initialize_keras_data_iterators.
We will use dataset's shard api instead.

Note: We are using validation dataset instead of dataset iterator in
evaluate so not creating validation iterator here.
"""
# self.training_iterator may be not None if we restored it from a checkpoint.
if self.training_iterator is None:
self.training_iterator = self.training_tf_data_adapter.get_iterator(repeat=True)

def _initialize_keras_data_iterators(self) -> None:
"""
Initialize training and validation iterator for keras sequence or
Expand Down Expand Up @@ -465,26 +435,13 @@ def _initialize_iterators(self) -> None:
Initialize training and validation iterators, the training iterator
remains initialized throughout the lifetime of this process.
"""
if self.is_tf_dataset:
self._initialize_tf_dataset_iterators()
else:
if not self.is_tf_dataset:
self._initialize_keras_data_iterators()

def _load(self) -> None:
if not self.load_path:
return

# load training keras
if self.is_tf_dataset:
# TODO (DET-1792): Currently, in distributed training, this will
# load the checkpoint of process 0 on all processes during restart.
# This needs to be fixed. More details in the ticket.
iterator_path = self.load_path.joinpath("iterator_state.pkl")
self.training_iterator = self.training_tf_data_adapter.get_iterator(repeat=True)
self.training_iterator = self.training_tf_data_adapter.restore_iterator(
self.training_iterator, str(iterator_path), self.session
)

# load model
full_ckpt_path = self.load_path.joinpath("determined-keras-model")
logging.info(f"Restoring checkpoint from {full_ckpt_path}")
Expand All @@ -500,11 +457,6 @@ def _save_checkpoint(self, path: pathlib.Path) -> workload.Response:

# save training data iterator position.
path.mkdir(parents=True, exist_ok=True)
if self.is_tf_dataset:
iterator_path = path.joinpath("iterator_state.pkl")
self.training_tf_data_adapter.save_iterator(
self.training_iterator, str(iterator_path), self.session
)

# save model weights
tf.keras.models.save_model(
Expand Down Expand Up @@ -560,7 +512,6 @@ def run(self) -> None:

def _launch_fit(self, initial_epoch: int) -> None:
check.false(self.fit_loop_started)
check.is_not_none(self.training_iterator)
self.fit_loop_started = True

self.tf_keras_callbacks.append(WaitForInstructionsCallback(self))
Expand All @@ -578,12 +529,14 @@ def _launch_fit(self, initial_epoch: int) -> None:

# Tensorflow dataset doesn't provide length api so use the configured batches_per_step
if self.is_tf_dataset:
training_input = self.training_tf_dataset
steps_per_epoch = self.batches_per_step
else:
training_input = self.training_iterator
steps_per_epoch = len(self.training_keras_data_adapter)

_ = self.model.fit(
self.training_iterator,
training_input,
callbacks=self.tf_keras_callbacks,
shuffle=False,
steps_per_epoch=steps_per_epoch,
Expand Down Expand Up @@ -693,11 +646,10 @@ def build_training_data_loader(self) -> keras.InputData:
interface, or a `tf.data.Dataset
<https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/data/Dataset>`__.

WARNING: If you are using ``tf.data.Dataset`` with
distributed training, Determined’s support for automatically checkpointing and
resuming workloads does not work correctly. Therefore, using
tf.data.Dataset inputs with distributed training is currently not
recommended.
WARNING: If you are using ``tf.data.Dataset``, Determined’s support for
automatically checkpointing the dataset does not currently work correctly.
This means that resuming workloads will start from the beginning of the dataset
if using `tf.data.Dataset`.
"""
pass

Expand All @@ -710,12 +662,6 @@ def build_validation_data_loader(self) -> keras.InputData:
<https://tensorflow.org/api_docs/python/tf/keras/utils/Sequence>`__
interface, or a `tf.data.Dataset
<https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/data/Dataset>`__.

WARNING: If you are using ``tf.data.Dataset`` with
distributed training, Determined’s support for automatically checkpointing and
resuming workloads does not work correctly. Therefore, using
tf.data.Dataset inputs with distributed training is currently not
recommended.
"""
pass

Expand Down