Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
error when BigInt required for Tensor.item, Tensor.data
Browse files Browse the repository at this point in the history
Summary: Now that int64 / long are supported by the Dtype type, give users a helpful message when attempting to use them in an unsupported way

Reviewed By: raedle

Differential Revision: D37387583

fbshipit-source-id: f2e51a96a3724dcb3e605aa3c9439ee85e3944e7
  • Loading branch information
chrisklaiber authored and facebook-github-bot committed Jun 23, 2022
1 parent a7ee796 commit 6f499ca
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ jsi::Value dataImpl(
// BigIntArray
if (type == torch_::kInt64) {
throw jsi::JSError(
runtime, "the property 'data' of BigInt Tensor is not supported.");
runtime,
"the property 'data' for a tensor of dtype torch.int64 is not"
" supported. Work around this with .to({dtype: torch.int32})"
" This might alter the tensor values.");
}

std::string typedArrayName;
Expand Down Expand Up @@ -296,6 +299,16 @@ jsi::Value itemImpl(
size_t count) {
auto thiz =
thisValue.asObject(runtime).asHostObject<TensorHostObject>(runtime);

// TODO(T113480543): enable BigInt once Hermes supports it
if (thiz->tensor.dtype() == torch_::kInt64) {
throw jsi::JSError(
runtime,
"the property 'item' for a tensor of dtype torch.int64 is not"
" supported. Work around this with .to({dtype: torch.int32})"
" This might alter the tensor values.");
}

auto scalar = thiz->tensor.item();
if (scalar.isIntegral(/*includeBool=*/false)) {
return jsi::Value(scalar.toInt());
Expand Down
34 changes: 33 additions & 1 deletion react-native-pytorch-core/cxx/test/TensorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,20 @@ TEST_F(TorchliveTensorRuntimeTest, TensorDataTest) {
const tensor = torch.tensor([128, 255], {dtype: torch.long});
tensor.data();
)";
EXPECT_THROW(eval(tensorWithDtypeAsInt64), facebook::jsi::JSError);
EXPECT_THROW(
{
try {
eval(tensorWithDtypeAsInt64);
} catch (const facebook::jsi::JSError& e) {
EXPECT_TRUE(
std::string(e.what()).find(
"property 'data' for a tensor of dtype torch.int64 is not supported.") !=
std::string::npos)
<< e.what();
throw;
}
},
facebook::jsi::JSError);
}

TEST_F(TorchliveTensorRuntimeTest, TensorIndexing) {
Expand Down Expand Up @@ -746,6 +759,25 @@ TEST_F(TorchliveTensorRuntimeTest, TensorItemTest) {
tensor.item();
)";
EXPECT_THROW(eval(tensorItemForMultiElementTensor), facebook::jsi::JSError);

std::string tensorItemInt64 = R"(
const tensor = torch.tensor(1, {dtype: torch.int64});
tensor.item();
)";
EXPECT_THROW(
{
try {
eval(tensorItemInt64);
} catch (const facebook::jsi::JSError& e) {
EXPECT_TRUE(
std::string(e.what()).find(
"property 'item' for a tensor of dtype torch.int64 is not supported.") !=
std::string::npos)
<< e.what();
throw;
}
},
facebook::jsi::JSError);
}

TEST_F(TorchliveTensorRuntimeTest, TensorSqrtTest) {
Expand Down

0 comments on commit 6f499ca

Please sign in to comment.