diff --git a/README.md b/README.md index 916e5200b29841..ef5bdc66ef0313 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ ----------------- -| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------| -| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | + +| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | +|-----------------|---------------------|------------------|-------------------|---------------|---------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -21,20 +22,6 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -**If you want to contribute to TensorFlow, be sure to review the [contribution -guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's -[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to -uphold this code.** - -**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions -and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** - -The TensorFlow project strives to abide by generally accepted best practices in open-source software development: - -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) - ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -75,6 +62,22 @@ $ python >>> sess.close() ``` +## Contribution guidelines + +**If you want to contribute to TensorFlow, be sure to review the [contribution +guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's +[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to +uphold this code.** + +**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for +tracking requests and bugs. So please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions +and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** + +The TensorFlow project strives to abide by generally accepted best practices in open-source software development: + +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/SECURITY.md b/SECURITY.md index 6ddac1f964dfba..fea24b27392088 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -233,7 +233,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known vulnerabilities -| Type | Versions affected | Reported by | Additional Information | -|------|:-----------------:|---------------------------------------| -| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +| Type | Versions affected | Reported by | Additional Information | +|-------------------|:-----------------:|--------------------|-----------------------------| +| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/configure b/configure index 9c21d2b03a2771..66b66ba54ed68a 100755 --- a/configure +++ b/configure @@ -8,7 +8,8 @@ if [ -z "$PYTHON_BIN_PATH" ]; then fi # Set all env variables -"$PYTHON_BIN_PATH" configure.py +CONFIGURE_DIR=$(dirname "$0") +"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" echo "Configuration finished" diff --git a/configure.py b/configure.py index 9744f6ac81a352..97f46757ee241b 100644 --- a/configure.py +++ b/configure.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import argparse import errno import os import platform @@ -32,10 +33,6 @@ from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '.tf_configure.bazelrc') -_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'WORKSPACE') _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' @@ -51,6 +48,11 @@ _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 +_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) +_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' +_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) +_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') + class UserInputError(Exception): pass @@ -119,22 +121,6 @@ def sed_in_place(filename, old, new): f.write(newdata) -def remove_line_with(filename, token): - """Remove lines that contain token from file. - - Args: - filename: string for filename. - token: string token to check if to remove a line from file or not. - """ - with open(filename, 'r') as f: - filedata = f.read() - - with open(filename, 'w') as f: - for line in filedata.strip().split('\n'): - if token not in line: - f.write(line + '\n') - - def write_to_bazelrc(line): with open(_TF_BAZELRC, 'a') as f: f.write(line + '\n') @@ -245,25 +231,30 @@ def setup_python(environ_cp): environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open('tools/python_bin_path.sh', 'w') as f: + with open(os.path.join( + _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(): +def reset_tf_configure_bazelrc(workspace_path): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - - home = os.path.expanduser('~') - if not os.path.exists('.bazelrc'): - if os.path.exists(os.path.join(home, '.bazelrc')): - with open('.bazelrc', 'a') as f: - f.write('import %s/.bazelrc\n' % home.replace('\\', '/')) + bazelrc_path = os.path.join(workspace_path, '.bazelrc') + + data = [] + if os.path.exists(bazelrc_path): + with open(bazelrc_path, 'r') as f: + data = f.read().splitlines() + with open(bazelrc_path, 'w') as f: + for l in data: + if _TF_BAZELRC_FILENAME in l: + continue + f.write('%s\n' % l) + if is_windows(): + tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") else: - open('.bazelrc', 'w').close() - - remove_line_with('.bazelrc', 'tf_configure') - with open('.bazelrc', 'a') as f: - f.write('import %workspace%/.tf_configure.bazelrc\n') + tf_bazelrc_path = _TF_BAZELRC + f.write('import %s\n' % tf_bazelrc_path) def cleanup_makefile(): @@ -271,7 +262,8 @@ def cleanup_makefile(): These files could interfere with Bazel parsing. """ - makefile_download_dir = 'tensorflow/contrib/makefile/downloads' + makefile_download_dir = os.path.join( + _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -456,7 +448,7 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', 'version']) + curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -502,7 +494,8 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - write_to_bazelrc('build:opt --host_copt=-march=native') + if not is_ppc64le(): + write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') # TODO(mikecase): Remove these default defines once we are able to get # TF Lite targets building without them. @@ -1229,7 +1222,7 @@ def set_host_c_compiler(environ_cp): environ_cp, var_name='HOST_C_COMPILER', var_default=default_c_host_compiler, - ask_for_var=('Please specify which C compiler should be used as the host' + ask_for_var=('Please specify which C compiler should be used as the host ' 'C compiler.'), check_success=os.path.exists, error_msg='Invalid C compiler path. %s cannot be found.', @@ -1373,13 +1366,20 @@ def config_info_line(name, help_text): def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", + type=str, + default=_TF_WORKSPACE_ROOT, + help="The absolute path to your active Bazel workspace.") + args = parser.parse_args() + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) check_bazel_version('0.5.4') - reset_tf_configure_bazelrc() + reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 9b732421e56b3e..0cb3132e94e381 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -182,6 +182,70 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + MaxPool3DGrad::Attrs grad_attrs; + auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); + +Status AvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + internal::AvgPoolGrad::Attrs grad_attrs; + auto dx = + internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); + +Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + AvgPool3DGrad::Attrs grad_attrs; + auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper); + Status LRNGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 0cfe5f6e3c49f7..c4eba7ecb017fe 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -31,8 +31,11 @@ using ops::Elu; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; +using ops::AvgPool; +using ops::AvgPool3D; using ops::MaxPool; using ops::MaxPoolV2; +using ops::MaxPool3D; using ops::Placeholder; using ops::Relu; using ops::Relu6; @@ -70,9 +73,9 @@ class NNGradTest : public ::testing::Test { // Sets tensor with random values, ensuring that the max value is largest by // a reasonable amount. - // This is an issue for MaxPool and MaxPoolV2, in which perturbations by the - // numeric gradient computation in the gradient checker can change the max - // value if values are too close together. + // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which + // perturbations by the numeric gradient computation in the gradient checker + // can change the max value if values are too close together. template void SetRandomValuesWithBumpedMax(Tensor* tensor) { auto tensor_flat = tensor->flat(); @@ -203,6 +206,41 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, MaxPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one MaxPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesWithBumpedMax(&x_init_value); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NNGradTest, AvgPoolGradHelper) { + TensorShape x_shape({1, 2, 2, 1}); + TensorShape y_shape({1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool. + const std::vector ksize{1, 2, 2, 1}; + const std::vector strides{1, 2, 2, 1}; + auto y = AvgPool(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + +TEST_F(NNGradTest, AvgPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(NNGradTest, LRN){ TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index 6077c45c5854fd..64edbb5766c360 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -61,18 +61,18 @@ class Profiler { /// Adds tracing information `run_meta` to profiler. A `run_meta` is /// generated by a TensorFlow session run call. `step` is the key /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// `step` in `options` to selectively profile the corresponding `run_meta`. /// Multiple different `run_meta` can be keyed by the same `step` in order /// to group them together. void AddStep(int64 step, const RunMetadata& run_meta); /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are contected by the op inputs/outputs. + /// Each node is an op and the nodes are connected by the op inputs/outputs. GraphNodeProto ProfileGraph(const Options& options); /// Profiles the model by organizing nodes in name scope structure. /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a filesystem tree. + /// scope, similar to a file system tree. /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. GraphNodeProto ProfileNameScope(const Options& options); diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc index a50461cafd6527..beb574061bea8d 100644 --- a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc @@ -17,4 +17,6 @@ limitations under the License. #define __CUDACC__ #include "crt/host_config.h" -int main(void) { return 0; } +int main(void) { + return 0; +} diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py new file mode 100644 index 00000000000000..4ed7268e7a9212 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py @@ -0,0 +1,325 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental methods for tf.feature_column sequence input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import abc +import collections + + +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# TODO(b/73160931): Fix pydoc. +# pylint: disable=g-doc-args,missing-docstring,protected-access +# TODO(b/73827486): Support SequenceExample. + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True, + scope=None): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc._clean_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, _SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + scope, default_name='sequence_input_layer', values=features.values()): + builder = fc._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + output_tensors.append( + array_ops.reshape( + dense_tensor, + shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) + # TODO(b/73160931): Verify sequence_length equality. + return array_ops.concat(output_tensors, -1), sequence_lengths[0] + + +# TODO(b/73160931): Add remaining categorical columns. +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + return _SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +# TODO(b/73160931): Merge with embedding_column +def _sequence_embedding_column( + categorical_column, dimension, initializer=None, ckpt_to_load_from=None, + tensor_name_in_ckpt=None, max_norm=None, trainable=True): + if not isinstance(categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'categorical_column must be of type _SequenceCategoricalColumn. ' + 'Given (type {}): {}'.format( + type(categorical_column), categorical_column)) + return _SequenceEmbeddingColumn( + fc.embedding_column( + categorical_column, + dimension=dimension, + initializer=initializer, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32): + # TODO(b/73160931): Add validations. + return _SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype) + + +class _SequenceDenseColumn(fc._FeatureColumn): + """Represents dense sequence data.""" + + __metaclass__ = abc.ABCMeta + + TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name + 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length']) + + @abc.abstractproperty + def _variable_shape(self): + """`TensorShape` without batch and sequence dimensions.""" + pass + + @abc.abstractmethod + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + """Returns a `TensorSequenceLengthPair`.""" + pass + + +def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): + with ops.name_scope(None, 'sequence_length') as name_scope: + row_ids = sp_tensor.indices[:, 0] + column_ids = sp_tensor.indices[:, 1] + column_ids += array_ops.ones_like(column_ids) + seq_length = ( + math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # If the last n rows do not have ids, seq_length will have shape + # [batch_size - n]. Pad the remaining values with zeros. + n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] + padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) + return array_ops.concat([seq_length, padding], axis=0, name=name_scope) + + +class _SequenceCategoricalColumn( + fc._CategoricalColumn, + collections.namedtuple( + '_SequenceCategoricalColumn', ['categorical_column'])): + + @property + def name(self): + return self.categorical_column.name + + @property + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec + + def _transform_feature(self, inputs): + return self.categorical_column._transform_feature(inputs) + + @property + def _num_buckets(self): + return self.categorical_column._num_buckets + + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) + id_tensor = sparse_tensors.id_tensor + weight_tensor = sparse_tensors.weight_tensor + # Expands final dimension, so that embeddings are not combined during + # embedding lookup. + check_id_rank = check_ops.assert_equal( + array_ops.rank(id_tensor), 2, + data=[ + 'Column {} expected ID tensor of rank 2. '.format(self.name), + 'id_tensor shape: ', array_ops.shape(id_tensor)]) + with ops.control_dependencies([check_id_rank]): + id_tensor = sparse_ops.sparse_reshape( + id_tensor, + shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + if weight_tensor is not None: + check_weight_rank = check_ops.assert_equal( + array_ops.rank(weight_tensor), 2, + data=[ + 'Column {} expected weight tensor of rank 2.'.format(self.name), + 'weight_tensor shape:', array_ops.shape(weight_tensor)]) + with ops.control_dependencies([check_weight_rank]): + weight_tensor = sparse_ops.sparse_reshape( + weight_tensor, + shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) + + def _sequence_length(self, inputs): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) + return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor) + + +class _SequenceEmbeddingColumn( + _SequenceDenseColumn, + collections.namedtuple('_SequenceEmbeddingColumn', ['embedding_column'])): + + @property + def name(self): + return self.embedding_column.name + + @property + def _parse_example_spec(self): + return self.embedding_column._parse_example_spec + + def _transform_feature(self, inputs): + return self.embedding_column._transform_feature(inputs) + + @property + def _variable_shape(self): + return self.embedding_column._variable_shape + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + dense_tensor = self.embedding_column._get_dense_tensor( + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sequence_length = self.embedding_column.categorical_column._sequence_length( + inputs) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + +class _SequenceNumericColumn( + _SequenceDenseColumn, + collections.namedtuple( + '_SequenceNumericColumn', + ['key', 'shape', 'default_value', 'dtype'])): + + @property + def name(self): + return self.key + + @property + def _parse_example_spec(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.key) + + @property + def _variable_shape(self): + return tensor_shape.TensorShape(self.shape) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + sp_tensor = inputs.get(self) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + sequence_length = _sequence_length_from_sparse_tensor( + sp_tensor, num_elements=self._variable_shape.num_elements()) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + +# pylint: enable=g-doc-args,missing-docstring,protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py new file mode 100644 index 00000000000000..59674869a27c3a --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py @@ -0,0 +1,471 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequential_feature_column as sfc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase): + + def test_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = sfc._sequence_embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = sfc._sequence_embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_input_layer = [ + [[0.], [1.]], + [[10.], [0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column_multi_dim(self): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + # The output of numeric_column._get_dense_tensor should be flattened. + expected_input_layer = [ + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_inputs3d(self): + """Tests _get_sparse_tensors when the input is already 3D Tensor.""" + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 2, 0), + dense_shape=(2, 2, 1)) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'Column aaa expected ID tensor of rank 2\.\s*' + r'id_tensor shape:\s*\[2 2 1\]'): + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': inputs})) + with monitored_session.MonitoredSession() as sess: + id_weight_pair.id_tensor.eval(session=sess) + + def test_sequence_length(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_zeros(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((1, 0), (3, 0), (3, 1)), + values=(1, 2, 0), + dense_shape=(5, 2)) + expected_sequence_length = [0, 1, 0, 2, 0] + + sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + expected_lookups = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceNumericColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_dense_tensor = [ + [[0.], [1.]], + [[10.], [0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_shape(self): + """Tests get_sequence_dense_tensor with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_dense_tensor = [ + [[0., 1., 2.], [3., 4., 5.]], + [[10., 11., 12.], [0., 0., 0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_dense_tensor_multi_dim(self): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + expected_dense_tensor = [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_sequence_length(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_shape(self): + """Tests _sequence_length with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 5549df971db277..45eb108586bed0 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -71,10 +71,11 @@ def get_cyclegan_model(): class SummariesTest(test.TestCase): - def _test_add_gan_model_image_summaries_impl( - self, get_model_fn, expected_num_summary_ops, model_summaries): - summaries.add_gan_model_image_summaries( - get_model_fn(), grid_size=2, model_summaries=model_summaries) + def _test_add_gan_model_image_summaries_impl(self, get_model_fn, + expected_num_summary_ops, + model_summaries): + summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, + model_summaries=model_summaries) self.assertEquals(expected_num_summary_ops, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 559c0c63dae891..350bcb3bca11b4 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -58,12 +58,12 @@ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', 'dropout', - 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'images_to_sequence', - 'layer_norm', 'linear', 'pool', 'max_pool2d', 'max_pool3d', - 'one_hot_encoding', 'relu', 'relu6', 'repeat', 'scale_gradient', - 'separable_conv2d', 'separable_convolution2d', 'sequence_to_images', - 'softmax', 'spatial_softmax', 'stack', 'unit_norm', + 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', + 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', + 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', + 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', + 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected', 'legacy_linear', 'legacy_relu', 'maxout' ] @@ -2718,7 +2718,8 @@ def sequence_to_images(inputs, num_batches = -1 else: num_batches = num_batches // height - reshaped = array_ops.reshape(inputs, [width, num_batches, height, depth]) + reshaped = array_ops.reshape(inputs, + [width, num_batches, height, depth]) if output_data_format == 'channels_first': outputs = array_ops.transpose(reshaped, [1, 3, 2, 0]) else: diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index ba70432c48630b..997f910a2a9756 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -3447,8 +3447,9 @@ def testImagesToSequenceDims(self): num_time_steps = 11 num_channels = 5 desired_height = 7 - sequence = np.random.uniform( - size=(num_time_steps, num_batches, num_channels)).astype(np.float32) + sequence = np.random.uniform(size=(num_time_steps, + num_batches, + num_channels)).astype(np.float32) output = _layers.sequence_to_images(sequence, desired_height) self.assertListEqual(output.get_shape().as_list(), [2, 7, 11, 5]) @@ -3457,10 +3458,12 @@ def testImagesToSequenceNCHW(self): num_time_steps = 11 num_channels = 5 desired_height = 7 - sequence = np.random.uniform( - size=(num_time_steps, num_batches, num_channels)).astype(np.float32) - output = _layers.sequence_to_images( - sequence, desired_height, output_data_format='channels_first') + sequence = np.random.uniform(size=(num_time_steps, + num_batches, + num_channels)).astype(np.float32) + output = _layers.sequence_to_images(sequence, + desired_height, + output_data_format='channels_first') self.assertListEqual(output.get_shape().as_list(), [2, 5, 7, 11]) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index 2c91be9d62db58..c57bb348c5b386 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -20,6 +20,9 @@ import android.graphics.Bitmap; import android.os.SystemClock; import android.util.Log; + +import org.tensorflow.lite.Interpreter; + import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; @@ -34,9 +37,10 @@ import java.util.List; import java.util.Map; import java.util.PriorityQueue; -import org.tensorflow.lite.Interpreter; -/** Classifies images with Tensorflow Lite. */ +/** + * Classifies images with Tensorflow Lite. + */ public abstract class ImageClassifier { /** Tag for the {@link Log}. */ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java index 3108422952a58c..be17b85e0cd937 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java @@ -16,22 +16,24 @@ package com.example.android.tflitecamerademo; import android.app.Activity; + import java.io.IOException; /** - * This classifier works with the Inception-v3 slim model. It applies floating point inference - * rather than using a quantized model. + * This classifier works with the Inception-v3 slim model. + * It applies floating point inference rather than using a quantized model. */ public class ImageClassifierFloatInception extends ImageClassifier { - /** The inception net requires additional normalization of the used input. */ + /** + * The inception net requires additional normalization of the used input. + */ private static final int IMAGE_MEAN = 128; - private static final float IMAGE_STD = 128.0f; /** - * An array to hold inference results, to be feed into Tensorflow Lite as outputs. This isn't part - * of the super class, because we need a primitive array here. + * An array to hold inference results, to be feed into Tensorflow Lite as outputs. + * This isn't part of the super class, because we need a primitive array here. */ private float[][] labelProbArray = null; diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java index ee89dbd375eeea..c533de7927050d 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java @@ -16,14 +16,17 @@ package com.example.android.tflitecamerademo; import android.app.Activity; + import java.io.IOException; -/** This classifier works with the quantized MobileNet model. */ +/** + * This classifier works with the quantized MobileNet model. + */ public class ImageClassifierQuantizedMobileNet extends ImageClassifier { /** - * An array to hold inference results, to be feed into Tensorflow Lite as outputs. This isn't part - * of the super class, because we need a primitive array here. + * An array to hold inference results, to be feed into Tensorflow Lite as outputs. + * This isn't part of the super class, because we need a primitive array here. */ private byte[][] labelProbArray = null; diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 883c7f270dcefa..780401e052733c 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 2481add76912ad..5488b71fcf6440 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -36,6 +36,7 @@ import zipfile import numpy as np from six import StringIO +from six.moves import xrange # TODO(aselle): Disable GPU for now os.environ["CUDA_VISIBLE_DEVICES"] = "-1" diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index f21915ffbc00aa..63fdd91d368d97 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1585,7 +1585,8 @@ def _cell_output(self, cell): with self.test_session() as sess: init = init_ops.constant_initializer(0.5) - with variable_scope.variable_scope("root", initializer=init): + with variable_scope.variable_scope("root", + initializer=init): x = array_ops.zeros([1, 2]) c0 = array_ops.zeros([1, 2]) h0 = array_ops.zeros([1, 2]) @@ -1595,12 +1596,11 @@ def _cell_output(self, cell): xout, sout = cell()(x, state0) sess.run([variables.global_variables_initializer()]) - res = sess.run( - [xout, sout], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - }) + res = sess.run([xout, sout], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + }) actual_state_c = res[1].c actual_state_h = res[1].h @@ -1611,8 +1611,9 @@ def testBasicCell(self): """Tests cell w/o peepholes and w/o normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=False) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=False) actual_c, actual_h = self._cell_output(cell) @@ -1626,8 +1627,9 @@ def testNonbasicCell(self): """Tests cell with peepholes and w/o normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=True) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=True) actual_c, actual_h = self._cell_output(cell) @@ -1641,8 +1643,9 @@ def testBasicCellWithNorm(self): """Tests cell w/o peepholes and with normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=False) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=False) actual_c, actual_h = self._cell_output(cell) @@ -1656,8 +1659,9 @@ def testNonBasicCellWithNorm(self): """Tests cell with peepholes and with normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=True) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=True) actual_c, actual_h = self._cell_output(cell) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 6e57ccd6dd21ed..03fe31abf736c0 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -722,7 +722,7 @@ def _mask_probs(probs, eos_token, finished): eos_token, vocab_size, dtype=probs.dtype, - on_value=0., + on_value=ops.convert_to_tensor(0., dtype=probs.dtype), off_value=probs.dtype.min) finished_probs = array_ops.tile( array_ops.reshape(finished_row, [1, 1, -1]), diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py index ad5e985487190e..b3343aef47d9f3 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py @@ -221,7 +221,7 @@ def parallel_read(data_sources, the data will be cycled through indefinitely. num_readers: a integer, number of Readers to create. reader_kwargs: an optional dict, of kwargs for the reader. - shuffle: boolean, wether should shuffle the files and the records by using + shuffle: boolean, whether should shuffle the files and the records by using RandomShuffleQueue as common_queue. dtypes: A list of types. The length of dtypes must equal the number of elements in each record. If it is None it will default to diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index 04e6b0a735320d..dc3e9fe79d32a1 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -468,7 +468,7 @@ class FixedSizeSparseClassificationGrowStats : public ClassificationStats { void PackToProto(FertileSlot* slot) const override; void InitLeafClassStats(int best_split_index, LeafStat* left_stats, - LeafStat* right_stats) const; + LeafStat* right_stats) const override; protected: void ClassificationAddSplitStats() override { diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 3b7b68f61b0a8b..c832c6f2e0cefe 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -47,7 +47,10 @@ tf_cuda_cc_test( tf_custom_op_library( name = "python/ops/_trt_engine_op.so", - srcs = ["ops/trt_engine_op.cc"], + srcs = [ + "ops/trt_calib_op.cc", + "ops/trt_engine_op.cc", + ], deps = [ ":trt_engine_op_kernel", ":trt_shape_function", @@ -71,11 +74,18 @@ tf_cuda_library( cc_library( name = "trt_engine_op_kernel", - srcs = ["kernels/trt_engine_op.cc"], - hdrs = ["kernels/trt_engine_op.h"], + srcs = [ + "kernels/trt_calib_op.cc", + "kernels/trt_engine_op.cc", + ], + hdrs = [ + "kernels/trt_calib_op.h", + "kernels/trt_engine_op.h", + ], copts = tf_copts(), deps = [ ":trt_logging", + ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", @@ -87,7 +97,10 @@ cc_library( ) tf_gen_op_libs( - op_lib_names = ["trt_engine_op"], + op_lib_names = [ + "trt_engine_op", + "trt_calib_op", + ], deps = if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -109,6 +122,7 @@ tf_gen_op_wrapper_py( name = "trt_engine_op", gen_locally = True, deps = [ + ":trt_calib_op_op_lib", ":trt_engine_op_op_lib", ":trt_logging", ":trt_shape_function", @@ -172,6 +186,27 @@ tf_py_wrap_cc( ], ) +tf_cuda_library( + name = "trt_resources", + srcs = [ + "resources/trt_int8_calibrator.cc", + "resources/trt_resource_manager.cc", + ], + hdrs = [ + "resources/trt_int8_calibrator.h", + "resources/trt_resource_manager.h", + "resources/trt_resources.h", + ], + deps = [ + ":trt_logging", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + # Library for the node-level conversion portion of TensorRT operation creation tf_cuda_library( name = "trt_conversion", @@ -186,6 +221,7 @@ tf_cuda_library( deps = [ ":segment", ":trt_logging", + ":trt_resources", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 4003ba056d28c6..9ee717dd7fb1ef 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -809,9 +809,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + + " not supported at: " + + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -1471,13 +1471,13 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( << std::to_string(op_info_vec.size()); // TODO(ben,jie): update TRT input format/dimension - nvinfer1::DimsCHW input_dim_pseudo_chw; - for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1; + nvinfer1::DimsCHW input_dim_psuedo_chw; + for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; for (int i = 1; i < op_info.shape().dim_size(); i++) { VLOG(2) << "dimension: " << i << " , size: " << op_info.shape().dim(i).size(); - input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size(); + input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size(); } // TODO(ben,jie): proper way to restore input tensor name? @@ -1486,7 +1486,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( input_tensor_name = node_name + ":" + std::to_string(output_idx); nvinfer1::ITensor* input_tensor = converter.network()->addInput( - input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); + input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); if (!input_tensor) return tensorflow::errors::InvalidArgument( diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc new file mode 100644 index 00000000000000..1dcb87e7683ad7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc @@ -0,0 +1,129 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); + OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); + OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ + } + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(FATAL) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } + } +} + +void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { + // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. + auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibOps"); + tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); + + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + int num_inputs = ctx->num_inputs(); + // first run instantiate calibrator + if (calib_res->calibrator_ == nullptr) { + dev_tensors_.resize(num_inputs); + int batch_size = ctx->input(0).dim_size(0); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + OP_REQUIRES_OK(ctx, + ctx->allocate_persistent(t.dtype(), t.shape(), + &dev_tensors_.at(i), nullptr)); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + device_buffers_.emplace(input_names_.at(i), + std::pair( + device_address, device_tensor->TotalBytes())); + } + + calib_res->calibrator_ = + new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); + string label(resource_name_); + calib_res->thr_ = new std::thread([calib_res, label]() { + VLOG(1) << "Starting calibration thread, Calibration Resource @ " + << calib_res; + calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); + calib_res->builder_->setInt8Mode(true); + calib_res->engine_ = calib_res->builder_->buildCudaEngine( + *calib_res->network_); // will loop until we terminate calibrator + VLOG(1) << "Calibration loop terminated " << label; + }); + VLOG(1) << "initialized calibrator resource"; + } // calibrator initialized + + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), + device_tensor->TotalBytes()); // use the tensor so FW keeps it + input_data.emplace(input_names_.at(i), data_address); + ctx->set_output(i, t); + } + VLOG(2) << "Filled map for sending"; + calib_res->calibrator_->setBatch(input_data); + VLOG(2) << "Passed calibration data"; + // TODO(aaroey): make sure we wait for the completion of calibration on the + // last batch in future PR. +}; + +#undef TYPECASE + +REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); + +} // namespace tensorrt +} // namespace tensorflow +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h new file mode 100644 index 00000000000000..23df9db32f077a --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H +#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H + +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +namespace tensorflow { +namespace tensorrt { +// TODO(sami): Convert this to async kernel! +class TRTCalibOp : public OpKernel { + public: + explicit TRTCalibOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + private: + string resource_name_; + std::vector segment_nodes_; + std::vector input_names_; + std::vector shapes_; + std::unordered_map> device_buffers_; + std::vector dev_tensors_; +}; +} // namespace tensorrt +} // namespace tensorflow +#endif +#endif +#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc new file mode 100644 index 00000000000000..4835e5065068ec --- /dev/null +++ b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +namespace tensorflow { + +REGISTER_OP("TRTCalibOp") + .Attr("segment_nodes: list(string)") // names of the ops in segment + .Attr("segment_output_names: list(string)") // names of the output ops in + // segment + .Attr("input_names: list(string)") // names of the inputs for + // passing into tensorrt + .Attr("resource_name: string") + .Attr("InT: list({int8, float16, float32})") + .Input("in_tensor: InT") + .Output("out_tensor: InT") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_inputs(); i++) { + c->set_output(i, c->input(i)); + } + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc new file mode 100644 index 00000000000000..3d5cc76c4256be --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" + +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda_runtime_api.h" + +namespace tensorflow { +namespace tensorrt { + +// set the batch size before constructing the thread to execute engine +int TRTInt8Calibrator::getBatchSize() const { return batch_size_; } + +TRTInt8Calibrator::TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name) + : batch_size_(batch_size), + done_(false), + dev_buffers_(dev_buffers), + calib_running_(false), + engine_name_(engine_name) {} + +bool TRTInt8Calibrator::setBatch( + const std::unordered_map& data) { + // TODO(aaroey): make sure that in future PR: + // 1. the mutex_lock is outside of the loop + // 2. wait() is used instead of wait_for() + // 3. done_ is to be protected by the mutex + // 4. the first batch is not missed + if (done_) return false; + while (calib_running_.load( + std::memory_order_acquire)) { // wait while calibration is running + tensorflow::mutex_lock l(cond_mtx_); + cond_.wait_for(l, std::chrono::milliseconds(50)); + if (done_) return false; + } + VLOG(1) << "Set Batch Waiting finished"; + for (const auto it : data) { + auto devptr = dev_buffers_.find(it.first); + if (devptr == dev_buffers_.end()) { + LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first + << "' does not match with the buffer names"; + } + const auto& d = devptr->second; + + // TODO(aaroey): we should not use sync copy on default stream. Make sure + // stream->ThenMemcpy() is used in future PRs. + auto status = + cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice); + if (status != cudaSuccess) { + LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first + << "' failed with " << status; + } + } + calib_running_.store(true, std::memory_order_release); // release builder + cond_.notify_all(); + return true; +} + +bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, + int num_bindings) { + calib_running_.store(false, std::memory_order_release); // wait for new batch + cond_.notify_all(); + while (!calib_running_.load( + std::memory_order_acquire)) { // wait until new batch arrives + tensorflow::mutex_lock l(cond_mtx_); + cond_.wait_for(l, std::chrono::milliseconds(50)); + if (done_) return false; + } + if (done_) { + return false; + } + + for (int i = 0; i < num_bindings; i++) { + auto it = dev_buffers_.find(names[i]); + if (it == dev_buffers_.end()) { + LOG(FATAL) << "Calibration engine asked for unknown tensor name '" + << names[i] << "' at position " << i; + } + + bindings[i] = it->second.first; + } + return true; +} + +const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { + return nullptr; +} + +void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, + std::size_t length) {} +TRTInt8Calibrator::~TRTInt8Calibrator() { + VLOG(1) << "Destroying calibrator for " << engine_name_; +} + +} // namespace tensorrt +} // namespace tensorflow +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h new file mode 100644 index 00000000000000..8830f7efe75b42 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ + +#include +#include +#include +#include +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" +namespace tensorflow { +namespace tensorrt { +// This class provides a 1 element queue to match TFs push model to +// TRTs pull model for calibration. When TRT implements a means for +// a push calibration This class should be updated accordingly + +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { + public: + TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], + int num_bindings) override; + bool setBatch(const std::unordered_map& data); + void setDone() { done_ = true; } + const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; + ~TRTInt8Calibrator(); + + private: + const int batch_size_; + tensorflow::mutex cond_mtx_; // mutex for condition_variable + tensorflow::condition_variable cond_; // condition variable to implement + // producer-consumer queue for + // calibration + bool done_; + const std::unordered_map> + dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with + // buffer names + std::atomic_bool calib_running_; + string engine_name_; +}; +} // namespace tensorrt +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc new file mode 100644 index 00000000000000..e663eed4dd6704 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace tensorrt { + +std::shared_ptr +tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { + // mutex is held for lookup only. Most instantiations where mutex will be held + // longer will be during op creation and should be ok. + tensorflow::mutex_lock lock(map_mutex_); + auto s = managers_.find(op_name); + if (s == managers_.end()) { + auto it = managers_.emplace( + op_name, std::make_shared(op_name)); + VLOG(1) << "Returning a new manager " << op_name; + return it.first->second; + } + VLOG(1) << "Returning old manager " << op_name; + return s->second; +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h new file mode 100644 index 00000000000000..5f8ad491d3c13e --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#include + +#include +#include +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace tensorrt { + +class TRTResourceManager { + TRTResourceManager() = default; + + public: + static std::shared_ptr instance() { + static std::shared_ptr instance_( + new TRTResourceManager); + return instance_; + } + // returns a manager for given op, if it doesn't exists it creates one + std::shared_ptr getManager(const string& op_name); + + private: + std::unordered_map> + managers_; + tensorflow::mutex map_mutex_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCE_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h new file mode 100644 index 00000000000000..3c85968ae7acf5 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_ + +#include +#include +#include +#include +#include +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/framework/resource_mgr.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +class TRTCalibrationResource : public tensorflow::ResourceBase { + public: + TRTCalibrationResource() + : calibrator_(nullptr), + builder_(nullptr), + network_(nullptr), + engine_(nullptr), + logger_(nullptr), + thr_(nullptr) {} + string DebugString() override { + std::stringstream oss; + oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl + << " Builder = " << std::hex << builder_ << std::dec << std::endl + << " Network = " << std::hex << network_ << std::dec << std::endl + << " Engine = " << std::hex << engine_ << std::dec << std::endl + << " Logger = " << std::hex << logger_ << std::dec << std::endl + << " Thread = " << std::hex << thr_ << std::dec << std::endl; + return oss.str(); + } + ~TRTCalibrationResource() { + VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + } + TRTInt8Calibrator* calibrator_; + nvinfer1::IBuilder* builder_; + nvinfer1::INetworkDefinition* network_; + nvinfer1::ICudaEngine* engine_; + tensorflow::tensorrt::Logger* logger_; + // TODO(sami): Use threadpool threads! + std::thread* thr_; +}; + +class TRTWeightStore : public tensorflow::ResourceBase { + public: + TRTWeightStore() {} + std::list> store_; + string DebugString() override { + std::stringstream oss; + size_t lenBytes = 0; + for (const auto& v : store_) { + lenBytes += v.size() * sizeof(uint8_t); + } + oss << " Number of entries = " << store_.size() << std::endl + << " Total number of bytes = " + << store_.size() * sizeof(std::vector) + lenBytes << std::endl; + return oss.str(); + } + virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } +}; + +class TRTEngineResource : public tensorflow::ResourceBase { + public: + TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){}; + string DebugString() override { return string(""); } + nvinfer1::IRuntime* runtime_; + nvinfer1::IExecutionContext* ctx_; +}; + +} // namespace tensorrt +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCEMGR_TRTRESOURCES_H_ +#endif +#endif diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index fff972c1f3277a..ed3ed4c0e1731d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -140,11 +140,13 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:state_ops", + "//tensorflow/python:summary", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:export", "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:metric_keys", ], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 8731b10923af9b..f4d9351432ef32 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -26,6 +26,7 @@ from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -35,6 +36,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +from tensorflow.python.summary import summary def time_series_regression_head(model, @@ -71,14 +73,34 @@ def __init__(self, self.input_statistics_generator = input_statistics_generator self._name = name + @property + def name(self): + return self._name + + # TODO(terrytangyuan): consolidate `model_outputs` and `_Head.LossSpec` + # once `_Head.create_loss` becomes extendable + def create_loss(self, features, mode, logits=None, labels=None): + """See `_Head`.""" + model_outputs = self.state_manager.define_loss( + self.model, features, mode) + summary.scalar( + head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), + model_outputs.loss) + return model_outputs + + @property + def logits_dimension(self): + """See `_Head`.""" + return 1 + def _train_ops(self, features): """Add training ops to the graph.""" + mode = estimator_lib.ModeKeys.TRAIN with variable_scope.variable_scope( "model", # Use ResourceVariables to avoid race conditions. use_resource=True): - model_outputs = self.state_manager.define_loss( - self.model, features, estimator_lib.ModeKeys.TRAIN) + model_outputs = self.create_loss(features, mode) train_op = optimizers.optimize_loss( model_outputs.loss, @@ -88,31 +110,14 @@ def _train_ops(self, features): learning_rate=None) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.TRAIN, + mode=mode, train_op=train_op) - # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name` - @property - def name(self): - return self._name - - # TODO(terrytangyuan): unused for now. Need to decouple - # `state_manager.define_loss` to satisfy the extendable return signature of - # `_Head.create_loss`. - def create_loss(self, features, mode, logits, labels): - """See `_Head`.""" - return None - - # TODO(terrytangyuan): check label dimension - @property - def logits_dimension(self): - return None - def _evaluate_ops(self, features): """Add ops for evaluation (aka filtering) to the graph.""" + mode = estimator_lib.ModeKeys.EVAL with variable_scope.variable_scope("model", use_resource=True): - model_outputs = self.state_manager.define_loss( - self.model, features, estimator_lib.ModeKeys.EVAL) + model_outputs = self.create_loss(features, mode) metrics = {} # Just output in-sample predictions for the last chunk seen for prediction_key, prediction_value in model_outputs.predictions.items(): @@ -125,7 +130,7 @@ def _evaluate_ops(self, features): model_outputs.end_state)) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.EVAL, + mode=mode, eval_metric_ops=metrics, predictions={}) @@ -143,9 +148,8 @@ def _serving_ops(self, features): with variable_scope.variable_scope("model", use_resource=True): prediction_outputs = self.model.predict(features=features) with variable_scope.variable_scope("model", reuse=True): - filtering_outputs = self.state_manager.define_loss( - self.model, features, estimator_lib.ModeKeys.EVAL) - + filtering_outputs = self.create_loss( + features, estimator_lib.ModeKeys.EVAL) return estimator_lib.EstimatorSpec( mode=estimator_lib.ModeKeys.PREDICT, export_outputs={ @@ -194,7 +198,7 @@ def _gather_state(self, features): def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" - with ops.name_scope("head"): + with ops.name_scope(self._name, "head"): if labels: raise ValueError( "The model received a `labels` dictionary, which is " diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md index 58fed4e5cb4c24..4b6104a8b4d542 100644 --- a/tensorflow/contrib/verbs/README.md +++ b/tensorflow/contrib/verbs/README.md @@ -93,7 +93,7 @@ When the receiver receives the RDMA write, it will locate the relevant **RdmaTen 1. When the sender receives a tensor request, the source tensor may or may not be ready yet. The situation is handled through a process of tag matching: * If the request arrives before the tensor is ready, then a callback is put in a local table, and will be invoked once the tensor arrives. - * If the tensor is ready before the request arives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediately. + * If the tensor is ready before the request arrives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediately. In code it is done by calling **RecvLocalAsync()**, which receives the tensor's key, step-id, and the callback. 2. When the callback is invoked, the relevant tensor is removed from the tag matching table. In the case where we need to send the tensor's meta-data, the **RdmaTensorResponse** will store a copy of the tensor until the re-request arrives. 3. The sending of protocol messages (**RDMA_MESSAGE_TENSOR_REQUEST**, **RDMA_MESSAGE_META_DATA_RESPONSE** and **RDMA_MESSAGE_TENSOR_RE_REQUEST**) is done by the class **RdmaMessageBuffer**. All messages are sent using RDMA writes from/to fixed messages buffers. This implies that we cannot send on a specific channel more than one message at a time. In order to synchronize the messages, the **RdmaMessageBuffer** holds the a local and remote buffer statuses which can be either busy or idle. When a write is issued, both statuses will be changed to busy. When the write-complete event is received, the local status is changed to idle. When the write is received on the remote side, the remote side will parse the message, and return an ACK back to the sending side on which the sending side will update the remote status to idle. When both the local and remote statuses are idle, the next message can be sent. diff --git a/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md index 956b8f2147cf81..da6fdd48e19e9d 100644 --- a/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md +++ b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md @@ -64,7 +64,7 @@ The protocol messages themselves will remain mostly unchanged at the first stage * type - The message type. * request_index - Request index. * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data. -* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors. +* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-request after meta-data update and reallocation of result/proxy tensors. * type - The message type. * name (name_size) - Name of the requested tensor. * step_id - Step ID. diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 7d95b6522c5149..86350a08e57e50 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt new file mode 100644 index 00000000000000..e21f56ba5b9268 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV2.pbtxt @@ -0,0 +1,85 @@ +op { + graph_op_name: "UniqueWithCountsV2" + in_arg { + name: "x" + description: < [1, 2, 4, 7, 8] +idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +count ==> [2, 1, 3, 1, 2] +``` + +For an `2-D` tensor `x` with `axis = 0`: + +``` +# tensor 'x' is [[1, 0, 0], +# [1, 0, 0], +# [2, 0, 0]] +y, idx, count = unique_with_counts(x, axis=0) +y ==> [[1, 0, 0], + [2, 0, 0]] +idx ==> [0, 0, 1] +count ==> [2, 1] +``` + +For an `2-D` tensor `x` with `axis = 1`: + +``` +# tensor 'x' is [[1, 0, 0], +# [1, 0, 0], +# [2, 0, 0]] +y, idx, count = unique_with_counts(x, axis=1) +y ==> [[1, 0], + [1, 0], + [2, 0]] +idx ==> [0, 1, 1] +count ==> [1, 2] +``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt index 4e69e0bc6302eb..4ca6780c95629d 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt @@ -14,20 +14,21 @@ Has same shape as data, except for dimension 0 which has size `num_segments`. END } - summary: "Computes the Max along segments of a tensor." + summary: "Computes the maximum along segments of a tensor." description: <::min()`. +If the maximum is empty for a given segment ID `i`, it outputs the smallest +possible value for the specific numeric type, +`output[i] = numeric_limits::lowest()`.
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt new file mode 100644 index 00000000000000..55ea69b5dd5f7f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -0,0 +1,33 @@ +op { + graph_op_name: "UnsortedSegmentMin" + in_arg { + name: "segment_ids" + description: <::max()`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt new file mode 100644 index 00000000000000..577ff53d60c5a1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -0,0 +1,32 @@ +op { + graph_op_name: "UnsortedSegmentProd" + in_arg { + name: "segment_ids" + description: <