diff --git a/.bazelrc b/.bazelrc index 1014d1506c8d00..8c645260972dca 100644 --- a/.bazelrc +++ b/.bazelrc @@ -594,6 +594,12 @@ build:release_cpu_linux --config=avx_linux build:release_cpu_linux --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" test:release_cpu_linux --test_env=LD_LIBRARY_PATH +# manylinux2014 config for cpu +build:release_cpu_linux_manylinux2014 --config=release_base +build:release_cpu_linux_manylinux2014 --config=avx_linux +build:release_cpu_linux_manylinux2014 --crosstool_top="@ubuntu18.04-gcc8_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" +test:release_cpu_linux_manylinux2014 --test_env=LD_LIBRARY_PATH + build:release_cpu_macos --config=release_base build:release_cpu_macos --config=avx_linux @@ -616,6 +622,12 @@ build:release_gpu_linux_11_4 --action_env=TF_CUDA_VERSION="11.4" build:release_gpu_linux_11_4 --action_env=TF_CUDNN_VERSION="8.2" build:release_gpu_linux_11_4 --crosstool_top=@ubuntu18.04-gcc7_manylinux2010-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain +# manylinux2014 config for gpu +build:release_gpu_linux_manylinux2014 --config=release_gpu_linux +build:release_gpu_linux_manylinux2014 --action_env=GCC_HOST_COMPILER_PATH="/dt8/usr/bin/gcc" +build:release_gpu_linux_manylinux2014 --crosstool_top=@ubuntu18.04-gcc8_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain + + build:release_cpu_windows --config=release_base build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true diff --git a/.bazelversion b/.bazelversion index 0b2eb36f508590..fae6e3d04b2cab 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -3.7.2 +4.2.1 diff --git a/ACKNOWLEDGMENTS b/ACKNOWLEDGMENTS deleted file mode 100644 index 7eb20334c45cc7..00000000000000 --- a/ACKNOWLEDGMENTS +++ /dev/null @@ -1,50 +0,0 @@ -## Some of TensorFlow's code is derived from Caffe, which is subject to the following copyright notice: - -COPYRIGHT - -All contributions by the University of California: - -Copyright (c) 2014, The Regents of the University of California (Regents) -All rights reserved. - -All other contributions: - -Copyright (c) 2014, the respective contributors -All rights reserved. - -Caffe uses a shared copyright model: each contributor holds copyright over -their contributions to Caffe. The project versioning records all such -contribution and copyright details. If a contributor wants to further mark -their specific copyright on a particular contribution, they should indicate -their copyright solely in the commit message of the change when it is -committed. - -LICENSE - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -CONTRIBUTION AGREEMENT - -By contributing to the BVLC/caffe repository through pull-request, comment, -or otherwise, the contributor releases their content to the -license and copyright terms herein. - diff --git a/LICENSE b/LICENSE index 9f6ace032ef128..12d255f8e0f049 100644 --- a/LICENSE +++ b/LICENSE @@ -200,31 +200,27 @@ See the License for the specific language governing permissions and limitations under the License. ------------------- -Files: third_party/compute_library/... - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - ------------------- -Files: ACKNOWLEDGEMENTS +## Some of TensorFlow's code is derived from Caffe, which is subject to the following copyright notice: + +COPYRIGHT + +All contributions by the University of California: + +Copyright (c) 2014, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: + +Copyright (c) 2014, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + LICENSE Redistribution and use in source and binary forms, with or without @@ -248,37 +244,8 @@ modification, are permitted provided that the following conditions are met: (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------------------- -Files: third_party/hexagon +CONTRIBUTION AGREEMENT -Copyright (c) 2016-2019, The Linux Foundation. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted (subject to the limitations in the -disclaimer below) provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - * Neither the name of The Linux Foundation nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE -GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT -HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED -WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE -GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER -IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR -OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN -IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. \ No newline at end of file diff --git a/RELEASE.md b/RELEASE.md index 81ab910333f592..8f825e8263de48 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -16,9 +16,12 @@ # Major Features and Improvements * `tf.lite`: - * Where operation support is added for these data types - 'int32/uint32/int8/uint8/int64' - * Add builtin support for `Bucketize` op on CPU. + * Added TFLite builtin op support for the following TF ops: + * `tf.raw_ops.Bucketize` op on CPU. + * `tf.where` op for data types `tf.int32`/`tf.uint32`/`tf.int8`/`tf.uint8`/`tf.int64`. + * `tf.random.normal` op for output data type `tf.float32` on CPU. + * `tf.random.uniform` op for output data type `tf.float32` on CPU. + * `tf.random.categorical` op for output data type `tf.int64` on CPU. * `tensorflow.experimental.tensorrt`: * `conversion_params` is now deprecated inside `TrtGraphConverterV2` in @@ -29,6 +32,16 @@ `.save()` function inside `TrtGraphConverterV2`. When `False`, the `.save()` function won't save any TRT engines that have been built. When `True` (default), the original behavior is preserved. +* `tf.tpu.experimental.embedding`: + * `tf.tpu.experimental.embedding.FeatureConfig` now takes an additional + argument `output_shape` which can specify the shape of the output + activation for the feature. + * `tf.tpu.experimental.embedding.TPUEmbedding` now has the same behavior + as `tf.tpu.experimental.embedding.serving_embedding_lookup` which can + take arbitrary rank of dense and sparse tensor. For ragged tensor, + though the input tensor remains to be rank 2, the activations now can be + rank 2 or above by specifying the output shape in the feature config + or via the build method. * @@ -42,6 +55,9 @@ * `tf.data`: * The optimization `parallel_batch` now becomes default if not disabled by users, which will parallelize copying of batch elements. + * Added the ability for `TensorSliceDataset` to identify and handle inputs + that are files. This enables creating hermetic SavedModels when using + datasets created from files. * `tf.lite`: * GPU @@ -161,6 +177,7 @@ This release contains contributions from many people at Google, as well as: * `tf.lite`: * Add experimental API `experimental_from_jax` to support conversion from Jax models to TensorFlow Lite. * Support uint32 data type for cast op. + * Support int8 data type for cast op. * Add experimental quantization debugger `tf.lite.QuantizationDebugger` * Add lite.experimental.authoring.compatible API * A Python decorator to provide a way to check TFLite compatibility diff --git a/configure.py b/configure.py index bff9abfe154797..6cdd109783e0e8 100644 --- a/configure.py +++ b/configure.py @@ -45,7 +45,7 @@ _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' _TF_CURRENT_BAZEL_VERSION = None -_TF_MIN_BAZEL_VERSION = '3.7.2' +_TF_MIN_BAZEL_VERSION = '4.2.1' _TF_MAX_BAZEL_VERSION = '4.99.0' NCCL_LIB_PATHS = [ diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 216e8ed0cd01d9..27c7919912d9d0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -46,7 +46,6 @@ licenses(["notice"]) exports_files([ "LICENSE", - "ACKNOWLEDGMENTS", # The leakr files are used by //third_party/cloud_tpu and # //third_party/tensorboard/google:copybara_config_test. "leakr_badwords.dic", diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h index 07a78f97bd5a9f..2132daf2cfa388 100644 --- a/tensorflow/c/eager/abstract_context.h +++ b/tensorflow/c/eager/abstract_context.h @@ -42,7 +42,7 @@ class AbstractContext { // Release any underlying resources, including the interface object. // // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own + // clients from directly destroying this object since it may manage its own // lifetime through ref counting. Thus clients MUST call Release() in order to // destroy an instance of this class. virtual void Release() = 0; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 0afb69bb82ce79..7ad77587d6fe70 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -119,7 +119,7 @@ TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(TFE_Context* ctx); // A tensorflow.ServerDef specifies remote workers (in addition to the current -// workers name). Operations created on this context can then be executed on +// workers name). Operations created in this context can then be executed on // any of these remote workers by setting an appropriate device. // // If the following is set, all servers identified by the @@ -134,7 +134,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, // type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors -// placed in memory of different devices or remote address spaces. +// placed in the memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, @@ -442,7 +442,7 @@ TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); // Ends a step. When there is no active step (that is, every started step has // been ended) step containers will be cleared. Note: it is not safe to call -// TFE_ContextEndStep while ops which rely on the step container may be running. +// TFE_ContextEndStep while ops that rely on the step container may be running. TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); #ifdef __cplusplus diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index d21cadfd0cbcdc..208ce427478b72 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -161,7 +161,7 @@ void TestFunctionWithPackedInput(const bool remote) { TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name); TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name); - // Add a sync point in order to make sure that variables have been initialized + // Add a sync point to make sure that variables have been initialized // before the function execution starts. TFE_ContextAsyncWait(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index f976b4b876c851..ee9cf9f950fd5e 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -140,7 +140,7 @@ TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge, typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell; TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet( TFE_MonitoringStringGaugeCell* cell, const char* value); -// Retrieves the string value and saves it in buffer. +// Retrieves the string value and saves it in the buffer. TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue( TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf); @@ -248,7 +248,7 @@ TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd( TFE_MonitoringSamplerCell* cell, double value); // Retrieves the current value of the cell. The return value is a HistogramProto -// saved in buffer. +// saved in the buffer. TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue( TFE_MonitoringSamplerCell* cell, TF_Buffer* buf); @@ -353,7 +353,7 @@ TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*); TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes( TFE_Executor*, TF_Status* status); -// When an error happens, any pending operations are discarded and newly issued +// When an error happens, any pending operations are discarded, and newly issued // ops return an error. This call clears the error state and re-enables // execution of newly issued ops. // @@ -362,12 +362,12 @@ TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes( // TODO(agarwal): mark the affected handles and raise errors if they are used. TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*); -// Sets a custom Executor for current thread. All nodes created by this thread -// will be added to this Executor. It will override current executor. +// Sets a custom Executor for the current thread. All nodes created by this +// thread will be added to this Executor. It will override the current executor. TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*, TFE_Executor*); -// Returns the Executor for current thread. +// Returns the Executor for the current thread. TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread( TFE_Context*); @@ -376,7 +376,7 @@ TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread( // Update an existing context with a new set of servers defined in a ServerDef // proto. Servers can be added to and removed from the list of remote workers -// in the context. New set of servers identified by the ServerDef must be up +// in the context. A New set of servers identified by the ServerDef must be up // when the context is updated. // // This API is for experimental usage and may be subject to change. @@ -527,8 +527,8 @@ typedef struct TFE_CustomDevice { // names of wrapped devices. // // There are currently no graph semantics implemented for registered custom -// devices, so executing tf.functions which contain operations placed on custom -// devices will fail. +// devices, so executing tf.functions which contain operations placed on the +// custom devices will fail. // // `device_name` must not name an existing physical or custom device. It must // follow the format: @@ -646,8 +646,8 @@ TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status); // Returns the status for the tensor handle. In TFRT, a tensor handle can carry -// error info if error happens. If so, status will be set with the error info. -// If not, status will be set as OK. +// error info if error happens. If so, the status will be set with the error +// info. If not, status will be set as OK. TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h, TF_Status* status); @@ -673,7 +673,7 @@ TF_CAPI_EXPORT extern void TFE_SetLogicalCpuDevices(TFE_Context* ctx, // setting the same key will lead to errors. // // Note that the key-values are only expected to be used for cluster -// configuration data, and should not be used for storing large amount of data +// configuration data, and should not be used for storing a large amount of data // or being accessed very frequently. TF_CAPI_EXPORT extern void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index beaca6c4ffd22f..e3a038489ff270 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -2174,92 +2174,249 @@ TEST(CAPI, ShareVariableAcrossContextsWorks) { worker_server2.release(); } +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index, + const string& host, int port) { + tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); + job_def->mutable_tasks()->at(task_index) = + tensorflow::strings::StrCat(host, ":", port); +} + +TEST(CAPI, ShareVariableAcrossContextsAfterUpdateContextWorks) { + tensorflow::ServerDef server_def_0 = GetServerDef(3); + server_def_0.mutable_default_session_config()->set_isolate_session_state( + false); + tensorflow::ServerDef server_def_1 = + ReplaceTaskInServerDef(server_def_0, /*task_index=*/0); + + // These server defs have task index set to 0. + string serialized_server_def_0 = server_def_0.SerializeAsString(); + string serialized_server_def_1 = server_def_1.SerializeAsString(); + + // Create two worker tasks. + server_def_0.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def_0, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + server_def_0.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def_0, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + // Create two contexts. + TFE_Context* ctx_0 = CreateContext(serialized_server_def_0, + /*isolate_session_state=*/false); + TFE_Context* ctx_1 = CreateContext(serialized_server_def_1, + /*isolate_session_state=*/false); + + // Remote device on `worker2`. + const char remote_device[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + // `ctx_0`, `ctx_1` contains `remote_device`. + { + const std::vector& device_names = ListDeviceNames(ctx_0); + ASSERT_TRUE(std::find(device_names.begin(), device_names.end(), + remote_device) != device_names.end()); + } + + { + const std::vector& device_names = ListDeviceNames(ctx_1); + ASSERT_TRUE(std::find(device_names.begin(), device_names.end(), + remote_device) != device_names.end()); + } + + // Create a variable using `ctx_0`. + // Replace worker1 using a new worker, and update the contexts. + // Read the variable using `ctx_1`. This read should succeed. + // + // 1. Create a variable on `remote_device`, using `ctx_0`. + TFE_TensorHandle* handle_0 = + CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var"); + + // 2. Wait for `var` to be created and initialized on the worker. + TF_Status* status = TF_NewStatus(); + TFE_ContextAsyncWait(ctx_0, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + int port = tensorflow::testing::PickUnusedPortOrDie(); + // 3. Replace worker1 with a new worker in server_def_0 and server_def_1. + ReplaceTaskInServerDef(&server_def_0, /*task_index=*/1, "localhost", port); + ReplaceTaskInServerDef(&server_def_1, /*task_index=*/1, "localhost", port); + // 4. Start a new task to replace worker1. + server_def_0.set_task_index(1); + worker_server1.release(); + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def_0, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + // 5a. Update `ctx_0` with updated `server_def_0`. + { + server_def_0.set_task_index(0); + string serialized_update = server_def_0.SerializeAsString(); + TF_Status* status = TF_NewStatus(); + TFE_ContextUpdateServerDef(ctx_0, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + } + + // 5b. Update `ctx_1` with updated `server_def_1`. + { + server_def_1.set_task_index(0); + string serialized_update = server_def_1.SerializeAsString(); + TF_Status* status = TF_NewStatus(); + TFE_ContextUpdateServerDef(ctx_1, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + } + + // 6. Read `var` using `ctx_1`. This read should succeed since `ctx_1` was + // created with `isolate_session_state` set to false, and update should + // preserve it. + { + // Create a handle to `var`, using `ctx_1`. + TFE_TensorHandle* var_handle = + CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var"); + + TFE_TensorHandle* handle_1 = nullptr; + int num_retvals = 1; + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Execute(op, &handle_1, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + ASSERT_EQ(1, num_retvals); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1)); + EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status)); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Read the value of tensor handle `handle_1`. + float value = 0.0f; + TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(sizeof(float), TF_TensorByteSize(t)); + memcpy(&value, TF_TensorData(t), sizeof(float)); + TF_DeleteTensor(t); + EXPECT_EQ(1.2f, value); + TFE_DeleteTensorHandle(handle_1); + TF_DeleteStatus(status); + TFE_DeleteTensorHandle(var_handle); + } + + TFE_DeleteTensorHandle(handle_0); + + TFE_DeleteContext(ctx_0); + TFE_DeleteContext(ctx_1); + + worker_server1.release(); + worker_server2.release(); +} + tensorflow::ServerDef CreateSingleHostServerDef( const tensorflow::ServerDef& cluster_server_def, int task_index) { tensorflow::ServerDef single_host_server_def; - single_host_server_def.set_job_name(cluster_server_def.job_name()); + single_host_server_def.set_job_name("worker"); single_host_server_def.set_protocol(cluster_server_def.protocol()); single_host_server_def.set_task_index(0); tensorflow::ClusterDef* cluster_def = single_host_server_def.mutable_cluster(); tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name(cluster_server_def.job_name()); + job_def->set_name("client"); // Add a client. - single_host_server_def.mutable_cluster() - ->mutable_job(0) - ->mutable_tasks() - ->insert( - {0, tensorflow::strings::StrCat( - "localhost:", tensorflow::testing::PickUnusedPortOrDie())}); + job_def->mutable_tasks()->insert( + {0, tensorflow::strings::StrCat( + "localhost:", tensorflow::testing::PickUnusedPortOrDie())}); + + tensorflow::JobDef* job_def2 = cluster_def->add_job(); + job_def2->set_name("worker"); // Copy over `host:port` at `task_index` for (auto task : cluster_server_def.cluster().job(0).tasks()) { if (task.first == task_index) { - single_host_server_def.mutable_cluster() - ->mutable_job(0) - ->mutable_tasks() - ->insert({task.first, task.second}); + job_def2->mutable_tasks()->insert({task.first, task.second}); } } return single_host_server_def; } +tensorflow::ServerDef GetClusterServerDef(const string& worker_job_name, + int num_workers) { + tensorflow::ServerDef server_def = GetServerDef(worker_job_name, num_workers); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + + // Add a client. + tensorflow::JobDef* job_def2 = cluster_def->add_job(); + job_def2->set_name("client"); + job_def2->mutable_tasks()->insert( + {0, tensorflow::strings::StrCat( + "localhost:", tensorflow::testing::PickUnusedPortOrDie())}); + return server_def; +} + TEST(CAPI, SingleHostServerDefWorks) { - // Create a server def that represents a 2-process cluster. + // Create a server def that represents a 2-process cluster and a client. // Example: // - // cluster { job { name: "localhost" - // tasks { key: 0 value: "localhost:14319" } <--client - // tasks { key: 1 value: "localhost:15022" } <--worker1 - // tasks { key: 2 value: "localhost:15023" } <--worker2 - // } } - // job_name: "localhost" protocol: "grpc" + // cluster { job { name: "worker" + // tasks { key: 0 value: "localhost:14522" } + // tasks { key: 1 value: "localhost:14523" } + // } + // job { name: "client" + // tasks { key: 0 value: "localhost:14524" } + // } + // } job_name: "worker" protocol: "grpc" // - tensorflow::ServerDef cluster_server_def = GetServerDef(3); - // These server defs have task index set to 0. - string serialized_cluster_server_def = cluster_server_def.SerializeAsString(); + tensorflow::ServerDef cluster_server_def = GetClusterServerDef("worker", 2); // Create two worker tasks, using single host server defs. // A single host server def contains a client and the remote host. // Example: // - // Worker2: - // cluster { job { name: "localhost" - // tasks { key: 0 value: "localhost:15226" } <--client - // tasks { key: 2 value: "localhost:15023" } <--worker2 - // } } - // job_name: "localhost" task_index: 2 protocol: "grpc" - // // Worker1: - // cluster { job { name: "localhost" - // tasks { key: 0 value: "localhost:15024" } <--client - // tasks { key: 1 value: "localhost:15022" } <--worker1 - // } } - // job_name: "localhost" task_index: 1 protocol: "grpc" + // cluster { job { name: "client" tasks { key: 0 value: "localhost:14525" } } + // job { name: "worker" tasks { key: 1 value: "localhost:14523" } } + // } job_name: "worker" task_index: 1 protocol: "grpc" + // + // Worker0: + // cluster { job { name: "client" tasks { key: 0 value: "localhost:14526" } } + // job { name: "worker" tasks { key: 0 value: "localhost:14522" } } + // } job_name: "worker" protocol: "grpc" // - // Create `worker_2` using single host server def `worker_2_server_def`. - tensorflow::ServerDef worker_2_server_def = - CreateSingleHostServerDef(cluster_server_def, 2); - worker_2_server_def.set_task_index(2); + // Create `worker_1` using single host server def `worker_1_server_def`. + tensorflow::ServerDef worker_1_server_def = + CreateSingleHostServerDef(cluster_server_def, 1); + worker_1_server_def.set_task_index(1); + worker_1_server_def.set_job_name("worker"); - std::unique_ptr worker_server2; - ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_2_server_def, + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_1_server_def, tensorflow::Env::Default(), - &worker_server2) + &worker_server1) .ok()); - ASSERT_TRUE(worker_server2->Start().ok()); + ASSERT_TRUE(worker_server1->Start().ok()); // Create context `local_ctx` using single host server def - - // `worker_2_server_def`. - worker_2_server_def.set_task_index(0); + // `worker_1_server_def`. + worker_1_server_def.set_task_index(0); + worker_1_server_def.set_job_name("client"); TFE_Context* local_ctx = - CreateContext(worker_2_server_def.SerializeAsString(), + CreateContext(worker_1_server_def.SerializeAsString(), /*isolate_session_state=*/false); - const char remote_device[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + const char remote_device[] = "/job:worker/replica:0/task:1/device:CPU:0"; // Create a variable `var` on `worker2` using `local_ctx`. TFE_TensorHandle* handle_0 = @@ -2270,21 +2427,24 @@ TEST(CAPI, SingleHostServerDefWorks) { TF_DeleteStatus(status); TFE_DeleteTensorHandle(handle_0); - // Create `worker1` using single host server def `worker_1_server_def`. - tensorflow::ServerDef worker_1_server_def = - CreateSingleHostServerDef(cluster_server_def, 1); - worker_1_server_def.set_task_index(1); + // Create `worker0` using single host server def `worker_0_server_def`. + tensorflow::ServerDef worker_0_server_def = + CreateSingleHostServerDef(cluster_server_def, 0); + worker_0_server_def.set_task_index(0); - std::unique_ptr worker_server1; - ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_1_server_def, + std::unique_ptr worker_server0; + ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_0_server_def, tensorflow::Env::Default(), - &worker_server1) + &worker_server0) .ok()); - ASSERT_TRUE(worker_server1->Start().ok()); + ASSERT_TRUE(worker_server0->Start().ok()); // Create a remote context, `remote_ctx`, using `cluster_server_def`. - TFE_Context* remote_ctx = CreateContext(serialized_cluster_server_def, - /*isolate_session_state=*/false); + cluster_server_def.set_task_index(0); + cluster_server_def.set_job_name("client"); + TFE_Context* remote_ctx = + CreateContext(cluster_server_def.SerializeAsString(), + /*isolate_session_state=*/false); // Read variable `var` using `remote_ctx`, created using `cluster_server_def`. { @@ -2326,7 +2486,7 @@ TEST(CAPI, SingleHostServerDefWorks) { TFE_DeleteContext(remote_ctx); worker_server1.release(); - worker_server2.release(); + worker_server0.release(); } } // namespace diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 04af1bd952c4e9..ce8546fb4f4186 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -119,7 +119,7 @@ TFE_Op* RecvOp(TFE_Context* ctx, const std::string& op_name, const std::string& send_device, const std::string& recv_device, tensorflow::uint64 send_device_incarnation); -// Return an 1-D INT32 tensor containing a single value 1. +// Return a 1-D INT32 tensor containing a single value 1. TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx); // Return an op taking minimum of `input` long `axis` dimension. diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index ee22695632fd12..41228f07e70fd4 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -32,7 +32,7 @@ extern "C" { // ----------------------------------------------------------------------------- // A TF_ExecutionContext stores knowledge about how to execute an operation. -// E.g. it could know whether we're in eager mode or in graph mode, keeps track +// E.g. it could know whether we're in eager mode or graph mode, keeps track // of gradient tapes, etc. typedef struct TF_ExecutionContext TF_ExecutionContext; diff --git a/tensorflow/c/eager/gradients_internal.h b/tensorflow/c/eager/gradients_internal.h index 5ddf017413a31d..1e14302c1721c1 100644 --- a/tensorflow/c/eager/gradients_internal.h +++ b/tensorflow/c/eager/gradients_internal.h @@ -24,7 +24,7 @@ namespace internal { // Helper functions which delegate to `AbstractOperation`, update // the state of the ForwardOperation and call the tape as appropriate. -// These APIs are mainly to faciliate testing and are subject to change. +// These APIs are mainly to facilitate testing and are subject to change. // Records the op name in the `ForwardOperation`. Status Reset(AbstractOperation*, const char* op, const char* raw_device_name, diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index 4a7586f0e5bce6..eab9314b3ec377 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -81,7 +81,7 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { // Release any underlying resources, including the interface object. // // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own + // clients from directly destroying this object since it may manage its own // lifetime through ref counting. Thus this must be allocated on the heap and // clients MUST call Release() in order to destroy an instance of this class. virtual void Release() = 0; diff --git a/tensorflow/cc/gradients/README.md b/tensorflow/cc/gradients/README.md index 3253163cc735cf..e2f7badcfebcfe 100644 --- a/tensorflow/cc/gradients/README.md +++ b/tensorflow/cc/gradients/README.md @@ -13,31 +13,35 @@ below. 2. Write the op gradient with the following naming scheme: - Status OpNameGrad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { - ... - return scope.status(); - } - REGISTER_GRADIENT_OP("OpName", OpNameGrad); - -3. Ops gradients are implemented by using the [C++ - API](https://www.tensorflow.org/api_docs/cc/). + ``` + Status OpNameGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + ... + return scope.status(); + } + REGISTER_GRADIENT_OP("OpName", OpNameGrad); + ``` + +3. Ops gradients are implemented by using the + [C++ API](https://www.tensorflow.org/api_docs/cc/). 4. Tests should be included in `foo_grad_test.cc`. Please see [`array_grad_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/gradients/array_grad_test.cc) - for an many examples. Tests are as simple as, creating a placeholder input - for the op's inputs and calling `RunTest` (`RunTest` uses a [gradient - checker](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/framework/gradient_checker.cc) + for many examples. Tests are as simple as, creating a placeholder input for + the op's inputs and calling `RunTest` (`RunTest` uses a + [gradient checker](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/framework/gradient_checker.cc) to verify that the theoretical gradient matches the numeric gradient). For example: - TEST_F(ArrayGradTest, IdentityGrad) { - TensorShape shape({5, 2}); - auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = Identity(scope_, x); - RunTest(x, shape, y, shape); - } + ``` + TEST_F(ArrayGradTest, IdentityGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Identity(scope_, x); + RunTest(x, shape, y, shape); + } + ``` NOTE: There are some ops that require features from the C++ API that are not yet implemented. diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 53a997de688adb..ecb5fcf3a3e840 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -491,6 +491,14 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (!op_filter_.allow_where_op && node.type_string() == "Where") { + absl::string_view uncompilable_reason = "Where op"; + MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, + encapsulating_function, uncompilable_nodes); + LogNotCompilable(node, uncompilable_reason); + return false; + } + if (!op_filter_.allow_ops_producing_or_consuming_variant && OpProducesOrConsumesVariant(node)) { absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer"; diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index d25444a5bf4216..687add5d2714cb 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -136,6 +136,13 @@ class RecursiveCompilabilityChecker { // Whether to allow the compilation of CollectiveReduceV2Op. bool allow_collective_reduce_v2 = true; + // Whether to allow the compilation of WhereOp. Compilation of the WhereOp + // generates output with bounded dynamic shape that may cause failures with + // auto clustering. + // TODO(b/203693252): Enable tf.where during autoclustering after all the + // legalization issues are fixed. + bool allow_where_op = true; + // Whether ops that are marked as outside compiled are always considered // compilable. // TODO(b/191502757): Make this behavior true by default and remove this diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index acd20bde8806cd..843eccfa0c1f89 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1203,6 +1203,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { filter.require_always_compilable = true; filter.allow_string_consts = false; filter.allow_collective_reduce_v2 = false; + filter.allow_where_op = false; RecursiveCompilabilityChecker checker( filter, DeviceType{registration->compilation_device_name}); diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 595d358d8caaad..9f8133894843d7 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -930,7 +930,23 @@ cc_library( deps = [ ":hlo", ":lhlo", - ":map_hlo_to_lhlo_op", + ":map_lhlo_to_hlo_op", + ":map_mhlo_to_scalar_op", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + ], +) + +cc_library( + name = "map_mhlo_to_scalar_op", + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"], + deps = [ + ":hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:ComplexDialect", @@ -959,6 +975,15 @@ cc_library( ], ) +cc_library( + name = "map_lhlo_to_hlo_op", + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lhlo_to_hlo_op.h"], + deps = [ + ":hlo", + ":lhlo", + ], +) + cc_library( name = "lhlo_legalize_to_affine", srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc"], @@ -1115,8 +1140,7 @@ cc_library( ], deps = [ ":hlo", - ":lhlo", - ":map_lmhlo_to_scalar_op", + ":map_mhlo_to_scalar_op", ":pass_details", ":type_conversion", "@llvm-project//llvm:Support", @@ -1253,6 +1277,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", @@ -1280,6 +1305,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index eed6fe34d7a9ef..7a6bbd4aa3100a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1491,6 +1491,13 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []> { `call_target_name` should be short as it may be used in labels. `backend_config` can encode arbitrarily large amounts of information. + `has_side_effect` must be true if the custom call has side-effects. + `api_version` specifies the version of the API used by the custom call + function. + + A custom call may apply functions within the scope of the parent module. + They can be referenced using `called_computations` attribute. + A custom call can also have layout constraints on operands and results which can be specified as optional `operand_layouts` and `result_layouts` attributes. The layout attribute is an array of rank-1 index tensors and the @@ -1517,6 +1524,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []> { DefaultValuedAttr: $api_version, + DefaultValuedAttr:$called_computations, OptionalAttr:$operand_layouts, OptionalAttr:$result_layouts ); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 6ec6c818afe698..15de7bf33f9989 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -115,6 +115,13 @@ def HLO_LayoutAttr : Attr< def HLO_ArrayOfLayoutAttr : TypedArrayAttrBase; +// An array of FlatSymbolRef attributes that can be used as a default valued +// attribute. +def HLO_FlatSymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)"; +} + //===----------------------------------------------------------------------===// // Common convolution attributes diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 8a91a549f51ad1..454925588a6bb8 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -15,12 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LhloLegalizeToLinalgPass : FunctionPass<"lhlo-legalize-to-linalg"> { - let summary = "Legalize from LHLO dialect to Linalg dialect."; - let constructor = "createLegalizeLhloToLinalgPass()"; -} - - def LhloFuseLinalgPass : FunctionPass<"lhlo-fuse-linalg"> { let summary = "Greedily fuse linalg ops obtained after LHLO lowering."; let constructor = "createLhloFuseLinalgPass()"; @@ -33,7 +27,6 @@ def LhloFuseLinalgPass : FunctionPass<"lhlo-fuse-linalg"> { ]; } - def LhloLegalizeToAffinePass : FunctionPass<"lhlo-legalize-to-affine"> { let summary = "Legalize from LHLO dialect to affine dialect."; let constructor = "createLhloLegalizeToAffinePass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lhlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lhlo_to_hlo_op.h new file mode 100644 index 00000000000000..248fc18bda9f0b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lhlo_to_hlo_op.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H_ + +#include + +#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace lmhlo { + +template +struct LhloToHloOpImpl { + using Type = std::false_type; +}; +template +using LhloToHloOp = typename LhloToHloOpImpl::Type; + +#define MAP_LHLO_TO_HLO(OpName) \ + template <> \ + struct LhloToHloOpImpl { \ + using Type = mhlo::OpName; \ + } + +MAP_LHLO_TO_HLO(AbsOp); +MAP_LHLO_TO_HLO(AddOp); +MAP_LHLO_TO_HLO(AndOp); +MAP_LHLO_TO_HLO(Atan2Op); +MAP_LHLO_TO_HLO(BitcastConvertOp); +MAP_LHLO_TO_HLO(BroadcastInDimOp); +MAP_LHLO_TO_HLO(CeilOp); +MAP_LHLO_TO_HLO(ClampOp); +MAP_LHLO_TO_HLO(ConstOp); +MAP_LHLO_TO_HLO(CompareOp); +MAP_LHLO_TO_HLO(ComplexOp); +MAP_LHLO_TO_HLO(ConcatenateOp); +MAP_LHLO_TO_HLO(ConvOp); +MAP_LHLO_TO_HLO(ConvertOp); +MAP_LHLO_TO_HLO(CopyOp); +MAP_LHLO_TO_HLO(CosOp); +MAP_LHLO_TO_HLO(CustomCallOp); +MAP_LHLO_TO_HLO(DivOp); +MAP_LHLO_TO_HLO(DotOp); +MAP_LHLO_TO_HLO(DynamicBroadcastInDimOp); +MAP_LHLO_TO_HLO(DynamicGatherOp); +MAP_LHLO_TO_HLO(DynamicIotaOp); +MAP_LHLO_TO_HLO(DynamicPadOp); +MAP_LHLO_TO_HLO(DynamicReshapeOp); +MAP_LHLO_TO_HLO(ExpOp); +MAP_LHLO_TO_HLO(Expm1Op); +MAP_LHLO_TO_HLO(FloorOp); +MAP_LHLO_TO_HLO(GatherOp); +MAP_LHLO_TO_HLO(ImagOp); +MAP_LHLO_TO_HLO(IotaOp); +MAP_LHLO_TO_HLO(IsFiniteOp); +MAP_LHLO_TO_HLO(LogOp); +MAP_LHLO_TO_HLO(LogisticOp); +MAP_LHLO_TO_HLO(Log1pOp); +MAP_LHLO_TO_HLO(MaxOp); +MAP_LHLO_TO_HLO(MinOp); +MAP_LHLO_TO_HLO(MulOp); +MAP_LHLO_TO_HLO(NegOp); +MAP_LHLO_TO_HLO(NotOp); +MAP_LHLO_TO_HLO(OrOp); +MAP_LHLO_TO_HLO(PowOp); +MAP_LHLO_TO_HLO(RealDynamicSliceOp); +MAP_LHLO_TO_HLO(RealOp); +MAP_LHLO_TO_HLO(ReduceOp); +MAP_LHLO_TO_HLO(ReshapeOp); +MAP_LHLO_TO_HLO(RemOp); +MAP_LHLO_TO_HLO(RsqrtOp); +MAP_LHLO_TO_HLO(SelectOp); +MAP_LHLO_TO_HLO(ShiftLeftOp); +MAP_LHLO_TO_HLO(ShiftRightArithmeticOp); +MAP_LHLO_TO_HLO(ShiftRightLogicalOp); +MAP_LHLO_TO_HLO(SignOp); +MAP_LHLO_TO_HLO(SinOp); +MAP_LHLO_TO_HLO(SliceOp); +MAP_LHLO_TO_HLO(SqrtOp); +MAP_LHLO_TO_HLO(SubOp); +MAP_LHLO_TO_HLO(TanhOp); +MAP_LHLO_TO_HLO(TransposeOp); +MAP_LHLO_TO_HLO(XorOp); + +#undef MAP_LHLO_TO_HLO + +} // namespace lmhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 967036d7747849..857cc03ad2b9d5 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -16,843 +16,21 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/iterator_range.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/TypeUtilities.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lhlo_to_hlo_op.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" namespace mlir { namespace lmhlo { -namespace impl { -// A struct to map LhloBinaryOpTy type to the corresponding floating-point and -// integer scalar operation types. -template -struct LhloToScalarOp { - using FOp = void; - using IOp = void; - using UOp = void; - using COp = void; -}; - -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::AddFOp; - using IOp = ::mlir::arith::AddIOp; - using UOp = ::mlir::arith::AddIOp; - using COp = ::mlir::complex::AddOp; -}; -template <> -struct LhloToScalarOp { - using IOp = ::mlir::arith::AndIOp; - using UOp = ::mlir::arith::AndIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::CmpFOp; - using IOp = ::mlir::arith::CmpIOp; - using UOp = ::mlir::arith::CmpIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::CeilOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::CosOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::DivFOp; - using IOp = ::mlir::arith::DivSIOp; - using UOp = ::mlir::arith::DivUIOp; - using COp = ::mlir::complex::DivOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::ExpOp; - using COp = ::mlir::complex::ExpOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::ExpM1Op; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::FloorOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::MaxFOp; - using IOp = ::mlir::arith::MaxSIOp; - using UOp = ::mlir::arith::MaxUIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::MinFOp; - using IOp = ::mlir::arith::MinSIOp; - using UOp = ::mlir::arith::MinUIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::LogOp; - using COp = ::mlir::complex::LogOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::Log1pOp; - using COp = ::mlir::complex::Log1pOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::MulFOp; - using IOp = ::mlir::arith::MulIOp; - using UOp = ::mlir::arith::MulIOp; - using COp = ::mlir::complex::MulOp; -}; -template <> -struct LhloToScalarOp { - using IOp = ::mlir::arith::OrIOp; - using UOp = ::mlir::arith::OrIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::RemFOp; - using IOp = ::mlir::arith::RemSIOp; - using UOp = ::mlir::arith::RemUIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::RsqrtOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::arith::SubFOp; - using IOp = ::mlir::arith::SubIOp; - using UOp = ::mlir::arith::SubIOp; - using COp = ::mlir::complex::SubOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::SqrtOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::SinOp; -}; -template <> -struct LhloToScalarOp { - using IOp = ::mlir::arith::ShLIOp; - using UOp = ::mlir::arith::ShLIOp; -}; -template <> -struct LhloToScalarOp { - using IOp = ::mlir::arith::ShRSIOp; - using UOp = ::mlir::arith::ShRSIOp; -}; -template <> -struct LhloToScalarOp { - using IOp = ::mlir::arith::ShRUIOp; - using UOp = ::mlir::arith::ShRUIOp; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::Atan2Op; -}; -template <> -struct LhloToScalarOp { - using FOp = ::mlir::math::TanhOp; -}; -template <> -struct LhloToScalarOp { - using IOp = ::mlir::arith::XOrIOp; - using UOp = ::mlir::arith::XOrIOp; -}; - -// Alias for the map from LHLO binary op type to STD floating-point op type. -template -using ScalarFOp = typename LhloToScalarOp::FOp; -// Alias for the map from LHLO binary op type to STD signed integer op type. -template -using ScalarIOp = typename LhloToScalarOp::IOp; -// Alias for the map from LHLO binary op type to STD unsigned integer op type. -template -using ScalarUOp = typename LhloToScalarOp::UOp; -// Alias for the map from LHLO binary op type to STD complex op type. -template -using ScalarCOp = typename LhloToScalarOp::COp; - -template -struct MapLhloOpToScalarOpImpl { - Value operator()(Location loc, ArrayRef result_types, - ArrayRef arg_types, ValueRange args, OpBuilder* b) { - return nullptr; - } -}; - -template -struct MapLhloOpToScalarOpImpl { - Value operator()(Location loc, ArrayRef result_types, - ArrayRef arg_types, ValueRange args, OpBuilder* b) { - return b->template create(loc, result_types, args, mlir::None); - } -}; - -template -struct MapLhloOpToScalarOpImpl { - Value operator()(Location loc, ArrayRef result_types, - ArrayRef arg_types, ValueRange args, OpBuilder* b) { - Type element_type = getElementTypeOrSelf(arg_types.front()); - if (SupportedType{}(element_type)) { - return b->template create(loc, result_types, args, - mlir::None); - } - return MapLhloOpToScalarOpImpl{}(loc, result_types, arg_types, - args, b); - } -}; - -template -struct MapLhloOpToScalarOpImpl { - Value operator()(Location loc, ArrayRef result_types, - ArrayRef arg_types, ValueRange args, OpBuilder* b) { - return MapLhloOpToScalarOpImpl{}(loc, result_types, arg_types, - args, b); - } -}; - -struct isAnyIntegerType { - bool operator()(Type t) { return t.isa(); } -}; - -struct isSignedIntegerType { - bool operator()(Type t) { - // Pretend that signless is signed. This will change eventually. - return t.isa() && !t.isUnsignedInteger(); - } -}; - -struct isUnsignedIntegerType { - bool operator()(Type t) { return t.isUnsignedInteger(); } -}; - -struct isFloatType { - bool operator()(Type t) { return t.isa(); } -}; - -struct isComplexType { - bool operator()(Type t) { return t.isa(); } -}; - -template