Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: Segment Anything #5829

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 31 additions & 45 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from typing import Dict, List, Literal, Union
from typing import List, Literal, Union

import cv2
import numpy as np
Expand All @@ -18,10 +18,9 @@
MLSDdetector,
NormalBaeDetector,
PidiNetDetector,
SamDetector,
ZoeDetector,
)
from controlnet_aux.util import HWC3, ade_palette
from controlnet_aux.util import HWC3
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator

Expand All @@ -39,6 +38,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.segment_anything.sam_image_predictor import SAMImagePredictor

from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output

Expand Down Expand Up @@ -497,48 +497,6 @@ def run_processor(self, img):
return processed_image


@invocation(
"segment_anything_processor",
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.2.1",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""

def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
)
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image


class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
h, w = anns[0]["segmentation"].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann["segmentation"]
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)


@invocation(
"color_map_image_processor",
title="Color Map Processor",
Expand Down Expand Up @@ -619,3 +577,31 @@ def run_processor(self, image: Image.Image):
resolution=self.image_resolution,
)
return processed_image


SEGMENT_ANYTHING_MODEL_TYPES = Literal["small", "medium", "large", "small_hq", "medium_hq", "large_hq", "mobile"]


@invocation(
"segment_anything_image_processor",
title="Segment Anything Image Processor",
tags=["controlnet", "segment anything", "mask"],
category="controlnet",
version="1.0.0",
)
class SegmentAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a mask of the text provided using Facebook's Segment Anything"""

model_type: SEGMENT_ANYTHING_MODEL_TYPES = InputField(default="small", description="SAM Model")
x_coordinate: int = InputField(default=0, ge=0, description="X-coordinate of your subject")
y_coordinate: int = InputField(default=0, ge=0, description="Y-coordinate of your subject")
background: bool = InputField(default=False, description="Object to mask is in the background")
invert: bool = InputField(default=False, description="Invert the generated mask")

def run_processor(self, image: Image.Image):
sam_predictor = SAMImagePredictor()
sam_predictor.load_model(self.model_type)
mask = sam_predictor(
image, background=self.background, position=(self.x_coordinate, self.y_coordinate), invert=self.invert
)
return mask
15 changes: 15 additions & 0 deletions invokeai/backend/image_util/segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license provided at https://github.com/facebookresearch/segment-anything

from invokeai.backend.image_util.segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator # noqa F401
from invokeai.backend.image_util.segment_anything.build_sam import ( # noqa F401
build_sam,
build_sam_vit_b,
build_sam_vit_h,
build_sam_vit_l,
sam_model_registry,
)
from invokeai.backend.image_util.segment_anything.build_sam_baseline import sam_model_registry_baseline

Check failure on line 14 in invokeai/backend/image_util/segment_anything/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

invokeai/backend/image_util/segment_anything/__init__.py:14:77: F401 `invokeai.backend.image_util.segment_anything.build_sam_baseline.sam_model_registry_baseline` imported but unused
from invokeai.backend.image_util.segment_anything.predictor import SamPredictor # noqa F401
Loading
Loading