Skip to content

Commit

Permalink
fix: seg mask bug fixes
Browse files Browse the repository at this point in the history
- Make sure the output more is correct so other nodes can interact with it
- Add option to invert mask if needed
  • Loading branch information
blessedcoolant committed Feb 29, 2024
1 parent 07bbe1e commit ca9b7ee
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
5 changes: 4 additions & 1 deletion invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,12 @@ class SegmentAnythingImageProcessorInvocation(ImageProcessorInvocation):
x_coordinate: int = InputField(default=0, ge=0, description="X-coordinate of your subject")
y_coordinate: int = InputField(default=0, ge=0, description="Y-coordinate of your subject")
background: bool = InputField(default=False, description="Object to mask is in the background")
invert: bool = InputField(default=False, description="Invert the generated mask")

def run_processor(self, image: Image.Image):
sam_predictor = SAMImagePredictor()
sam_predictor.load_model(self.model_type)
mask = sam_predictor(image, background=self.background, position=(self.x_coordinate, self.y_coordinate))
mask = sam_predictor(
image, background=self.background, position=(self.x_coordinate, self.y_coordinate), invert=self.invert
)
return mask
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal, Tuple

import numpy as np
from PIL import Image
from PIL import Image, ImageOps

from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.image_util.segment_anything import SamPredictor, sam_model_registry
Expand Down Expand Up @@ -72,7 +72,9 @@ def load_model(self, model_type: SEGMENT_ANYTHING_MODEL_TYPES = "small"):
)
sam_predictor = sam_predictor = SamPredictor(sam_model)

def __call__(self, image: Image.Image, background: bool = False, position: Tuple[int, int] = (0, 0)) -> Image.Image:
def __call__(
self, image: Image.Image, background: bool = False, position: Tuple[int, int] = (0, 0), invert: bool = False
) -> Image.Image:
global sam_predictor

input_image = np.array(image.convert("RGB")) if image.mode != "RGB" else np.array(image)
Expand All @@ -82,7 +84,9 @@ def __call__(self, image: Image.Image, background: bool = False, position: Tuple
if sam_predictor:
sam_predictor.set_image(input_image)
masks, _, _ = sam_predictor.predict(input_point, input_label)
mask = Image.fromarray(masks[0])
mask = Image.fromarray(masks[0]).convert("RGB")
if invert:
mask = ImageOps.invert(mask)
return mask
else:
return Image.new("RGB", (image.width, image.height), color="black")

0 comments on commit ca9b7ee

Please sign in to comment.