From 68f3568eb983dc7e27528e074aa5d64551718fee Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 30 Sep 2020 15:53:09 -0600 Subject: [PATCH] docs: add a tf.layers-in-Estimator example (#1383) We sometimes encounter folks looking to run low-level TF code in our platform. We don't support that directly anymore; TF recommends wrapping low-level graphs in higher-level APIs. The mnist_tp_to_estimator examples used to show this use case, but since we removed Tensorpack entirely, we needed a new example. --- docs/examples.txt | 4 + docs/faq.txt | 2 +- examples/README.md | 2 +- .../computer_vision/mnist_tf_layers/README.md | 33 +++++ .../mnist_tf_layers/const.yaml | 13 ++ .../mnist_tf_layers/model_def.py | 140 ++++++++++++++++++ .../mnist_tf_layers/startup-hook.sh | 4 + examples/tests/requirements.txt | 1 + examples/tests/test_official.py | 5 + 9 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 examples/computer_vision/mnist_tf_layers/README.md create mode 100644 examples/computer_vision/mnist_tf_layers/const.yaml create mode 100644 examples/computer_vision/mnist_tf_layers/model_def.py create mode 100644 examples/computer_vision/mnist_tf_layers/startup-hook.sh diff --git a/docs/examples.txt b/docs/examples.txt index 01326a4fc76..3ffaf95bfe9 100644 --- a/docs/examples.txt +++ b/docs/examples.txt @@ -56,6 +56,10 @@ Computer Vision - MNIST - :download:`mnist_estimator.tgz ` + * - TensorFlow (tf.layers via Estimator API) + - MNIST + - :download:`mnist_tf_layers.tgz ` + * - TensorFlow (tf.keras) - Fashion MNIST - :download:`fashion_mnist_tf_keras.tgz ` diff --git a/docs/faq.txt b/docs/faq.txt index 679b4d88cc5..91b019c3d28 100644 --- a/docs/faq.txt +++ b/docs/faq.txt @@ -199,7 +199,7 @@ Determined has support for TensorFlow models that use the :ref:`tf.keras that use the low-level TensorFlow Core APIs, we recommend porting your model to use :ref:`Estimator Trial `. `Example of converting a Tensorflow graph into an Estimator -`_. +`_. ***************** PyTorch Support diff --git a/examples/README.md b/examples/README.md index 93ff5efaac4..0f8063fe617 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,7 +2,7 @@ | Domain | Example | Dataset | Framework | |:------:|:-----:|:-------:|:------------------:| | Tutorials |

[mnist_pytorch](tutorials/mnist_pytorch)
[fashion_mnist_tf_keras](tutorials/fashion_mnist_tf_keras)

|

MNIST
Fashion MNIST

|

PyTorch
TensorFlow (tf.keras)

| -| Computer Vision |

[cifar10_pytorch](computer_vision/cifar10_pytorch)
[mnist_multi_output_pytorch](computer_vision/mnist_multi_output_pytorch)
[fasterrcnn_coco_pytorch](computer_vision/fasterrcnn_coco_pytorch)
[mnist_estimator](computer_vision/mnist_estimator)
[cifar10_tf_keras](computer_vision/cifar10_tf_keras)
[iris_tf_keras](computer_vision/iris_tf_keras)
[unets_tf_keras](computer_vision/unets_tf_keras)

|

CIFAR-10
MNIST
Penn-Fudan Dataset
MNIST
CIFAR-10
Iris Dataset
Oxford-IIIT Pet Dataset

|

PyTorch
PyTorch
PyTorch
TensorFlow (Estimator API)
TensorFlow (tf.keras)
TensorFlow (tf.keras)
TensorFlow (tf.keras)

| +| Computer Vision |

[cifar10_pytorch](computer_vision/cifar10_pytorch)
[mnist_multi_output_pytorch](computer_vision/mnist_multi_output_pytorch)
[fasterrcnn_coco_pytorch](computer_vision/fasterrcnn_coco_pytorch)
[mnist_estimator](computer_vision/mnist_estimator)
[mnist_tf_layers](computer_vision/mnist_tf_layers)
[cifar10_tf_keras](computer_vision/cifar10_tf_keras)
[iris_tf_keras](computer_vision/iris_tf_keras)
[unets_tf_keras](computer_vision/unets_tf_keras)

|

CIFAR-10
MNIST
Penn-Fudan Dataset
MNIST
MNIST
CIFAR-10
Iris Dataset
Oxford-IIIT Pet Dataset

|

PyTorch
PyTorch
PyTorch
TensorFlow (Estimator API)
TensorFlow (tf.layers via Estimator API)
TensorFlow (tf.keras)
TensorFlow (tf.keras)
TensorFlow (tf.keras)

| | Natural Language Processing (NLP) |

[bert_squad_pytorch](nlp/bert_squad_pytorch)
[bert_glue_pytorch](nlp/bert_glue_pytorch)

|

SQuAD
GLUE

|

PyTorch
PyTorch

| | HP Search Benchmarks |

[darts_cifar10_pytorch](hp_search_benchmarks/darts_cifar10_pytorch)
[darts_penntreebank_pytorch](hp_search_benchmarks/darts_penntreebank_pytorch)

|

CIFAR-10
Penn Treebank Dataset

|

PyTorch
PyTorch

| | Neural Architecture Search (NAS) | [gaea_pytorch](nas/gaea_pytorch) | DARTS | PyTorch | diff --git a/examples/computer_vision/mnist_tf_layers/README.md b/examples/computer_vision/mnist_tf_layers/README.md new file mode 100644 index 00000000000..967ae2c68c5 --- /dev/null +++ b/examples/computer_vision/mnist_tf_layers/README.md @@ -0,0 +1,33 @@ +# Training a TensorFlow Graph in Determined (via Estimator API) + +This example shows how wrap a graph defined in low-level TensorFlow APIs in a +custom Estimator, and then run it in Determined. + +## Files +* **model_def.py**: The core code for the model. This includes code for +defining the model in low-level TensorFlow APIs, as well as for defining the +custom Estimator and the EstimatorTrial. + +* **startup-hook.sh**: Predownload the dataset in the container. This ensures +that the dataset download does not cause conflicts between multiple workers +trying to download to the same directory if you were to reconfigure the +experiment for distributed training. + +### Configuration Files +* **const.yaml**: Train the model with constant hyperparameter values. + +## Data +Estimators require tf.data.Datasets as inputs. This examples uses the +`tensorflow_datasets` MNIST dataset as input. + +## To Run +If you have not yet installed Determined, installation instructions can be found +under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html + +Run the following command: `det -m experiment create -f +const.yaml .`. The other configurations can be run by specifying the appropriate +configuration file in place of `const.yaml`. + +## Results +Training the model with the hyperparameter settings in `const.yaml` should yield +a validation error of < 2%. diff --git a/examples/computer_vision/mnist_tf_layers/const.yaml b/examples/computer_vision/mnist_tf_layers/const.yaml new file mode 100644 index 00000000000..b0a5ad22b1d --- /dev/null +++ b/examples/computer_vision/mnist_tf_layers/const.yaml @@ -0,0 +1,13 @@ +description: mnist_tf_core_to_estimator +hyperparameters: + learning_rate: 1e-3 + global_batch_size: 64 + n_filters_1: 10 + n_filters_2: 40 +searcher: + name: single + metric: error + smaller_is_better: true + max_length: + batches: 2000 +entrypoint: model_def:MNistTrial diff --git a/examples/computer_vision/mnist_tf_layers/model_def.py b/examples/computer_vision/mnist_tf_layers/model_def.py new file mode 100644 index 00000000000..ecb3c98b502 --- /dev/null +++ b/examples/computer_vision/mnist_tf_layers/model_def.py @@ -0,0 +1,140 @@ +""" +An example showing how to use a graph defined in low-level TensorFlow APIs in Determined. + +We will be wrapping the TensorFlow graph in an Estimator and using Determined's EstimatorTrial. +""" +from typing import Any, Callable, Dict + +import tensorflow.compat.v1 as tf +import tensorflow_datasets as tfds + +from determined import estimator + +NUM_CLASSES = 10 + + +def calculate_logits(hparams: Dict[str, Any], images: tf.Tensor, training: bool) -> tf.Tensor: + """This example assumes you already have something like this written for defining your graph.""" + conv1 = tf.layers.conv2d( + inputs=tf.cast(images, tf.float32), + filters=hparams["n_filters_1"], + kernel_size=[5, 5], + padding="same", + activation=tf.nn.relu, + ) + pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) + + conv2 = tf.layers.conv2d( + inputs=pool1, + filters=hparams["n_filters_2"], + kernel_size=[5, 5], + padding="same", + activation=tf.nn.relu, + ) + pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) + pool2_shape = pool2.get_shape().as_list() + + pool2_flat = tf.reshape( + pool2, [-1, pool2_shape[1] * pool2_shape[2] * pool2_shape[3]] + ) + dense = tf.layers.dense(inputs=pool2_flat, units=512, activation=tf.nn.relu) + + if training: + dropout = tf.layers.dropout(inputs=dense, rate=0.5) + logits = tf.layers.dense(inputs=dropout, units=NUM_CLASSES) + else: + logits = tf.layers.dense(inputs=dense, units=NUM_CLASSES) + + return logits + + +def calculate_loss(labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: + """This example assumes you already have something like this written for defining your graph.""" + return tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) + ) + + +def calculate_predictions(logits: tf.Tensor) -> tf.Tensor: + """This example assumes you already have something like this written for defining your graph.""" + return tf.argmax(logits, axis=1) + + +def calculate_error(predictions: tf.Tensor, labels: tf.Tensor) -> tf.Tensor: + """This example assumes you already have something like this written for defining your graph.""" + correct = tf.cast(tf.equal(predictions, labels), tf.float32) + return 1 - tf.reduce_mean(correct) + + +def make_model_fn(context: estimator.EstimatorTrialContext) -> Callable: + # Define a model_fn which is the magic ingredient for wrapping a tensorflow graph in an + # Estimator. The Estimator training loop will call this function with different modes to + # build graphs for either training or validation (or prediction, but that's not used by + # Determined). + # + # Read more at https://www.tensorflow.org/guide/estimator. + def model_fn(features: Any, mode: tf.estimator.ModeKeys) -> tf.estimator.EstimatorSpec: + # The "features" argument must be named "features", but in this simple example, it + # contains the full output of our dataset, including the images and the labels. + images = features["image"] + labels = features["label"] + + if mode == tf.estimator.ModeKeys.TRAIN: + # Build a graph for training. + logits = calculate_logits(context.get_hparams(), images, training=True) + loss = calculate_loss(labels, logits) + + learning_rate = context.get_hparam("learning_rate") + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + optimizer = context.wrap_optimizer(optimizer) + + train_op = optimizer.minimize( + loss, global_step=tf.train.get_global_step() + ) + return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) + + if mode == tf.estimator.ModeKeys.EVAL: + # Build a graph for validation. + logits = calculate_logits( + context.get_hparams(), images, training=False + ) + loss = calculate_loss(labels, logits) + predictions = calculate_predictions(logits) + error = calculate_error(predictions, labels) + return tf.estimator.EstimatorSpec( + mode, + loss=loss, + eval_metric_ops={"error": tf.metrics.mean(error)}, + ) + + return model_fn + + +class MNistTrial(estimator.EstimatorTrial): + def __init__(self, context: estimator.EstimatorTrialContext) -> None: + self.context = context + + def build_estimator(self) -> tf.estimator.Estimator: + return tf.estimator.Estimator(model_fn=make_model_fn(self.context)) + + def build_train_spec(self) -> tf.estimator.TrainSpec: + # Write a function which returns your dataset for training... + def input_fn() -> tf.data.Dataset: + ds = tfds.image.MNIST().as_dataset()["train"] + ds = self.context.wrap_dataset(ds) + ds = ds.batch(self.context.get_per_slot_batch_size()) + return ds + + # ... then return a TrainSpec which includes that function. + return tf.estimator.TrainSpec(input_fn) + + def build_validation_spec(self) -> tf.estimator.EvalSpec: + # Write a function which returns your dataset for validation... + def input_fn() -> tf.data.Dataset: + ds = tfds.image.MNIST().as_dataset()["test"] + ds = self.context.wrap_dataset(ds) + ds = ds.batch(self.context.get_per_slot_batch_size()) + return ds + + # ... then return an EvalSpec which includes that function. + return tf.estimator.EvalSpec(input_fn, steps=None) diff --git a/examples/computer_vision/mnist_tf_layers/startup-hook.sh b/examples/computer_vision/mnist_tf_layers/startup-hook.sh new file mode 100644 index 00000000000..1fc50cc24f3 --- /dev/null +++ b/examples/computer_vision/mnist_tf_layers/startup-hook.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +# Download the dataset before starting training. +python -c "import tensorflow_datasets as tfds; tfds.image.MNIST().download_and_prepare()" diff --git a/examples/tests/requirements.txt b/examples/tests/requirements.txt index 9bb5e68f1c0..8d90f9ca818 100644 --- a/examples/tests/requirements.txt +++ b/examples/tests/requirements.txt @@ -4,3 +4,4 @@ tensorflow==2.2.0 torch==1.4.0 torchvision==0.5.0 pandas==1.0.3 +tensorflow_datasets diff --git a/examples/tests/test_official.py b/examples/tests/test_official.py index 032b35614a2..06061ad9566 100644 --- a/examples/tests/test_official.py +++ b/examples/tests/test_official.py @@ -18,6 +18,7 @@ ), ("computer_vision/iris_tf_keras", "computer_vision/iris_tf_keras/const.yaml"), ("computer_vision/mnist_estimator", "computer_vision/mnist_estimator/const.yaml"), + ("computer_vision/mnist_tf_layers", "computer_vision/mnist_tf_layers/const.yaml"), ("tutorials/mnist_pytorch", "tutorials/mnist_pytorch/const.yaml"), ("gan/gan_mnist_pytorch", "gan/gan_mnist_pytorch/const.yaml"), ] @@ -29,6 +30,10 @@ def test_official(model_def: str, config_file: str) -> None: model_def_absolute = examples_dir.joinpath(model_def) config_file_absolute = examples_dir.joinpath(config_file) + startup_hook = model_def_absolute.joinpath("startup-hook.sh") + if startup_hook.exists(): + subprocess.check_output(("sh", str(startup_hook))) + subprocess.check_output( ( "det",