Skip to content

Commit

Permalink
Generate error (instead of segfault) when trying to copy string tensor
Browse files Browse the repository at this point in the history
to GPU in EagerTensor constructor.

PiperOrigin-RevId: 168457320
  • Loading branch information
tensorflower-gardener committed Sep 12, 2017
1 parent 655f26f commit 00c8655
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
9 changes: 9 additions & 0 deletions tensorflow/c/eager/c_api.cc
Expand Up @@ -222,6 +222,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
return nullptr;
}
tensorflow::Tensor* src = &(h->t);
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Can't copy Tensor with type ",
tensorflow::DataTypeString(src->dtype()),
" to device ", DeviceName(dstd), ".")
.c_str());
return nullptr;
}
if (src_cpu) {
tensorflow::Tensor dst(
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/python/eager/BUILD
Expand Up @@ -92,13 +92,15 @@ py_library(
srcs_version = "PY2AND3",
)

py_test(
cuda_py_test(
name = "tensor_test",
srcs = ["tensor_test.py"],
srcs_version = "PY2AND3",
deps = [
additional_deps = [
":context",
":tensor",
":test",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//third_party/py/numpy",
],
)
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/eager/tensor_test.py
Expand Up @@ -20,9 +20,12 @@

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.eager import tensor
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util


Expand Down Expand Up @@ -136,6 +139,15 @@ def testStringTensor(self):
t_np = t.numpy()
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))

def testStringTensorOnGPU(self):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with ops.device("/device:GPU:0"):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Can't copy Tensor with type string to device"):
tensor.Tensor("test string")


if __name__ == "__main__":
test.main()

0 comments on commit 00c8655

Please sign in to comment.