Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SAM export variants w/ input boxes #1681

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,8 @@ class SamOnnxConfig(OnnxConfig):
VARIANTS = {
"monolith": "All the SAM model components are exported as a single model.onnx.",
"split": "The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_encoder_mask_decoder.onnx. This allows to encoder the image only once for multiple point queries.",
"split-with-boxes": "The same as `split`, but with `input_boxes` instead of `input_points`.",
"split-with-points-and-boxes": "The same as `split`, but with `input_points` as well as `input_boxes`.",
}
DEFAULT_VARIANT = "split"

Expand Down Expand Up @@ -1681,6 +1683,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"pixel_values": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
"input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
"input_boxes": {0: "batch_size", 1: "nb_boxes_per_image"},
}
else:
if self.vision_encoder:
Expand All @@ -1689,14 +1692,20 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
inputs = {
"image_positional_embeddings": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
"input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
}

if self.variant == "split" or "points" in self.variant:
inputs["input_points"] = {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}
inputs["input_labels"] = {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}

if "boxes" in self.variant:
inputs['input_boxes'] = {0: "batch_size", 1: "nb_boxes_per_image"}

return inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.variant == "split" and self.vision_encoder:
if self.variant != "monolith" and self.vision_encoder:
return {"image_embeddings": {0: "batch_size"}, "image_positional_embeddings": {0: "batch_size"}}
else:
return {
Expand Down
10 changes: 6 additions & 4 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def patched_forward(
pixel_values=None,
input_points=None,
input_labels=None,
input_boxes=None,
image_embeddings=None,
image_positional_embeddings=None,
return_dict=True,
Expand All @@ -521,11 +522,12 @@ def patched_forward(
pixel_values=pixel_values,
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
image_embeddings=image_embeddings,
return_dict=return_dict,
**kwargs,
)
elif config.variant == "split":
else: # "split":
# return_dict = get_argument(args, kwargs, signature, "return_dict")
if config.vision_encoder:
# pixel_values = get_argument(args, kwargs, signature, "pixel_values")
Expand All @@ -551,13 +553,13 @@ def patched_forward(
"image_positional_embeddings": image_positional_embeddings,
}
else:
if input_points is None:
raise ValueError("input_points is required to export the prompt encoder / mask decoder.")
if input_points is None and input_boxes is None:
raise ValueError("`input_points` or `input_boxes` is required to export the prompt encoder / mask decoder.")

sparse_embeddings, dense_embeddings = model.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=None, # Not supported in the ONNX export
input_boxes=input_boxes,
input_masks=None, # Not supported in the ONNX export
)

Expand Down
8 changes: 7 additions & 1 deletion optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def wrapper(*args, **kwargs):
"height": 64,
"num_channels": 3,
"point_batch_size": 3,
"nb_boxes_per_image": 3,
"nb_points_per_image": 2,
# audio
"feature_size": 80,
Expand Down Expand Up @@ -836,7 +837,7 @@ class DummyPointsGenerator(DummyInputGenerator):
Generates dummy time step inputs.
"""

SUPPORTED_INPUT_NAMES = ("input_points", "input_labels")
SUPPORTED_INPUT_NAMES = ("input_points", "input_labels", "input_boxes")

def __init__(
self,
Expand All @@ -845,18 +846,23 @@ def __init__(
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
point_batch_size: int = DEFAULT_DUMMY_SHAPES["point_batch_size"],
nb_points_per_image: int = DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
nb_boxes_per_image: int = DEFAULT_DUMMY_SHAPES["nb_boxes_per_image"],
**kwargs,
):
self.task = task

self.batch_size = batch_size
self.point_batch_size = point_batch_size
self.nb_points_per_image = nb_points_per_image
self.nb_boxes_per_image = nb_boxes_per_image

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "input_points":
shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
elif input_name == "input_boxes":
shape = [self.batch_size, self.nb_boxes_per_image, 4]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
else: # input_labels
shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image]
return self.random_int_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=int_dtype)
Expand Down
Loading