Skip to content

Commit

Permalink
Allow tile op to work on variant dtype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 312382133
Change-Id: I3a0f95865ca0f782fa73f7ba55b3d987de006332
  • Loading branch information
tensorflower-gardener committed May 19, 2020
1 parent db57348 commit f8a918c
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/BUILD
Expand Up @@ -1339,6 +1339,7 @@ tf_kernel_library(
"tile_functor_cpu_int8.cc",
"tile_functor_cpu_tstring.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_functor_sycl.cc",
],
hdrs = ["tile_functor.h"],
Expand Down Expand Up @@ -6907,6 +6908,7 @@ filegroup(
"tile_functor_cpu_int8.cc",
"tile_functor_cpu_tstring.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_ops.cc",
"tile_ops_cpu_impl_1.cc",
"tile_ops_cpu_impl_2.cc",
Expand Down
30 changes: 30 additions & 0 deletions tensorflow/core/kernels/tile_functor_cpu_variant.cc
@@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/tile_functor_cpu.h"

namespace tensorflow {
namespace functor {

typedef Eigen::ThreadPoolDevice CPUDevice;

template struct Tile<CPUDevice, Variant, int32>;
template struct Tile<CPUDevice, Variant, int64>;

} // end namespace functor
} // end namespace tensorflow
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/tile_ops.cc
Expand Up @@ -143,6 +143,7 @@ TF_CALL_half(DECLARE_TYPE);
TF_CALL_complex64(DECLARE_TYPE);
TF_CALL_complex128(DECLARE_TYPE);
TF_CALL_tstring(DECLARE_TYPE);
TF_CALL_variant(DECLARE_TYPE);
#undef DECLARE_TYPE

#define DECLARE_DIM(T, NDIM) \
Expand Down Expand Up @@ -244,6 +245,7 @@ class TileOp : public OpKernel {
TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
TF_CALL_complex64(HANDLE_TYPE_NAME);
TF_CALL_complex128(HANDLE_TYPE_NAME);
TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice

#undef HANDLE_TYPE_NAME
#undef HANDLE_TYPE
Expand Down Expand Up @@ -323,6 +325,7 @@ TF_CALL_half(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
TF_CALL_tstring(HANDLE_TYPE_NAME_CPU);
TF_CALL_variant(HANDLE_TYPE_NAME_CPU);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_bool(HANDLE_TYPE_NAME_GPU);
Expand Down
28 changes: 28 additions & 0 deletions tensorflow/python/kernel_tests/array_ops_test.py
Expand Up @@ -42,6 +42,7 @@
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
Expand Down Expand Up @@ -1994,5 +1995,32 @@ def repeat_fn(array, repeats):
self.assertAllEqual(v_tf_fn, v_np)


@test_util.run_all_in_graph_and_eager_modes
class TileVariantTest(test_util.TensorFlowTestCase):

def test_tile_tensor_list(self):
t = constant_op.constant(np.random.uniform(size=[2, 3, 4]))
handle = list_ops.tensor_list_from_tensor(t, element_shape=None)
with ops.device("CPU:0"):
tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2])
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
[3, 4])
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
[3, 4])
self.assertAllEqual(t, tiled_tensor_0)
self.assertAllEqual(t, tiled_tensor_1)
# Now mutate some of the lists and make sure the changes are not reflected
# in the tiled handles.
with ops.control_dependencies([
list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle),
list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]):
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
[3, 4])
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
[3, 4])
self.assertAllEqual(t, tiled_tensor_0)
self.assertAllEqual(t, tiled_tensor_1)


if __name__ == "__main__":
test_lib.main()

0 comments on commit f8a918c

Please sign in to comment.