diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp index b3ed9482f..54201270e 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp @@ -386,6 +386,19 @@ jsi::Value itemImpl( } } +jsi::Value matmulImpl( + jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + auto args = utils::ArgumentParser(runtime, thisValue, arguments, count); + args.requireNumArguments(1); + auto thisTensor = args.thisAsHostObject()->tensor; + const auto otherTensor = args.asHostObject(0)->tensor; + return utils::helpers::createFromHostObject( + runtime, torch_::matmul(thisTensor, otherTensor)); +} + jsi::Value mulImpl( jsi::Runtime& runtime, const jsi::Value& thisValue, @@ -656,6 +669,7 @@ TensorHostObject::TensorHostObject(jsi::Runtime& runtime, torch_::Tensor t) setPropertyHostFunction(runtime, "div", 1, divImpl); setPropertyHostFunction(runtime, "flip", 1, flipImpl); setPropertyHostFunction(runtime, "item", 0, itemImpl); + setPropertyHostFunction(runtime, "matmul", 1, matmulImpl); setPropertyHostFunction(runtime, "mul", 1, mulImpl); setPropertyHostFunction(runtime, "permute", 1, permuteImpl); setPropertyHostFunction(runtime, "reshape", 1, reshapeImpl); diff --git a/react-native-pytorch-core/cxx/test/TensorTests.cpp b/react-native-pytorch-core/cxx/test/TensorTests.cpp index 2ce5d9fd1..af6031f81 100644 --- a/react-native-pytorch-core/cxx/test/TensorTests.cpp +++ b/react-native-pytorch-core/cxx/test/TensorTests.cpp @@ -512,6 +512,66 @@ TEST_F(TorchliveTensorRuntimeTest, TensorMulTest) { EXPECT_THROW(eval("torch.arrange(3, 4).mul('foo')"), facebook::jsi::JSError); } +TEST_F(TorchliveTensorRuntimeTest, TensorMatmulTest) { + std::string matMul2Dx1D = + R"( + const x = torch.randn([3, 4]); + const y = torch.randn([4]); + const z = x.matmul(y); + z.size()[0] === 3; + )"; + EXPECT_TRUE(eval(matMul2Dx1D).getBool()); + + std::string matMul2Dx2D = + R"( + const x = torch.randn([3, 4]); + const y = torch.randn([4, 2]); + const z = x.matmul(y); + z.size()[0] === 3 && z.size()[1] === 2; + )"; + EXPECT_TRUE(eval(matMul2Dx2D).getBool()); + + std::string matMul3Dx1D = + R"( + const x = torch.randn([10, 3, 4]); + const y = torch.randn([4]); + const z = x.matmul(y); + z.size()[0] === 10 && z.size()[1] === 3; + )"; + EXPECT_TRUE(eval(matMul3Dx1D).getBool()); + + std::string matMul3Dx2D = + R"( + const x = torch.randn([10, 3, 4]); + const y = torch.randn([4, 5]); + const z = x.matmul(y); + z.size()[0] === 10 && z.size()[1] === 3 && z.size()[2] === 5; + )"; + EXPECT_TRUE(eval(matMul3Dx2D).getBool()); + + std::string matMulShapeMismatch = + R"( + const x = torch.randn([10, 3]); + const y = torch.randn([10, 3]); + const z = x.matmul(y); + )"; + EXPECT_THROW(eval(matMulShapeMismatch), facebook::jsi::JSError); + + std::string matMulWithNumber = + R"( + const x = torch.randn([10, 3]); + const z = x.matmul(3); + )"; + EXPECT_THROW(eval(matMulWithNumber), facebook::jsi::JSError); + + std::string matMulWithString = + R"( + const x = torch.randn([10, 3]); + const z = x.matmul('foo'); + )"; + EXPECT_THROW(eval(matMulWithString), facebook::jsi::JSError); +} + TEST_F(TorchliveTensorRuntimeTest, TensorPermuteTest) { std::string tensorPermute = R"( diff --git a/react-native-pytorch-core/src/torchlive/torch.ts b/react-native-pytorch-core/src/torchlive/torch.ts index 8c54e62d9..8041944eb 100644 --- a/react-native-pytorch-core/src/torchlive/torch.ts +++ b/react-native-pytorch-core/src/torchlive/torch.ts @@ -278,6 +278,14 @@ export interface Tensor { * @param shape The new shape. */ reshape(shape: number[]): Tensor; + /** + * Performs matrix multiplication with other tensor. + * + * {@link https://pytorch.org/docs/1.12/generated/torch.Tensor.matmul.html} + * + * @param other tensor matrix multiplied this tensor. + */ + matmul(other: Tensor): Tensor; /** * Multiplies input by other scalar or tensor. *