diff --git a/onnx/test/cpp/shape_inference_test.cc b/onnx/test/cpp/shape_inference_test.cc index 614fa10bec8..9791958cc5c 100644 --- a/onnx/test/cpp/shape_inference_test.cc +++ b/onnx/test/cpp/shape_inference_test.cc @@ -507,8 +507,7 @@ TEST(GraphInferencerImplTest, Scan9_BasicTest) { doInferencingTest(false); } -void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShape) { - ModelProto model; +void ParseAndInfer(ModelProto& model, const char* modelStr) { OnnxParser parser(modelStr); auto status = parser.Parse(model); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -516,6 +515,11 @@ void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShap ShapeInferenceOptions options{true, 1, true}; ONNX_NAMESPACE::shape_inference::InferShapes(model, ONNX_NAMESPACE::OpSchemaRegistry::Instance(), options); +} + +void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShape) { + ModelProto model; + ParseAndInfer(model, modelStr); const auto inferredShape = model.graph().output(0).type().tensor_type().shape(); EXPECT_TRUE(inferredShape.dim_size() == expectedShape.dim_size()); @@ -620,5 +624,37 @@ TEST(ShapeInferenceTest, CheckShapesAndTypesTest) { #endif } +TEST(ShapeInferenceTest, CustomOpTest) { + const char* modelStr = R"ONNX( + +agraph (float[256, 768, 3] x) => (z1, z2) +{ + z1 = custom.domain.CustomOp (x) + # Inference cannot determine the type/shape of z1 + z2 = Abs(x) + # Inference SHOULD determine the type/shape of z2 (same as that of x) +} +)ONNX"; + + ModelProto model; + ParseAndInfer(model, modelStr); + + auto& z1_value_info = model.graph().output(0); + // Check no inferred type for z1 (It's a quirk of the implementation that it + // has a dummy TypeProto, but it should have no values filled in.) + ASSERT_TRUE(z1_value_info.has_type()); + ASSERT_FALSE(z1_value_info.type().has_tensor_type()); + + // Check inferred type for z2: + auto& z2_value_info = model.graph().output(1); + ASSERT_TRUE(z2_value_info.has_type()); + ASSERT_TRUE(z2_value_info.type().has_tensor_type()); + EXPECT_EQ(z2_value_info.type().tensor_type().elem_type(), TensorProto_DataType_FLOAT); + EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim_size(), 3); + EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(0).dim_value(), 256); + EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(1).dim_value(), 768); + EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(2).dim_value(), 3); +} + } // namespace Test } // namespace ONNX_NAMESPACE