From 68e67821e8f279500ddaa3a46a5af708eecd8af5 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 10 May 2023 09:25:52 -0400 Subject: [PATCH 1/2] Update Image segmentation description --- src/transformers/tools/image_segmentation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/tools/image_segmentation.py b/src/transformers/tools/image_segmentation.py index 1a43afbd18d27..42cb77127be1c 100644 --- a/src/transformers/tools/image_segmentation.py +++ b/src/transformers/tools/image_segmentation.py @@ -28,10 +28,9 @@ class ImageSegmentationTool(PipelineTool): description = ( - "This is a tool that creates a segmentation mask identifiying elements inside an image according to a prompt. " - "It takes two arguments named `image` which should be the original image, and `prompt` which should be a text " - "describing the elements what should be identified in the segmentation mask. The tool returns the mask as a " - "black-and-white image." + "This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image." + "It takes two arguments named `image` which should be the original image, and `label` which should be a text " + "describing the elements what should be identified in the segmentation mask. The tool returns the mask." ) default_checkpoint = "CIDAS/clipseg-rd64-refined" name = "image_segmenter" From e7b0a4ab417dfd9ed4bc379eaf913b6a1a930097 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 10 May 2023 09:27:43 -0400 Subject: [PATCH 2/2] prompt -> label --- src/transformers/tools/image_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/tools/image_segmentation.py b/src/transformers/tools/image_segmentation.py index 42cb77127be1c..4471b84905278 100644 --- a/src/transformers/tools/image_segmentation.py +++ b/src/transformers/tools/image_segmentation.py @@ -43,9 +43,9 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) super().__init__(*args, **kwargs) - def encode(self, image: "Image", prompt: str): + def encode(self, image: "Image", label: str): self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]} - return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt") + return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt") def forward(self, inputs): with torch.no_grad():