From 9d5b0e50f99251c897075689fa8b48da6537c02a Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 10 May 2023 15:36:15 +0200 Subject: [PATCH] Update Image segmentation description (#23261) * Update Image segmentation description * prompt -> label --- src/transformers/tools/image_segmentation.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/tools/image_segmentation.py b/src/transformers/tools/image_segmentation.py index 1a43afbd18d27..4471b84905278 100644 --- a/src/transformers/tools/image_segmentation.py +++ b/src/transformers/tools/image_segmentation.py @@ -28,10 +28,9 @@ class ImageSegmentationTool(PipelineTool): description = ( - "This is a tool that creates a segmentation mask identifiying elements inside an image according to a prompt. " - "It takes two arguments named `image` which should be the original image, and `prompt` which should be a text " - "describing the elements what should be identified in the segmentation mask. The tool returns the mask as a " - "black-and-white image." + "This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image." + "It takes two arguments named `image` which should be the original image, and `label` which should be a text " + "describing the elements what should be identified in the segmentation mask. The tool returns the mask." ) default_checkpoint = "CIDAS/clipseg-rd64-refined" name = "image_segmenter" @@ -44,9 +43,9 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) super().__init__(*args, **kwargs) - def encode(self, image: "Image", prompt: str): + def encode(self, image: "Image", label: str): self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]} - return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt") + return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt") def forward(self, inputs): with torch.no_grad():