Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Add tensor.matmul to PyTorch SDK for JSI
Browse files Browse the repository at this point in the history
Summary:
added `matmulImpl` function to `TensorHostObject.cpp` and set the propertyHostFunction so that `matmul` references the new implementation.

added tests in `TensorTests.cpp` to check that matrix multiplications in various dimensions result in the correct shape and errors are thrown when there is a shape mismatch or a non tensor is passed as the argument

added proper function signature and documentation information in `torch.ts`

Reviewed By: raedle, liuyinglao

Differential Revision: D40067104

fbshipit-source-id: 33036ee71007b752d85f4788e32721541127c65a
  • Loading branch information
zrfisher authored and raedle committed Nov 20, 2022
1 parent edf43de commit 7f633ca
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,19 @@ jsi::Value itemImpl(
}
}

jsi::Value matmulImpl(
jsi::Runtime& runtime,
const jsi::Value& thisValue,
const jsi::Value* arguments,
size_t count) {
auto args = utils::ArgumentParser(runtime, thisValue, arguments, count);
args.requireNumArguments(1);
auto thisTensor = args.thisAsHostObject<TensorHostObject>()->tensor;
const auto otherTensor = args.asHostObject<TensorHostObject>(0)->tensor;
return utils::helpers::createFromHostObject<TensorHostObject>(
runtime, torch_::matmul(thisTensor, otherTensor));
}

jsi::Value mulImpl(
jsi::Runtime& runtime,
const jsi::Value& thisValue,
Expand Down Expand Up @@ -656,6 +669,7 @@ TensorHostObject::TensorHostObject(jsi::Runtime& runtime, torch_::Tensor t)
setPropertyHostFunction(runtime, "div", 1, divImpl);
setPropertyHostFunction(runtime, "flip", 1, flipImpl);
setPropertyHostFunction(runtime, "item", 0, itemImpl);
setPropertyHostFunction(runtime, "matmul", 1, matmulImpl);
setPropertyHostFunction(runtime, "mul", 1, mulImpl);
setPropertyHostFunction(runtime, "permute", 1, permuteImpl);
setPropertyHostFunction(runtime, "reshape", 1, reshapeImpl);
Expand Down
60 changes: 60 additions & 0 deletions react-native-pytorch-core/cxx/test/TensorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,66 @@ TEST_F(TorchliveTensorRuntimeTest, TensorMulTest) {
EXPECT_THROW(eval("torch.arrange(3, 4).mul('foo')"), facebook::jsi::JSError);
}

TEST_F(TorchliveTensorRuntimeTest, TensorMatmulTest) {
std::string matMul2Dx1D =
R"(
const x = torch.randn([3, 4]);
const y = torch.randn([4]);
const z = x.matmul(y);
z.size()[0] === 3;
)";
EXPECT_TRUE(eval(matMul2Dx1D).getBool());

std::string matMul2Dx2D =
R"(
const x = torch.randn([3, 4]);
const y = torch.randn([4, 2]);
const z = x.matmul(y);
z.size()[0] === 3 && z.size()[1] === 2;
)";
EXPECT_TRUE(eval(matMul2Dx2D).getBool());

std::string matMul3Dx1D =
R"(
const x = torch.randn([10, 3, 4]);
const y = torch.randn([4]);
const z = x.matmul(y);
z.size()[0] === 10 && z.size()[1] === 3;
)";
EXPECT_TRUE(eval(matMul3Dx1D).getBool());

std::string matMul3Dx2D =
R"(
const x = torch.randn([10, 3, 4]);
const y = torch.randn([4, 5]);
const z = x.matmul(y);
z.size()[0] === 10 && z.size()[1] === 3 && z.size()[2] === 5;
)";
EXPECT_TRUE(eval(matMul3Dx2D).getBool());

std::string matMulShapeMismatch =
R"(
const x = torch.randn([10, 3]);
const y = torch.randn([10, 3]);
const z = x.matmul(y);
)";
EXPECT_THROW(eval(matMulShapeMismatch), facebook::jsi::JSError);

std::string matMulWithNumber =
R"(
const x = torch.randn([10, 3]);
const z = x.matmul(3);
)";
EXPECT_THROW(eval(matMulWithNumber), facebook::jsi::JSError);

std::string matMulWithString =
R"(
const x = torch.randn([10, 3]);
const z = x.matmul('foo');
)";
EXPECT_THROW(eval(matMulWithString), facebook::jsi::JSError);
}

TEST_F(TorchliveTensorRuntimeTest, TensorPermuteTest) {
std::string tensorPermute =
R"(
Expand Down
8 changes: 8 additions & 0 deletions react-native-pytorch-core/src/torchlive/torch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ export interface Tensor {
* @param shape The new shape.
*/
reshape(shape: number[]): Tensor;
/**
* Performs matrix multiplication with other tensor.
*
* {@link https://pytorch.org/docs/1.12/generated/torch.Tensor.matmul.html}
*
* @param other tensor matrix multiplied this tensor.
*/
matmul(other: Tensor): Tensor;
/**
* Multiplies input by other scalar or tensor.
*
Expand Down

0 comments on commit 7f633ca

Please sign in to comment.