Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Holistic Landmarker Python Tasks API #5028

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions mediapipe/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarker_graph",
"//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph",
],
)

Expand Down
1 change: 1 addition & 0 deletions mediapipe/tasks/python/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ py_library(
"//mediapipe/calculators/core:flow_limiter_calculator_py_pb2",
"//mediapipe/framework:calculator_options_py_pb2",
"//mediapipe/framework:calculator_py_pb2",
"@com_google_protobuf//:protobuf_python"
],
)
36 changes: 23 additions & 13 deletions mediapipe/tasks/python/core/task_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mediapipe.framework import calculator_options_pb2
from mediapipe.framework import calculator_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from google.protobuf import any_pb2


@doc_controls.do_not_generate_docs
Expand Down Expand Up @@ -80,18 +81,31 @@ def add_stream_name_prefix(tag_index_name):
raise ValueError(
'`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.'
)
task_subgraph_options = calculator_options_pb2.CalculatorOptions()

task_options_proto = self.task_options.to_pb2()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
task_options_proto)

node_config = calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=self.input_streams,
output_stream=self.output_streams
)

if hasattr(task_options_proto, 'ext'):
# Use the extension mechanism for task_subgraph_options (proto2)
task_subgraph_options = calculator_options_pb2.CalculatorOptions()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
task_options_proto)
node_config.options.CopyFrom(task_subgraph_options)
else:
# Use the Any type for task_subgraph_options (proto3)
task_subgraph_options = any_pb2.Any()
task_subgraph_options.Pack(self.task_options.to_pb2())
node_config.node_options.append(task_subgraph_options)

if not enable_flow_limiting:
return calculator_pb2.CalculatorGraphConfig(
node=[
calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=self.input_streams,
output_stream=self.output_streams,
options=task_subgraph_options)
node_config
],
input_stream=self.input_streams,
output_stream=self.output_streams)
Expand Down Expand Up @@ -121,11 +135,7 @@ def add_stream_name_prefix(tag_index_name):
options=flow_limiter_options)
config = calculator_pb2.CalculatorGraphConfig(
node=[
calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=task_subgraph_inputs,
output_stream=self.output_streams,
options=task_subgraph_options), flow_limiter
node_config, flow_limiter
],
input_stream=self.input_streams,
output_stream=self.output_streams)
Expand Down
21 changes: 21 additions & 0 deletions mediapipe/tasks/python/test/vision/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,27 @@ py_test(
],
)

py_test(
name = "holistic_landmarker_test",
srcs = ["holistic_landmarker_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
tags = ["not_run:arm"],
deps = [
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:holistic_landmarker",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@com_google_protobuf//:protobuf_python",
],
)

py_test(
name = "face_aligner_test",
srcs = ["face_aligner_test.py"],
Expand Down
Loading