From dc82ee9acaa4a8246e2967381b9c0e33c59c8609 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 12 Dec 2025 14:58:19 +0000 Subject: [PATCH 1/7] Fix tied weight keys sam2 video --- src/transformers/models/sam2_video/modeling_sam2_video.py | 5 ----- src/transformers/models/sam2_video/modular_sam2_video.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/src/transformers/models/sam2_video/modeling_sam2_video.py b/src/transformers/models/sam2_video/modeling_sam2_video.py index d107016ccfc2..f9d85c217bef 100644 --- a/src/transformers/models/sam2_video/modeling_sam2_video.py +++ b/src/transformers/models/sam2_video/modeling_sam2_video.py @@ -1559,11 +1559,6 @@ class Sam2VideoModel(Sam2VideoPreTrainedModel): input_modalities = ("video", "text") _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)} _keys_to_ignore_on_load_unexpected = [] - _tied_weights_keys = { - "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" - } - # need to be ignored, as it's a buffer and will not be correctly detected as tied weight - _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] def __init__(self, config: Sam2VideoConfig): super().__init__(config) diff --git a/src/transformers/models/sam2_video/modular_sam2_video.py b/src/transformers/models/sam2_video/modular_sam2_video.py index e876ee5bda5b..ea2e5d7cadb1 100644 --- a/src/transformers/models/sam2_video/modular_sam2_video.py +++ b/src/transformers/models/sam2_video/modular_sam2_video.py @@ -1446,11 +1446,6 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000): @auto_docstring class Sam2VideoModel(Sam2Model): input_modalities = ("video", "text") - _tied_weights_keys = { - "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" - } - # need to be ignored, as it's a buffer and will not be correctly detected as tied weight - _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _keys_to_ignore_on_load_unexpected = [] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)} From 36f163e7636b4e8160cfa28dee7d171555c321b1 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 30 Jan 2026 03:00:13 +0000 Subject: [PATCH 2/7] add pvs pipeline --- src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 22 + src/transformers/pipelines/__init__.py | 8 + .../promptable_visual_segmentation.py | 393 ++++++++++++++++++ ...ipelines_promptable_visual_segmentation.py | 283 +++++++++++++ 5 files changed, 708 insertions(+) create mode 100644 src/transformers/pipelines/promptable_visual_segmentation.py create mode 100644 tests/pipelines/test_pipelines_promptable_visual_segmentation.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e5677dd872f9..efee915394c4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -162,6 +162,7 @@ "TextToAudioPipeline", "TokenClassificationPipeline", "VideoClassificationPipeline", + "PromptableVisualSegmentationPipeline", "VisualQuestionAnsweringPipeline", "ZeroShotAudioClassificationPipeline", "ZeroShotClassificationPipeline", @@ -663,6 +664,7 @@ from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat from .pipelines import Pipeline as Pipeline from .pipelines import PipelineDataFormat as PipelineDataFormat + from .pipelines import PromptableVisualSegmentationPipeline as PromptableVisualSegmentationPipeline from .pipelines import QuestionAnsweringPipeline as QuestionAnsweringPipeline from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline from .pipelines import TextClassificationPipeline as TextClassificationPipeline diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fa3b1c979939..e4eded9455db 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1628,6 +1628,18 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ) +MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Promptable Visual Segmentation mapping + ("sam3_tracker", "Sam3TrackerModel"), + ("sam2", "Sam2Model"), + # facebook/sam2.1-hiera-large checkpoint uses sam2_video config but can be used for single-image inference + ("sam2_video", "Sam2Model"), + ("sam", "SamModel"), + ("edgetam", "EdgeTamModel"), + ] +) + MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict( [ ("superpoint", "SuperPointForKeypointDetection"), @@ -1797,6 +1809,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) +MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING_NAMES +) + MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES ) @@ -1826,6 +1842,10 @@ class AutoModelForMaskGeneration(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING +class AutoModelForPromptableVisualSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING + + class AutoModelForKeypointDetection(_BaseAutoModelClass): _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING @@ -2168,6 +2188,7 @@ class AutoModelForAudioTokenization(_BaseAutoModelClass): "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_OBJECT_DETECTION_MAPPING", "MODEL_FOR_PRETRAINING_MAPPING", + "MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING", "MODEL_FOR_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", @@ -2214,6 +2235,7 @@ class AutoModelForAudioTokenization(_BaseAutoModelClass): "AutoModelForNextSentencePrediction", "AutoModelForObjectDetection", "AutoModelForPreTraining", + "AutoModelForPromptableVisualSegmentation", "AutoModelForQuestionAnswering", "AutoModelForSemanticSegmentation", "AutoModelForSeq2SeqLM", diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 57c7a806fdf2..59bd41213aaa 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -71,6 +71,7 @@ from .keypoint_matching import KeypointMatchingPipeline from .mask_generation import MaskGenerationPipeline from .object_detection import ObjectDetectionPipeline +from .promptable_visual_segmentation import PromptableVisualSegmentationPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline from .text_classification import TextClassificationPipeline @@ -107,6 +108,7 @@ AutoModelForMaskGeneration, AutoModelForMultimodalLM, AutoModelForObjectDetection, + AutoModelForPromptableVisualSegmentation, AutoModelForQuestionAnswering, AutoModelForSemanticSegmentation, AutoModelForSeq2SeqLM, @@ -298,6 +300,12 @@ "default": {"model": ("magic-leap-community/superglue_outdoor", "f4041f8")}, "type": "image", }, + "promptable-visual-segmentation": { + "impl": PromptableVisualSegmentationPipeline, + "pt": (AutoModelForPromptableVisualSegmentation,) if is_torch_available() else (), + "default": {"model": ("facebook/sam3", "main")}, + "type": "multimodal", + }, "any-to-any": { "impl": AnyToAnyPipeline, "tf": (), diff --git a/src/transformers/pipelines/promptable_visual_segmentation.py b/src/transformers/pipelines/promptable_visual_segmentation.py new file mode 100644 index 000000000000..dac53f39dfb2 --- /dev/null +++ b/src/transformers/pipelines/promptable_visual_segmentation.py @@ -0,0 +1,393 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Union, overload + +from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends +from .base import Pipeline, build_pipeline_init_args + + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_image, valid_images + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING_NAMES + +logger = logging.get_logger(__name__) + + +@add_end_docstrings(build_pipeline_init_args(has_processor=True)) +class PromptableVisualSegmentationPipeline(Pipeline): + """ + Promptable Visual Segmentation pipeline using SAM-family models. This pipeline predicts segmentation masks + for objects when you provide an image and visual prompts. Visual prompts can be points (with positive/negative + labels) or bounding boxes. + + This task is supported by models: Sam3TrackerModel, Sam2Model, SamModel, and EdgeTamModel. + + Example: + + ```python + >>> from transformers import pipeline + + >>> segmenter = pipeline(model="facebook/sam2.1-hiera-large", task="promptable-visual-segmentation") + >>> # Single point prompt + >>> segmenter( + ... "http://images.cocodataset.org/val2017/000000077595.jpg", + ... input_points=[[[[450, 600]]]], + ... input_labels=[[[1]]], + ... ) + [[{'score': 0.87, 'mask': tensor([...])}]] + + >>> # Box prompt + >>> segmenter( + ... "http://images.cocodataset.org/val2017/000000136466.jpg", + ... input_boxes=[[[59, 144, 76, 163]]], + ... ) + [[{'score': 0.92, 'mask': tensor([...])}]] + + >>> # Multiple points for refinement (positive and negative) + >>> segmenter( + ... "http://images.cocodataset.org/val2017/000000136466.jpg", + ... input_points=[[[[450, 600], [500, 620]]]], + ... input_labels=[[[1, 0]]], # 1=positive (include), 0=negative (exclude) + ... ) + [[{'score': 0.85, 'mask': tensor([...])}]] + ``` + + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + This promptable visual segmentation pipeline can currently be loaded from [`pipeline`] using the following task + identifier: `"promptable-visual-segmentation"`. + + See the list of available models on + [huggingface.co/models](https://huggingface.co/models?filter=promptable-visual-segmentation). + """ + + _load_processor = True + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + requires_backends(self, "vision") + self.check_model_type(MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING_NAMES) + + # Handle processor compatibility: Sam3VideoProcessor → Sam3TrackerProcessor + # facebook/sam3 checkpoint loads Sam3VideoProcessor by default, but this pipeline needs Sam3TrackerProcessor + if self.processor is not None and self.processor.__class__.__name__ == "Sam3VideoProcessor": + from ..models.sam3_tracker import Sam3TrackerProcessor + + # Get checkpoint name from model (empty string if instantiated from config, so use 'or' for fallback) + model_name = getattr(self.model, "name_or_path", "") or "facebook/sam3" + self.processor = Sam3TrackerProcessor.from_pretrained(model_name) + + # Determine if using SamProcessor (needs reshaped_input_sizes in post_process_masks) + self._needs_reshaped_sizes = self.processor.__class__.__name__ == "SamProcessor" + + @overload + def __call__( + self, + image: Union[str, "Image.Image"], + input_points: list[list[list[list[float]]]] | None = None, + input_labels: list[list[list[int]]] | None = None, + input_boxes: list[list[list[float]]] | None = None, + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: ... + + @overload + def __call__(self, image: list[dict[str, Any]], **kwargs: Any) -> list[list[dict[str, Any]]]: ... + + def __call__( + self, + image: Union[str, "Image.Image", list[dict[str, Any]]], + input_points: list[list[list[list[float]]]] | None = None, + input_labels: list[list[list[int]]] | None = None, + input_boxes: list[list[list[float]]] | None = None, + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: + """ + Segment objects in the image(s) based on visual prompts. + + Args: + image (`str`, `PIL.Image`, or `list[dict[str, Any]]`): + The pipeline handles three types of images: + + - A string containing an http url pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + You can use this parameter to send directly a list of images, or a dataset or a generator like so: + + ```python + >>> from transformers import pipeline + + >>> segmenter = pipeline(model="facebook/sam2.1-hiera-large", task="promptable-visual-segmentation") + >>> segmenter( + ... [ + ... { + ... "image": "http://images.cocodataset.org/val2017/000000077595.jpg", + ... "input_points": [[[[450, 600]]]], + ... "input_labels": [[[1]]], + ... }, + ... { + ... "image": "http://images.cocodataset.org/val2017/000000136466.jpg", + ... "input_boxes": [[[59, 144, 76, 163]]], + ... }, + ... ] + ... ) + [[{'score': 0.87, 'mask': ...}], [{'score': 0.92, 'mask': ...}]] + ``` + + input_points (`list[list[list[list[float]]]]`, *optional*): + Point prompts in (x, y) format. + Structure: [batch, objects, num_points, 2]. + Each point specifies a location on the image to guide segmentation. + + input_labels (`list[list[list[int]]]`, *optional*): + Labels for the point prompts. + Structure: [batch, objects, num_points]. + Values: 1 = positive (include in mask), 0 = negative (exclude from mask). + Must match the structure of `input_points`. + + input_boxes (`list[list[list[float]]]`, *optional*): + Bounding box prompts in xyxy format [x1, y1, x2, y2] in pixel coordinates. + Structure: [batch, num_boxes, 4]. + + multimask_output (`bool`, *optional*, defaults to False): + Whether to output multiple mask candidates per prompt. When True, returns 3 masks per object + ranked by IoU score. When False, returns only the best mask per object. + + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarizing the predicted masks. + + top_k (`int`, *optional*, defaults to None): + The number of top predictions that will be returned by the pipeline. If the provided number is `None` + or higher than the number of predictions available, it will default to the number of predictions. + + timeout (`float`, *optional*, defaults to None): + The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and + the call may block forever. + + Return: + A list of lists containing prediction results, one list per input image. Each list contains dictionaries + with the following keys: + + - **score** (`float`) -- IoU confidence score for the predicted mask. + - **mask** (`torch.Tensor`) -- Binary segmentation mask for the object, shape (height, width). + """ + # Handle different input formats + if isinstance(image, (str, Image.Image)): + inputs = { + "image": image, + "input_points": input_points, + "input_labels": input_labels, + "input_boxes": input_boxes, + } + elif isinstance(image, (list, tuple)) and valid_images(image): + # Batch of images - create individual inputs for each image + batch_inputs = self._prepare_batch_inputs(image, input_points, input_labels, input_boxes) + return list(super().__call__(batch_inputs, **kwargs)) + else: + """ + Supports the following format + - {"image": image, "input_points": points, "input_labels": labels} + - [{"image": image, "input_points": points, "input_labels": labels}] + - Generator and datasets + """ + inputs = image + + results = super().__call__(inputs, **kwargs) + return results + + def _prepare_batch_inputs(self, images, input_points, input_labels, input_boxes): + """Helper method to prepare batch inputs from separate parameters.""" + # Expand single values to match batch size + num_images = len(images) + points_list = input_points if input_points is not None else [None] * num_images + labels_list = input_labels if input_labels is not None else [None] * num_images + boxes_list = input_boxes if input_boxes is not None else [None] * num_images + + # Create input dict for each image + return ( + { + "image": img, + "input_points": points, + "input_labels": labels, + "input_boxes": boxes, + } + for img, points, labels, boxes in zip(images, points_list, labels_list, boxes_list) + ) + + def _sanitize_parameters(self, **kwargs): + preprocess_params = {} + if "timeout" in kwargs: + preprocess_params["timeout"] = kwargs["timeout"] + + forward_params = {} + if "multimask_output" in kwargs: + forward_params["multimask_output"] = kwargs["multimask_output"] + + postprocess_params = {} + if "mask_threshold" in kwargs: + postprocess_params["mask_threshold"] = kwargs["mask_threshold"] + if "top_k" in kwargs: + postprocess_params["top_k"] = kwargs["top_k"] + + return preprocess_params, forward_params, postprocess_params + + def preprocess(self, inputs, timeout=None): + """ + Preprocess inputs for the model. + + Args: + inputs: Dictionary containing 'image' and optionally 'input_points', 'input_labels', 'input_boxes' + timeout: Timeout for image loading + + Returns: + Dictionary with preprocessed model inputs + """ + image = load_image(inputs["image"], timeout=timeout) + input_points = inputs.get("input_points") + input_labels = inputs.get("input_labels") + input_boxes = inputs.get("input_boxes") + + # Validate that at least one prompt type is provided + if input_points is None and input_boxes is None: + raise ValueError( + "You must provide at least one prompt type: either 'input_points' (with 'input_labels') or 'input_boxes'. " + "For example: input_points=[[[[450, 600]]]], input_labels=[[[1]]] or input_boxes=[[[100, 150, 200, 250]]]" + ) + + # Validate that if input_points is provided, input_labels must also be provided + if input_points is not None and input_labels is None: + raise ValueError("When providing 'input_points', you must also provide 'input_labels'.") + + # Process inputs - pass all prompts as explicit parameters + processor_kwargs = { + "images": image, + "return_tensors": "pt", + } + + if input_points is not None: + processor_kwargs["input_points"] = input_points + processor_kwargs["input_labels"] = input_labels + + if input_boxes is not None: + processor_kwargs["input_boxes"] = input_boxes + + model_inputs = self.processor(**processor_kwargs) + model_inputs = model_inputs.to(self.dtype) + + # Store original size for post-processing + target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32) + model_inputs["original_sizes"] = target_size + + # For SamProcessor, we also need to store reshaped_input_sizes + if self._needs_reshaped_sizes and "reshaped_input_sizes" in model_inputs: + model_inputs["_reshaped_input_sizes"] = model_inputs["reshaped_input_sizes"] + + return model_inputs + + def _forward(self, model_inputs, multimask_output=False): + """ + Forward pass through the model. + + Args: + model_inputs: Preprocessed model inputs + multimask_output: Whether to output multiple masks per prompt + + Returns: + Model outputs with additional metadata + """ + original_sizes = model_inputs.pop("original_sizes") + reshaped_input_sizes = model_inputs.pop("_reshaped_input_sizes", None) + + outputs = self.model(**model_inputs, multimask_output=multimask_output) + + return { + "outputs": outputs, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + def postprocess(self, model_outputs, mask_threshold=0.0, top_k=None): + """ + Post-process model outputs into final predictions. + + Args: + model_outputs: Raw model outputs + mask_threshold: Threshold for binarizing masks + top_k: Maximum number of predictions to return per image + + Returns: + List of lists of dictionaries with 'score' and 'mask' keys + """ + outputs = model_outputs["outputs"] + original_sizes = model_outputs["original_sizes"] + reshaped_input_sizes = model_outputs["reshaped_input_sizes"] + + # Get masks and IoU scores from outputs + pred_masks = outputs.pred_masks # (batch, objects, num_masks, H, W) + iou_scores = outputs.iou_scores # (batch, objects, num_masks) + + # Post-process masks to original image size + post_process_kwargs = { + "masks": pred_masks.cpu(), + "original_sizes": original_sizes.tolist(), + "mask_threshold": mask_threshold, + "binarize": True, + } + + # For SamProcessor, we need to pass reshaped_input_sizes + if self._needs_reshaped_sizes and reshaped_input_sizes is not None: + post_process_kwargs["reshaped_input_sizes"] = reshaped_input_sizes.tolist() + + masks = self.processor.post_process_masks(**post_process_kwargs) + + # Format output as per-image list of dictionaries + final_results = [] + batch_size = pred_masks.shape[0] + + for batch_idx in range(batch_size): + image_results = [] + num_objects = pred_masks.shape[1] + num_masks_per_object = pred_masks.shape[2] + + for obj_idx in range(num_objects): + for mask_idx in range(num_masks_per_object): + score = iou_scores[batch_idx, obj_idx, mask_idx].item() + mask_tensor = masks[batch_idx][obj_idx, mask_idx] + + result = { + "score": score, + "mask": mask_tensor, + } + image_results.append(result) + + # Sort results by score in descending order + image_results = sorted(image_results, key=lambda x: x["score"], reverse=True) + + # Apply top_k filtering + if top_k is not None and len(image_results) > top_k: + image_results = image_results[:top_k] + + final_results.append(image_results) + + # If single image, return as list with one element (for consistency) + return final_results if batch_size > 1 or isinstance(pred_masks, (list, tuple)) else final_results diff --git a/tests/pipelines/test_pipelines_promptable_visual_segmentation.py b/tests/pipelines/test_pipelines_promptable_visual_segmentation.py new file mode 100644 index 000000000000..a42c3ac7eb3c --- /dev/null +++ b/tests/pipelines/test_pipelines_promptable_visual_segmentation.py @@ -0,0 +1,283 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import ( + Sam2Model, + Sam2Processor, + SamModel, + SamProcessor, + is_torch_available, + is_vision_available, + pipeline, +) +from transformers.testing_utils import require_torch, require_vision, slow + + +if is_torch_available(): + pass + +if is_vision_available(): + import requests + from PIL import Image + + +@require_torch +@require_vision +class PromptableVisualSegmentationPipelineTests(unittest.TestCase): + # Test image URLs + test_image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" + + def get_test_image(self): + """Helper to load test image.""" + return Image.open(requests.get(self.test_image_url, stream=True).raw).convert("RGB") + + def test_sam2_single_point(self): + """Test SAM2 with single point prompt.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] # Single point + input_labels = [[[1]]] # Positive + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + self.assertEqual(len(results), 1, "Should return results for 1 image") + self.assertGreater(len(results[0]), 0, "Should return at least 1 mask") + self.assertIn("score", results[0][0]) + self.assertIn("mask", results[0][0]) + self.assertIsInstance(results[0][0]["score"], float) + self.assertTrue(0 <= results[0][0]["score"] <= 1, "Score should be between 0 and 1") + + def test_sam2_box_prompt(self): + """Test SAM2 with box prompt.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_boxes = [[[75, 275, 1725, 850]]] # Box around truck + + results = segmenter(image, input_boxes=input_boxes, multimask_output=False) + + self.assertEqual(len(results), 1) + self.assertGreater(len(results[0]), 0) + self.assertIn("score", results[0][0]) + self.assertIn("mask", results[0][0]) + + def test_sam2_multiple_points(self): + """Test SAM2 with multiple points per object.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375], [1125, 625]]]] # Multiple points + input_labels = [[[1, 1]]] # Both positive + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + self.assertEqual(len(results), 1) + self.assertGreater(len(results[0]), 0) + + def test_sam2_multiple_objects(self): + """Test SAM2 with multiple objects in same image.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + # Points for two different objects + input_points = [[[[500, 375]], [[650, 750]]]] + input_labels = [[[1], [1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + self.assertEqual(len(results), 1) + self.assertGreaterEqual(len(results[0]), 2, "Should return at least 2 masks for 2 objects") + + def test_sam2_multimask_output(self): + """Test SAM2 with multimask_output=True.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=True) + + self.assertEqual(len(results), 1) + # With multimask_output=True, should return 3 masks per object + self.assertGreaterEqual(len(results[0]), 3, "Should return at least 3 masks with multimask_output=True") + + def test_sam2_mask_threshold(self): + """Test SAM2 with mask_threshold parameter.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter( + image, input_points=input_points, input_labels=input_labels, mask_threshold=0.5, multimask_output=False + ) + + self.assertEqual(len(results), 1) + self.assertGreater(len(results[0]), 0) + + def test_sam2_top_k(self): + """Test SAM2 with top_k parameter.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter( + image, input_points=input_points, input_labels=input_labels, multimask_output=True, top_k=2 + ) + + self.assertEqual(len(results), 1) + self.assertLessEqual(len(results[0]), 2, "Should return at most 2 masks with top_k=2") + + def test_sam_single_point(self): + """Test SAM with single point prompt.""" + model = SamModel.from_pretrained("facebook/sam-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam-vit-base") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + self.assertEqual(len(results), 1) + self.assertGreater(len(results[0]), 0) + self.assertIn("score", results[0][0]) + self.assertIn("mask", results[0][0]) + + def test_results_sorted_by_score(self): + """Test that results are sorted by score in descending order.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=True) + + scores = [r["score"] for r in results[0]] + sorted_scores = sorted(scores, reverse=True) + self.assertEqual(scores, sorted_scores, "Results should be sorted by score in descending order") + + def test_error_no_prompts(self): + """Test that error is raised when no prompts are provided.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + + with self.assertRaises(ValueError) as context: + segmenter(image) + + self.assertIn("at least one prompt type", str(context.exception)) + + def test_error_points_without_labels(self): + """Test that error is raised when points are provided without labels.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + + with self.assertRaises(ValueError) as context: + segmenter(image, input_points=input_points) + + self.assertIn("input_labels", str(context.exception)) + + @slow + def test_sam2_automatic_loading(self): + """Test that SAM2 can be loaded automatically with checkpoint name.""" + segmenter = pipeline("promptable-visual-segmentation", model="facebook/sam2.1-hiera-large") + + self.assertIsInstance(segmenter.model, Sam2Model) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + self.assertEqual(len(results), 1) + self.assertGreater(len(results[0]), 0) + + @slow + def test_sam_automatic_loading(self): + """Test that SAM can be loaded automatically with checkpoint name.""" + segmenter = pipeline("promptable-visual-segmentation", model="facebook/sam-vit-base") + + self.assertIsInstance(segmenter.model, SamModel) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + self.assertEqual(len(results), 1) + self.assertGreater(len(results[0]), 0) + + def test_mask_shape(self): + """Test that mask shape matches original image size.""" + model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-tiny") + processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-tiny") + + segmenter = pipeline("promptable-visual-segmentation", model=model, processor=processor) + + image = self.get_test_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) + + mask = results[0][0]["mask"] + expected_shape = (image.height, image.width) + self.assertEqual( + mask.shape, expected_shape, f"Mask shape {mask.shape} should match image size {expected_shape}" + ) From 2ef91b879562145dcb2c46ed45a84480185727b4 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 30 Jan 2026 23:47:30 +0000 Subject: [PATCH 3/7] Add docs, fix tests --- docs/source/en/main_classes/pipelines.md | 6 + docs/source/en/model_doc/auto.md | 4 + docs/source/en/model_doc/edgetam.md | 50 ++- docs/source/en/model_doc/sam.md | 42 ++ docs/source/en/model_doc/sam2.md | 40 +- docs/source/en/model_doc/sam3_tracker.md | 40 +- .../tasks/promptable_visual_segmentation.md | 381 ++++++++++++++++++ src/transformers/models/auto/modeling_auto.py | 12 +- src/transformers/pipelines/__init__.py | 2 +- tests/models/edgetam/test_modeling_edgetam.py | 8 +- tests/models/sam/test_modeling_sam.py | 8 +- tests/models/sam2/test_modeling_sam2.py | 8 +- .../test_modeling_sam3_tracker.py | 8 +- utils/check_docstrings.py | 1 + 14 files changed, 591 insertions(+), 19 deletions(-) create mode 100644 docs/source/en/tasks/promptable_visual_segmentation.md diff --git a/docs/source/en/main_classes/pipelines.md b/docs/source/en/main_classes/pipelines.md index e2a40ac3cc3e..54badb440f88 100644 --- a/docs/source/en/main_classes/pipelines.md +++ b/docs/source/en/main_classes/pipelines.md @@ -473,6 +473,12 @@ Pipelines available for multimodal tasks include the following. - __call__ - all +### PromptableVisualSegmentationPipeline + +[[autodoc]] PromptableVisualSegmentationPipeline + - __call__ + - all + ### VisualQuestionAnsweringPipeline [[autodoc]] VisualQuestionAnsweringPipeline diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index b45b3bfdb187..616ba3295918 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -113,6 +113,10 @@ The following auto classes are available for the following natural language proc [[autodoc]] AutoModelForMaskGeneration +### AutoModelForPromptableVisualSegmentation + +[[autodoc]] AutoModelForPromptableVisualSegmentation + ### AutoModelForSeq2SeqLM [[autodoc]] AutoModelForSeq2SeqLM diff --git a/docs/source/en/model_doc/edgetam.md b/docs/source/en/model_doc/edgetam.md index 173b89533c83..8149f770a9ba 100644 --- a/docs/source/en/model_doc/edgetam.md +++ b/docs/source/en/model_doc/edgetam.md @@ -39,14 +39,52 @@ The original code can be found [here](https://github.com/facebookresearch/EdgeTA ## Usage example -### Automatic Mask Generation with Pipeline +### Promptable Visual Segmentation Pipeline + +The easiest way to use EdgeTAM is through the `promptable-visual-segmentation` pipeline: + +```python +>>> from transformers import pipeline + +>>> segmenter = pipeline(model="yonigozlan/EdgeTAM-hf", task="promptable-visual-segmentation") +>>> # Single point prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000077595.jpg", +... input_points=[[[[450, 600]]]], +... input_labels=[[[1]]], +... ) +[[{'score': 0.87, 'mask': tensor([...])}]] + +>>> # Box prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_boxes=[[[59, 144, 76, 163]]], +... ) +[[{'score': 0.92, 'mask': tensor([...])}]] + +>>> # Multiple points for refinement +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_points=[[[[450, 600], [500, 620]]]], +... input_labels=[[[1, 0]]], # 1=positive, 0=negative +... ) +[[{'score': 0.85, 'mask': tensor([...])}]] +``` + + + +**Note:** The pipeline output format differs from using the model and processor manually. The pipeline returns a standardized format (list of lists of dicts with `score` and `mask`) to ensure consistency across all transformers pipelines, while the processor's `post_process_masks()` returns raw tensors. + + + +### Automatic Mask Generation Pipeline EdgeTAM can be used for automatic mask generation to segment all objects in an image using the `mask-generation` pipeline: ```python >>> from transformers import pipeline ->>> generator = pipeline("mask-generation", model="yonigozlan/edgetam-1", device=0) +>>> generator = pipeline("mask-generation", model="yonigozlan/EdgeTAM-hf", device=0) >>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" >>> outputs = generator(image_url, points_per_batch=64) @@ -69,8 +107,8 @@ from accelerate import Accelerator >>> device = Accelerator().device ->>> model = EdgeTamModel.from_pretrained("yonigozlan/edgetam-1").to(device) ->>> processor = Sam2Processor.from_pretrained("yonigozlan/edgetam-1") +>>> model = EdgeTamModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(device) +>>> processor = Sam2Processor.from_pretrained("yonigozlan/EdgeTAM-hf") >>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" >>> raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") @@ -166,8 +204,8 @@ from accelerate import Accelerator >>> device = Accelerator().device ->>> model = EdgeTamModel.from_pretrained("yonigozlan/edgetam-1").to(device) ->>> processor = Sam2Processor.from_pretrained("yonigozlan/edgetam-1") +>>> model = EdgeTamModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(device) +>>> processor = Sam2Processor.from_pretrained("yonigozlan/EdgeTAM-hf") >>> # Load multiple images >>> image_urls = [ diff --git a/docs/source/en/model_doc/sam.md b/docs/source/en/model_doc/sam.md index b770e41663e1..41b0b099eeec 100644 --- a/docs/source/en/model_doc/sam.md +++ b/docs/source/en/model_doc/sam.md @@ -44,6 +44,48 @@ Tips: This model was contributed by [ybelkada](https://huggingface.co/ybelkada) and [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/facebookresearch/segment-anything). +## Usage examples with 🤗 Transformers + +### Promptable Visual Segmentation Pipeline + +The easiest way to use SAM is through the `promptable-visual-segmentation` pipeline: + +```python +>>> from transformers import pipeline + +>>> segmenter = pipeline(model="facebook/sam-vit-base", task="promptable-visual-segmentation") +>>> # Single point prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000077595.jpg", +... input_points=[[[[450, 600]]]], +... input_labels=[[[1]]], +... ) +[[{'score': 0.87, 'mask': tensor([...])}]] + +>>> # Box prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_boxes=[[[59, 144, 76, 163]]], +... ) +[[{'score': 0.92, 'mask': tensor([...])}]] + +>>> # Multiple points for refinement +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_points=[[[[450, 600], [500, 620]]]], +... input_labels=[[[1, 0]]], # 1=positive, 0=negative +... ) +[[{'score': 0.85, 'mask': tensor([...])}]] +``` + + + +**Note:** The pipeline output format differs from using the model and processor manually. The pipeline returns a standardized format (list of lists of dicts with `score` and `mask`) to ensure consistency across all transformers pipelines, while the processor's `post_process_masks()` returns raw tensors. + + + +### Basic Usage with Model and Processor + Below is an example on how to run mask generation given an image and a 2D point: ```python diff --git a/docs/source/en/model_doc/sam2.md b/docs/source/en/model_doc/sam2.md index 3d0514de57cb..6e7a9c0f9299 100644 --- a/docs/source/en/model_doc/sam2.md +++ b/docs/source/en/model_doc/sam2.md @@ -47,7 +47,45 @@ The original code can be found [here](https://github.com/facebookresearch/sam2/t ## Usage example -### Automatic Mask Generation with Pipeline +### Promptable Visual Segmentation Pipeline + +The easiest way to use SAM2 is through the `promptable-visual-segmentation` pipeline: + +```python +>>> from transformers import pipeline + +>>> segmenter = pipeline(model="facebook/sam2.1-hiera-large", task="promptable-visual-segmentation") +>>> # Single point prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000077595.jpg", +... input_points=[[[[450, 600]]]], +... input_labels=[[[1]]], +... ) +[[{'score': 0.87, 'mask': tensor([...])}]] + +>>> # Box prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_boxes=[[[59, 144, 76, 163]]], +... ) +[[{'score': 0.92, 'mask': tensor([...])}]] + +>>> # Multiple points for refinement +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_points=[[[[450, 600], [500, 620]]]], +... input_labels=[[[1, 0]]], # 1=positive, 0=negative +... ) +[[{'score': 0.85, 'mask': tensor([...])}]] +``` + + + +**Note:** The pipeline output format differs from using the model and processor manually. The pipeline returns a standardized format (list of lists of dicts with `score` and `mask`) to ensure consistency across all transformers pipelines, while the processor's `post_process_masks()` returns raw tensors. + + + +### Automatic Mask Generation Pipeline SAM2 can be used for automatic mask generation to segment all objects in an image using the `mask-generation` pipeline: diff --git a/docs/source/en/model_doc/sam3_tracker.md b/docs/source/en/model_doc/sam3_tracker.md index c64c8b711c45..927474e154e9 100644 --- a/docs/source/en/model_doc/sam3_tracker.md +++ b/docs/source/en/model_doc/sam3_tracker.md @@ -43,7 +43,45 @@ This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan) an ## Usage example -### Automatic Mask Generation with Pipeline +### Promptable Visual Segmentation Pipeline + +The easiest way to use Sam3Tracker is through the `promptable-visual-segmentation` pipeline: + +```python +>>> from transformers import pipeline + +>>> segmenter = pipeline(model="facebook/sam3", task="promptable-visual-segmentation") +>>> # Single point prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000077595.jpg", +... input_points=[[[[450, 600]]]], +... input_labels=[[[1]]], +... ) +[[{'score': 0.87, 'mask': tensor([...])}]] + +>>> # Box prompt +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_boxes=[[[59, 144, 76, 163]]], +... ) +[[{'score': 0.92, 'mask': tensor([...])}]] + +>>> # Multiple points for refinement +>>> segmenter( +... "http://images.cocodataset.org/val2017/000000136466.jpg", +... input_points=[[[[450, 600], [500, 620]]]], +... input_labels=[[[1, 0]]], # 1=positive, 0=negative +... ) +[[{'score': 0.85, 'mask': tensor([...])}]] +``` + + + +**Note:** The pipeline output format differs from using the model and processor manually. The pipeline returns a standardized format (list of lists of dicts with `score` and `mask`) to ensure consistency across all transformers pipelines, while the processor's `post_process_masks()` returns raw tensors. + + + +### Automatic Mask Generation Pipeline Sam3Tracker can be used for automatic mask generation to segment all objects in an image using the `mask-generation` pipeline: diff --git a/docs/source/en/tasks/promptable_visual_segmentation.md b/docs/source/en/tasks/promptable_visual_segmentation.md new file mode 100644 index 000000000000..862548860ba3 --- /dev/null +++ b/docs/source/en/tasks/promptable_visual_segmentation.md @@ -0,0 +1,381 @@ + + +# Promptable Visual Segmentation + +[[open-in-colab]] + +Promptable Visual Segmentation (PVS) is a computer vision task that segments objects in an image based on interactive visual prompts. Unlike automatic segmentation methods, PVS lets you specify **exactly which objects** to segment by providing: + +- **Point prompts** with labels (positive points to include, negative points to exclude) +- **Bounding box prompts** (rectangular regions around objects) +- **Combinations** of points and boxes for refined segmentation + +For each prompted object, PVS returns: +- Binary segmentation masks +- Quality/confidence scores (IoU predictions) + +> [!NOTE] +> This task is supported by the SAM-family models on the Hub: [SAM3Tracker](https://huggingface.co/facebook/sam3), [SAM2](https://huggingface.co/facebook/sam2.1-hiera-large), [SAM](https://huggingface.co/facebook/sam-vit-base), and [EdgeTAM](https://huggingface.co/yonigozlan/EdgeTAM-hf). + +In this guide, you will learn how to: + +- Use the pipeline for quick inference +- Segment objects with single point clicks +- Refine segmentation with multiple points +- Use bounding boxes as prompts +- Segment multiple objects simultaneously +- Process batches of images efficiently + +Before you begin, make sure you have all the necessary libraries installed: + +```bash +pip install -q transformers +``` + +## Promptable Visual Segmentation pipeline + +The simplest way to try out promptable visual segmentation is to use the [`pipeline`]. Instantiate a pipeline from a [checkpoint on the Hugging Face Hub](https://huggingface.co/models?other=sam2): + +```python +>>> from transformers import pipeline + +>>> segmenter = pipeline("promptable-visual-segmentation", model="facebook/sam2.1-hiera-large") +``` + +Next, choose an image you'd like to segment objects in. Here we'll use an image from the [COCO dataset](https://cocodataset.org/): + +```py +>>> from PIL import Image +>>> import requests + +>>> url = "http://images.cocodataset.org/val2017/000000077595.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") +>>> image +``` + +
+ Cats on a couch +
+ +### Single point segmentation + +Pass the image and a point prompt. Points are specified as `[[[x, y]]]` coordinates with corresponding labels `[[[1]]]` where `1` means "include this object": + +```py +>>> # Click on a cat's body +>>> input_points = [[[[450, 600]]]] # [batch, objects, points_per_object, coordinates] +>>> input_labels = [[[1]]] # [batch, objects, points_per_object] - 1=positive click + +>>> results = segmenter(image, input_points=input_points, input_labels=input_labels) +>>> results +[[{'score': 0.8731, + 'mask': tensor([[False, False, False, ..., False, False, False], + [False, False, False, ..., False, False, False], + ...])}]] +``` + +The results are a list of lists (one inner list per input image). Each object gets multiple mask predictions ranked by quality score: +- `score`: Quality score (typically IoU prediction, 0-1) +- `mask`: Binary segmentation mask (same size as original image) + +By default, the model returns 3 masks per prompt, ranked by quality. To get only the best mask: + +```py +>>> results = segmenter(image, input_points=input_points, input_labels=input_labels, multimask_output=False) +>>> print(f"Returned {len(results[0])} mask(s)") # 1 mask +Returned 1 mask(s) +``` + +### Visualizing results + +Let's visualize the segmentation mask: + +```py +>>> import numpy as np +>>> import matplotlib.pyplot as plt + +>>> fig, axes = plt.subplots(1, 2, figsize=(15, 5)) + +>>> # Show original image with point +>>> axes[0].imshow(image) +>>> point_x, point_y = input_points[0][0][0] +>>> axes[0].plot(point_x, point_y, "ro", markersize=10, markeredgewidth=2, markeredgecolor="white") +>>> axes[0].set_title("Input: Image + Point") +>>> axes[0].axis("off") + +>>> # Show segmentation result +>>> mask = results[0][0]["mask"].numpy() +>>> score = results[0][0]["score"] + +>>> axes[1].imshow(image) +>>> # Create colored overlay +>>> overlay = np.zeros((*mask.shape, 4)) +>>> overlay[mask] = [1, 0, 0, 0.5] # Red with 50% transparency +>>> axes[1].imshow(overlay) +>>> axes[1].set_title(f"Segmentation (score: {score:.3f})") +>>> axes[1].axis("off") + +>>> plt.tight_layout() +>>> plt.show() +``` + +### Multiple points for refinement + +You can provide multiple points to refine the segmentation. Use positive points (label=1) to include regions and negative points (label=0) to exclude them: + +```py +>>> # First positive point on cat body, second negative point on the couch +>>> input_points = [[[[450, 600], [300, 400]]]] +>>> input_labels = [[[1, 0]]] # 1=include, 0=exclude + +>>> results = segmenter( +... image, +... input_points=input_points, +... input_labels=input_labels, +... multimask_output=False, +... ) +>>> # This will segment the cat while excluding couch regions +``` + +### Bounding box segmentation + +You can also use bounding boxes as prompts. Boxes are specified in `[x1, y1, x2, y2]` format (top-left and bottom-right corners): + +```py +>>> # Define a box around the left cat +>>> input_boxes = [[[100, 200, 350, 550]]] # [batch, objects, 4] + +>>> results = segmenter(image, input_boxes=input_boxes, multimask_output=False) +>>> mask = results[0][0]["mask"] +>>> print(f"Segmented object with box prompt, score: {results[0][0]['score']:.3f}") +``` + +### Multiple objects segmentation + +Segment multiple objects in the same image by providing multiple prompts: + +```py +>>> # Points for two cats - each cat gets its own point +>>> input_points = [ +... [[[450, 600]], [[200, 300]]] # Two objects, each with one point +... ] +>>> input_labels = [[[1], [1]]] # Both positive + +>>> results = segmenter( +... image, +... input_points=input_points, +... input_labels=input_labels, +... multimask_output=False, +... ) + +>>> print(f"Segmented {len(results[0])} objects") +>>> for i, obj_result in enumerate(results[0]): +... print(f"Object {i+1}: score={obj_result['score']:.3f}") +``` + +### Combining points and boxes + +For maximum precision, you can combine point and box prompts: + +```py +>>> # Box around an object + refinement points +>>> input_boxes = [[[100, 200, 350, 550]]] +>>> input_points = [[[[200, 300], [150, 250]]]] # Positive and negative points +>>> input_labels = [[[1, 0]]] + +>>> results = segmenter( +... image, +... input_points=input_points, +... input_labels=input_labels, +... input_boxes=input_boxes, +... multimask_output=False, +... ) +``` + +## Manual inference with model and processor + +While the pipeline is convenient, you may want more control over the inference process. Here's how to use the model and processor directly: + +```py +>>> from transformers import Sam2Processor, Sam2Model +>>> import torch + +>>> device = "cuda" if torch.cuda.is_available() else "cpu" +>>> model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large").to(device) +>>> processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-large") +``` + +Load an image: + +```py +>>> url = "http://images.cocodataset.org/val2017/000000077595.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") +``` + +Prepare inputs and run inference: + +```py +>>> input_points = [[[[450, 600]]]] +>>> input_labels = [[[1]]] + +>>> inputs = processor( +... images=image, +... input_points=input_points, +... input_labels=input_labels, +... return_tensors="pt", +... ).to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> # Post-process masks to original image size +>>> masks = processor.post_process_masks( +... outputs.pred_masks.cpu(), +... inputs["original_sizes"], +... )[0] + +>>> print(f"Mask shape: {masks.shape}") # [num_objects, num_masks_per_object, height, width] +>>> print(f"IoU scores: {outputs.iou_scores}") +>>> # Results contain: +>>> # - masks: Segmentation masks (torch.Tensor) +>>> # - iou_scores: Quality predictions for each mask (torch.Tensor) +``` + +> [!TIP] +> **Pipeline vs Manual Output Format**: The pipeline returns a standardized format (list of lists of dicts with `score` and `mask`) for consistency across transformers. The processor's `post_process_masks()` returns raw tensors for more flexible post-processing. + +## Batch processing + +You can process multiple images efficiently by batching them together: + +```py +>>> cat_url = "http://images.cocodataset.org/val2017/000000077595.jpg" +>>> kitchen_url = "http://images.cocodataset.org/val2017/000000136466.jpg" +>>> images = [ +... Image.open(requests.get(cat_url, stream=True).raw).convert("RGB"), +... Image.open(requests.get(kitchen_url, stream=True).raw).convert("RGB"), +... ] + +>>> # Different prompts for each image +>>> input_points = [ +... [[[450, 600]]], # Cat image: single point +... [[[300, 250]]], # Kitchen image: single point +... ] +>>> input_labels = [[[1]], [[1]]] + +>>> inputs = processor( +... images=images, +... input_points=input_points, +... input_labels=input_labels, +... return_tensors="pt", +... ).to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"]) + +>>> for i, image_masks in enumerate(masks): +... print(f"Image {i+1}: {image_masks.shape[0]} object(s) segmented") +``` + +## Efficient multi-prompt inference + +When running multiple prompts on the same image, pre-compute image embeddings to avoid redundant computation: + +```py +>>> # Pre-process image and compute image embeddings once +>>> img_inputs = processor(images=image, return_tensors="pt").to(device) +>>> with torch.no_grad(): +... image_embeddings = model.get_image_features(pixel_values=img_inputs.pixel_values) + +>>> # Run multiple prompts efficiently +>>> point_prompts = [ +... [[[[450, 600]]]], # Point on left cat +... [[[[200, 300]]]], # Point on right cat +... [[[[150, 450]]]], # Point on couch +... ] +>>> all_results = [] + +>>> for points in point_prompts: +... labels = [[[1]]] +... prompt_inputs = processor( +... input_points=points, +... input_labels=labels, +... original_sizes=img_inputs["original_sizes"], +... return_tensors="pt", +... ).to(device) +... +... with torch.no_grad(): +... outputs = model( +... input_points=prompt_inputs["input_points"], +... input_labels=prompt_inputs["input_labels"], +... image_embeddings=image_embeddings, +... multimask_output=False, +... ) +... +... masks = processor.post_process_masks( +... outputs.pred_masks.cpu(), +... img_inputs["original_sizes"], +... )[0] +... all_results.append({"points": points, "masks": masks, "scores": outputs.iou_scores}) + +>>> print(f"Processed {len(all_results)} prompts efficiently") +``` + +This approach significantly speeds up inference when testing multiple points on the same image! + +## Advanced usage: Interactive segmentation + +PVS is ideal for interactive applications where users click to segment objects. Here's a simple iterative refinement workflow: + +```py +>>> def interactive_segment(image, positive_points, negative_points=None): +... """Segment an object with interactive point clicks.""" +... all_points = positive_points + (negative_points or []) +... labels = [1] * len(positive_points) + [0] * len(negative_points or []) +... +... input_points = [[all_points]] +... input_labels = [[labels]] +... +... results = segmenter( +... image, +... input_points=input_points, +... input_labels=input_labels, +... multimask_output=False, +... ) +... return results[0][0] + +>>> # Simulated interactive clicks +>>> # Initial click +>>> result = interactive_segment(image, positive_points=[[450, 600]]) +>>> print(f"Initial segmentation score: {result['score']:.3f}") + +>>> # Refine with additional positive click +>>> result = interactive_segment(image, positive_points=[[450, 600], [380, 550]]) +>>> print(f"Refined segmentation score: {result['score']:.3f}") + +>>> # Further refine with negative click to exclude background +>>> result = interactive_segment( +... image, +... positive_points=[[450, 600], [380, 550]], +... negative_points=[[300, 400]], +... ) +>>> print(f"Final segmentation score: {result['score']:.3f}") +``` + +This demonstrates how PVS can be used in interactive tools where users iteratively refine segmentation masks by adding positive and negative clicks! diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e4eded9455db..4f62e255cd7d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1627,16 +1627,16 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ] ) - +# Model for Promptable Visual Segmentation mapping +# facebook/sam2.1-hiera-large checkpoint uses sam2_video config but can be used for single-image inference MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( [ - # Model for Promptable Visual Segmentation mapping - ("sam3_tracker", "Sam3TrackerModel"), + ("edgetam", "EdgeTamModel"), + ("sam", "SamModel"), ("sam2", "Sam2Model"), - # facebook/sam2.1-hiera-large checkpoint uses sam2_video config but can be used for single-image inference ("sam2_video", "Sam2Model"), - ("sam", "SamModel"), - ("edgetam", "EdgeTamModel"), + ("sam3_tracker", "Sam3TrackerModel"), + ("sam3_video", "Sam3TrackerModel"), ] ) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 59bd41213aaa..136f590682d0 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -303,7 +303,7 @@ "promptable-visual-segmentation": { "impl": PromptableVisualSegmentationPipeline, "pt": (AutoModelForPromptableVisualSegmentation,) if is_torch_available() else (), - "default": {"model": ("facebook/sam3", "main")}, + "default": {"model": ("facebook/sam3", "3c879f3")}, "type": "multimodal", }, "any-to-any": { diff --git a/tests/models/edgetam/test_modeling_edgetam.py b/tests/models/edgetam/test_modeling_edgetam.py index 36d0f3ac21fd..03d59193fa63 100644 --- a/tests/models/edgetam/test_modeling_edgetam.py +++ b/tests/models/edgetam/test_modeling_edgetam.py @@ -232,7 +232,13 @@ class EdgeTamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) all_model_classes = (EdgeTamModel,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": EdgeTamModel, "mask-generation": EdgeTamModel} if is_torch_available() else {} + { + "feature-extraction": EdgeTamModel, + "mask-generation": EdgeTamModel, + "promptable-visual-segmentation": EdgeTamModel, + } + if is_torch_available() + else {} ) test_resize_embeddings = False diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 1c8957e3ca7e..ac2484d3c149 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -504,7 +504,13 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (SamModel,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": SamModel, "mask-generation": SamModel} if is_torch_available() else {} + { + "feature-extraction": SamModel, + "mask-generation": SamModel, + "promptable-visual-segmentation": SamModel, + } + if is_torch_available() + else {} ) test_resize_embeddings = False diff --git a/tests/models/sam2/test_modeling_sam2.py b/tests/models/sam2/test_modeling_sam2.py index 300f872ea082..14f50f3e76db 100644 --- a/tests/models/sam2/test_modeling_sam2.py +++ b/tests/models/sam2/test_modeling_sam2.py @@ -458,7 +458,13 @@ class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Sam2Model,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": Sam2Model, "mask-generation": Sam2Model} if is_torch_available() else {} + { + "feature-extraction": Sam2Model, + "mask-generation": Sam2Model, + "promptable-visual-segmentation": Sam2Model, + } + if is_torch_available() + else {} ) test_resize_embeddings = False diff --git a/tests/models/sam3_tracker/test_modeling_sam3_tracker.py b/tests/models/sam3_tracker/test_modeling_sam3_tracker.py index 3a4af37ea24c..393163199c70 100644 --- a/tests/models/sam3_tracker/test_modeling_sam3_tracker.py +++ b/tests/models/sam3_tracker/test_modeling_sam3_tracker.py @@ -242,7 +242,13 @@ class Sam3TrackerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC all_model_classes = (Sam3TrackerModel,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": Sam3TrackerModel, "mask-generation": Sam3TrackerModel} if is_torch_available() else {} + { + "feature-extraction": Sam3TrackerModel, + "mask-generation": Sam3TrackerModel, + "promptable-visual-segmentation": Sam3TrackerModel, + } + if is_torch_available() + else {} ) test_resize_embeddings = False diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index ea03cfe4e48d..22c708c064d3 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -109,6 +109,7 @@ class DecoratedItem: "SmolLM3Config", "Gemma3nVisionConfig", "Llama4Processor", + "PromptableVisualSegmentationPipeline", # Deprecated "InputExample", "InputFeatures", From d3fa667a1952d84f7128ea9c1f55fc614b01b184 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 30 Jan 2026 23:50:51 +0000 Subject: [PATCH 4/7] style --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index efee915394c4..4fe5218e0f41 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -155,6 +155,7 @@ "PipedPipelineDataFormat", "Pipeline", "PipelineDataFormat", + "PromptableVisualSegmentationPipeline", "QuestionAnsweringPipeline", "TableQuestionAnsweringPipeline", "TextClassificationPipeline", @@ -162,7 +163,6 @@ "TextToAudioPipeline", "TokenClassificationPipeline", "VideoClassificationPipeline", - "PromptableVisualSegmentationPipeline", "VisualQuestionAnsweringPipeline", "ZeroShotAudioClassificationPipeline", "ZeroShotClassificationPipeline", From f53bf7b200403c7987f4ab0fdbfe07accb7fcb57 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Sat, 31 Jan 2026 00:03:58 +0000 Subject: [PATCH 5/7] fix metadata --- src/transformers/pipelines/__init__.py | 2 ++ utils/update_metadata.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 136f590682d0..d0dc5b1a7599 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -446,6 +446,8 @@ def pipeline(task: Literal["mask-generation"], model: str | PreTrainedModel | No @overload def pipeline(task: Literal["object-detection"], model: str | PreTrainedModel | None = None, config: str | PreTrainedConfig | None = None, tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None = None, feature_extractor: str | PreTrainedFeatureExtractor | None = None, image_processor: str | BaseImageProcessor | None = None, processor: str | ProcessorMixin | None = None, revision: str | None = None, use_fast: bool = True, token: str | bool | None = None, device: int | str | torch.device | None = None, device_map: str | dict[str, int | str] | None = None, dtype: str | torch.dtype | None = "auto", trust_remote_code: bool | None = None, model_kwargs: dict[str, Any] | None = None, pipeline_class: Any | None = None, **kwargs: Any) -> ObjectDetectionPipeline: ... @overload +def pipeline(task: Literal["promptable-visual-segmentation"], model: str | PreTrainedModel | None = None, config: str | PreTrainedConfig | None = None, tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None = None, feature_extractor: str | PreTrainedFeatureExtractor | None = None, image_processor: str | BaseImageProcessor | None = None, processor: str | ProcessorMixin | None = None, revision: str | None = None, use_fast: bool = True, token: str | bool | None = None, device: int | str | torch.device | None = None, device_map: str | dict[str, int | str] | None = None, dtype: str | torch.dtype | None = "auto", trust_remote_code: bool | None = None, model_kwargs: dict[str, Any] | None = None, pipeline_class: Any | None = None, **kwargs: Any) -> PromptableVisualSegmentationPipeline: ... +@overload def pipeline(task: Literal["question-answering"], model: str | PreTrainedModel | None = None, config: str | PreTrainedConfig | None = None, tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None = None, feature_extractor: str | PreTrainedFeatureExtractor | None = None, image_processor: str | BaseImageProcessor | None = None, processor: str | ProcessorMixin | None = None, revision: str | None = None, use_fast: bool = True, token: str | bool | None = None, device: int | str | torch.device | None = None, device_map: str | dict[str, int | str] | None = None, dtype: str | torch.dtype | None = "auto", trust_remote_code: bool | None = None, model_kwargs: dict[str, Any] | None = None, pipeline_class: Any | None = None, **kwargs: Any) -> QuestionAnsweringPipeline: ... @overload def pipeline(task: Literal["table-question-answering"], model: str | PreTrainedModel | None = None, config: str | PreTrainedConfig | None = None, tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None = None, feature_extractor: str | PreTrainedFeatureExtractor | None = None, image_processor: str | BaseImageProcessor | None = None, processor: str | ProcessorMixin | None = None, revision: str | None = None, use_fast: bool = True, token: str | bool | None = None, device: int | str | torch.device | None = None, device_map: str | dict[str, int | str] | None = None, dtype: str | torch.dtype | None = "auto", trust_remote_code: bool | None = None, model_kwargs: dict[str, Any] | None = None, pipeline_class: Any | None = None, **kwargs: Any) -> TableQuestionAnsweringPipeline: ... diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 231380df8f58..cae32e39f9ba 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -74,6 +74,11 @@ "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForZeroShotObjectDetection", ), + ( + "promptable-visual-segmentation", + "MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING_NAMES", + "AutoModelForPromptableVisualSegmentation", + ), ("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"), ("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"), ("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"), From 7393443fd9bb06fbfaf8522e03396077c2a4bd6a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Sat, 31 Jan 2026 00:14:18 +0000 Subject: [PATCH 6/7] update tree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 22d127edda1a..5a1129b09fff 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -309,6 +309,8 @@ title: Image Feature Extraction - local: tasks/mask_generation title: Mask Generation + - local: tasks/promptable_visual_segmentation + title: Promptable Visual Segmentation - local: tasks/keypoint_detection title: Keypoint detection - local: tasks/knowledge_distillation_for_image_classification From ca36b36d957f9d5c85ca2cea7a7c64592c03d122 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Sat, 31 Jan 2026 00:22:11 +0000 Subject: [PATCH 7/7] fix after review --- .../promptable_visual_segmentation.py | 3 +- ...ipelines_promptable_visual_segmentation.py | 57 +++++++++++++++++-- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/src/transformers/pipelines/promptable_visual_segmentation.py b/src/transformers/pipelines/promptable_visual_segmentation.py index dac53f39dfb2..10e70037a48b 100644 --- a/src/transformers/pipelines/promptable_visual_segmentation.py +++ b/src/transformers/pipelines/promptable_visual_segmentation.py @@ -389,5 +389,4 @@ def postprocess(self, model_outputs, mask_threshold=0.0, top_k=None): final_results.append(image_results) - # If single image, return as list with one element (for consistency) - return final_results if batch_size > 1 or isinstance(pred_masks, (list, tuple)) else final_results + return final_results diff --git a/tests/pipelines/test_pipelines_promptable_visual_segmentation.py b/tests/pipelines/test_pipelines_promptable_visual_segmentation.py index a42c3ac7eb3c..5c118ff470ca 100644 --- a/tests/pipelines/test_pipelines_promptable_visual_segmentation.py +++ b/tests/pipelines/test_pipelines_promptable_visual_segmentation.py @@ -15,31 +15,78 @@ import unittest from transformers import ( + MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING, + PromptableVisualSegmentationPipeline, Sam2Model, Sam2Processor, SamModel, SamProcessor, - is_torch_available, is_vision_available, pipeline, ) -from transformers.testing_utils import require_torch, require_vision, slow +from transformers.testing_utils import is_pipeline_test, require_torch, require_vision, slow -if is_torch_available(): - pass - if is_vision_available(): import requests from PIL import Image +@is_pipeline_test @require_torch @require_vision class PromptableVisualSegmentationPipelineTests(unittest.TestCase): + model_mapping = ( + dict(list(MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING.items())) + if MODEL_FOR_PROMPTABLE_VISUAL_SEGMENTATION_MAPPING + else [] + ) + # Test image URLs test_image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" + def get_test_pipeline( + self, + model, + tokenizer=None, + image_processor=None, + feature_extractor=None, + processor=None, + dtype="float32", + ): + segmenter = PromptableVisualSegmentationPipeline( + model=model, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + image_processor=image_processor, + processor=processor, + dtype=dtype, + ) + examples = [ + { + "image": "./tests/fixtures/tests_samples/COCO/000000039769.png", + "input_points": [[[[450, 600]]]], + "input_labels": [[[1]]], + }, + { + "image": "./tests/fixtures/tests_samples/COCO/000000039769.png", + "input_boxes": [[[100, 200, 350, 550]]], + }, + ] + return segmenter, examples + + def run_pipeline_test(self, segmenter, examples): + for example in examples: + result = segmenter(**example) + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) + # Each result should be a list of objects (for multiple images) + for obj_list in result: + self.assertIsInstance(obj_list, list) + for obj in obj_list: + self.assertIn("mask", obj) + self.assertIn("score", obj) + def get_test_image(self): """Helper to load test image.""" return Image.open(requests.get(self.test_image_url, stream=True).raw).convert("RGB")