diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp index 7edfd6846..60ec8aecb 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp @@ -227,7 +227,10 @@ jsi::Value dataImpl( // BigIntArray if (type == torch_::kInt64) { throw jsi::JSError( - runtime, "the property 'data' of BigInt Tensor is not supported."); + runtime, + "the property 'data' for a tensor of dtype torch.int64 is not" + " supported. Work around this with .to({dtype: torch.int32})" + " This might alter the tensor values."); } std::string typedArrayName; @@ -296,6 +299,16 @@ jsi::Value itemImpl( size_t count) { auto thiz = thisValue.asObject(runtime).asHostObject(runtime); + + // TODO(T113480543): enable BigInt once Hermes supports it + if (thiz->tensor.dtype() == torch_::kInt64) { + throw jsi::JSError( + runtime, + "the property 'item' for a tensor of dtype torch.int64 is not" + " supported. Work around this with .to({dtype: torch.int32})" + " This might alter the tensor values."); + } + auto scalar = thiz->tensor.item(); if (scalar.isIntegral(/*includeBool=*/false)) { return jsi::Value(scalar.toInt()); diff --git a/react-native-pytorch-core/cxx/test/TensorTests.cpp b/react-native-pytorch-core/cxx/test/TensorTests.cpp index b9215c1d4..ee3b338f4 100644 --- a/react-native-pytorch-core/cxx/test/TensorTests.cpp +++ b/react-native-pytorch-core/cxx/test/TensorTests.cpp @@ -245,7 +245,20 @@ TEST_F(TorchliveTensorRuntimeTest, TensorDataTest) { const tensor = torch.tensor([128, 255], {dtype: torch.long}); tensor.data(); )"; - EXPECT_THROW(eval(tensorWithDtypeAsInt64), facebook::jsi::JSError); + EXPECT_THROW( + { + try { + eval(tensorWithDtypeAsInt64); + } catch (const facebook::jsi::JSError& e) { + EXPECT_TRUE( + std::string(e.what()).find( + "property 'data' for a tensor of dtype torch.int64 is not supported.") != + std::string::npos) + << e.what(); + throw; + } + }, + facebook::jsi::JSError); } TEST_F(TorchliveTensorRuntimeTest, TensorIndexing) { @@ -746,6 +759,25 @@ TEST_F(TorchliveTensorRuntimeTest, TensorItemTest) { tensor.item(); )"; EXPECT_THROW(eval(tensorItemForMultiElementTensor), facebook::jsi::JSError); + + std::string tensorItemInt64 = R"( + const tensor = torch.tensor(1, {dtype: torch.int64}); + tensor.item(); + )"; + EXPECT_THROW( + { + try { + eval(tensorItemInt64); + } catch (const facebook::jsi::JSError& e) { + EXPECT_TRUE( + std::string(e.what()).find( + "property 'item' for a tensor of dtype torch.int64 is not supported.") != + std::string::npos) + << e.what(); + throw; + } + }, + facebook::jsi::JSError); } TEST_F(TorchliveTensorRuntimeTest, TensorSqrtTest) {