https://github.com/facebookresearch/segment-anything

1 - Provide a Copy of the License: You must include a copy of the Apache License 2.0 in your project. This can be done by including a LICENSE file in the root of your project directory.

2 - State Changes: If you modify any files from the original project, you need to include a prominent notice stating that you have changed the files. This helps in distinguishing your work from the original.

3 - Preserve Notices: You must retain all the copyright, patent, trademark, and attribution notices from the source form of the original work in your derived work. This includes any notices contained in a NOTICE file if one exists.

4 - Include the NOTICE File: If the original project includes a NOTICE file, you need to include a readable copy of the attribution notices contained within it in at least one of the following places: within a NOTICE text file distributed as part of your project, within the source code or documentation, or within a display generated by your project, if applicable.

5 - Add Your Own Notices: You can add your own copyright statements and additional notices, as long as they do not contradict the terms of the Apache License 2.0.

6 - State Compliance: You must state that your use of the original project's files is in compliance with the Apache License 2.0. This can typically be done in your project's documentation or README file.

Here's a brief example of how you can include the Apache License 2.0 in your project:

1 - LICENSE File: Create a LICENSE file in the root of your project directory with the full text of the Apache License 2.0.

2 - Notice of Changes: If you modify any files, add a comment at the top of each modified file:  
\# Modified by [Your Name] on [Date]  
\# Original file available at [URL to original file]

3 - Retain Notices: Ensure that all original notices are preserved in your project.

4 - README or Documentation: Include a statement in your README file that your project includes files licensed under the Apache License 2.0 and provide a link to the license.

By following these steps, you ensure that you are respecting the terms of the Apache License 2.0 while developing your pip library using the content from the GitHub repository.

In [11]:
import cv2
import numpy as np
from skimage.morphology import flood_fill
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu" # not enough GPU memory

from helper_exceptions import samPromptGenerationQuitException
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry

In [22]:
class samSegmentator():
    def __init__(self, sam, device:str="cpu"):
        """initializing samSegmentator object
        """
        self.sam = sam
        self.device = device
        self.sam_setted_image = None
        self.DRAW_BG = {"color" : [0,0,255], "val" : 0} # right click
        self.DRAW_FG = {"color" : [0,255,0], "val" : 1} # left click
        self.reset()

    def reset(self):
        """resetting samSegmentator object variables
        """
        self.paint_dict = None
        self.clicked = False                # flag for drawing action
        self.currently_drawing_rect = False # flag for rectangle action
        self.display_rects = []             # selected rectangles for displaying
        self.prompt_rects = []              # selected rectangles for prompting
        self.prompt_coords = []             # selected coords for prompting
        self.prompt_labels = []             # selected labels for prompting

    def annotation_event_listener(self, event, x:int, y:int, flags, param):
        """mouse callbacks for annotation types

        Args:
            event (opencv event): mouse event to detect
            x (int): column coordinate of mouse
            y (int): row coordinate of mouse
            flags (opencv flags): flags
            param (dictionary): parameters
        """
        # rectangle selection with middle button
        if event == cv2.EVENT_MBUTTONDOWN:
            self.currently_drawing_rect = True
            self.ix, self.iy = x,y
        elif event == cv2.EVENT_MOUSEMOVE:
            if self.currently_drawing_rect:
                self.image = self.altered.copy()
                for r in self.display_rects:
                    cv2.rectangle(self.image, (r[0], r[1]), (r[2], r[3]), [255,0,0], 1)
                cv2.rectangle(self.image, (self.ix, self.iy), (x, y), [255,0,0], 1)
        elif event == cv2.EVENT_MBUTTONUP:
            self.currently_drawing_rect = False
            for r in self.display_rects:
                cv2.rectangle(self.altered, (r[0], r[1]), (r[2], r[3]), [255,0,0], 1)
            cv2.rectangle(self.altered, (self.ix, self.iy), (x, y), [255,0,0], 1)
            self.display_rects.append((self.ix, self.iy, x, y))
            x = max(x, 0) # clips negatives to zero
            x = min(max(x, 0), self.image.shape[1]) # clips out of bounds values to max value
            y = max(y, 0) # clips negatives to zero
            y = min(max(y, 0), self.image.shape[0]) # clips out of bounds values to max value
            self.prompt_rects.append(np.array([self.ix, self.iy, x, y]))
            self.altered = self.image.copy()
        if self.currently_drawing_rect:
            return

        # annotation type selection with left/right click
        if event == cv2.EVENT_LBUTTONDOWN:
            self.clicked = True
            self.paint_dict = self.DRAW_FG
        elif event == cv2.EVENT_RBUTTONDOWN:
            self.clicked = True
            self.paint_dict = self.DRAW_BG
        elif self.clicked and (event == cv2.EVENT_LBUTTONUP or event == cv2.EVENT_RBUTTONUP):
            self.clicked = False
            cv2.circle(self.altered, (x,y), 3, self.paint_dict["color"], -1)
            self.prompt_coords.append([x,y])
            self.prompt_labels.append(self.paint_dict["val"])
            self.image = self.altered.copy()
   
    def generate_prompt(self, image):
        """function to interactively generate sam prompt

        Args:
            image (numpy.ndarray): original image

        Returns:
            list: boxes, coords and labels for prompting
        """
        self.reset()
        self.original = image.copy()
        self.image = self.original.copy()
        self.altered = self.original.copy()

        cv2.namedWindow("Annotations", flags= cv2.WINDOW_AUTOSIZE | cv2.WINDOW_KEEPRATIO | cv2.WINDOW_GUI_NORMAL)
        cv2.setMouseCallback("Annotations", self.annotation_event_listener)
        
        while True:
            cv2.imshow("Annotations", self.image)
            key = cv2.waitKey(1)

            # key bindings
            if key == ord("q"):
                cv2.destroyWindow("annotation")
                raise(samPromptGenerationQuitException("samSegmentator received key q for quitting"))
            if key == ord(" "):
                cv2.destroyWindow("Annotations")
                return self.prompt_rects, self.prompt_coords, self.prompt_labels
            elif key == ord("r"): # reset everything
                self.reset()
                self.image = self.altered = self.original.copy()

    def get_label_from_sam_auto_output(self, sam_auto_output):
        """creates labeled image from sam output

        Args:
            sam_auto_output (list): list of informations about found masks

        Returns:
            numpy.ndarray: labeled image
        """
        # get al masks and mark each of them with an unique id
        masks = [x["segmentation"] for x in sam_auto_output]
        labeled_image = np.zeros(self.original.shape[:2], dtype=np.int16)
        for e,mask in enumerate(masks):
            labeled_image[mask] = e

        # segment the unlabeled pixels
        segment_pixels = np.where(labeled_image == 0)
        segment_id = labeled_image.max()+1
        while len(segment_pixels[0]) != 0: # while image has pixels with value 0 which means non-labeled segment
            ri, ci = segment_pixels[0][0], segment_pixels[1][0] # get a segment pixel
            
            labeled_image = flood_fill(labeled_image, (ri, ci), segment_id, connectivity=1, in_place=True) # floodfill segment
            extracted_segment = np.array(labeled_image == labeled_image[ri][ci]).astype(np.int16) # extract only segment as binary
            extracted_segment = cv2.dilate(extracted_segment, np.ones((3,3)), iterations=1) # expand segment borders by one pixel to remove edges
            np.putmask(labeled_image, extracted_segment != 0, segment_id) # overwrite expanded segment to labeled_image

            segment_id = segment_id + 1
            segment_pixels = np.where(labeled_image == 0)

        return labeled_image
    
    def get_label_from_sam_with_prompt_output_mask(self, sam_with_prompt_output_mask):
        """creates labeled image from sam output

        Args:
            sam_with_prompt_output_mask (numpy.ndarray): binary merged mask from sam output

        Returns:
            numpy.ndarray: labeled image
        """
        labeled_image = np.zeros(self.original.shape[:2], dtype=np.int16)
        # mark masked pixels with -1
        labeled_image[sam_with_prompt_output_mask] = -1

        # label the masked pixels starting from 1
        segment_pixels = np.where(labeled_image == -1)
        segment_id = 1
        while len(segment_pixels[0]) != 0: # while image has pixels with value 0 which means non-labeled segment
            ri, ci = segment_pixels[0][0], segment_pixels[1][0] # get a segment pixel
            
            labeled_image = flood_fill(labeled_image, (ri, ci), segment_id, connectivity=1, in_place=True) # floodfill segment
            extracted_segment = np.array(labeled_image == labeled_image[ri][ci]).astype(np.int16) # extract only segment as binary
            extracted_segment = cv2.dilate(extracted_segment, np.ones((3,3)), iterations=1) # expand segment borders by one pixel to remove edges
            np.putmask(labeled_image, extracted_segment != 0, segment_id) # overwrite expanded segment to labeled_image

            segment_id = segment_id + 1
            segment_pixels = np.where(labeled_image == -1)

        # label the not masked pixels from last id
        segment_pixels = np.where(labeled_image == 0)
        while len(segment_pixels[0]) != 0: # while image has pixels with value 0 which means non-labeled segment
            ri, ci = segment_pixels[0][0], segment_pixels[1][0] # get a segment pixel
            
            labeled_image = flood_fill(labeled_image, (ri, ci), segment_id, connectivity=1, in_place=True) # floodfill segment
            extracted_segment = np.array(labeled_image == labeled_image[ri][ci]).astype(np.int16) # extract only segment as binary
            extracted_segment = cv2.dilate(extracted_segment, np.ones((3,3)), iterations=1) # expand segment borders by one pixel to remove edges
            np.putmask(labeled_image, extracted_segment != 0, segment_id) # overwrite expanded segment to labeled_image

            segment_id = segment_id + 1
            segment_pixels = np.where(labeled_image == 0)

        return labeled_image

    def sam_predict(self, image):
        """segmentation using sam model

        Args:
            image (numpy.ndarray): image to segment

        Returns:
            numpy.ndarray: segmented image
        """
        self.original = image.copy()

        if type(self.sam) == SamPredictor:
            # prompt generation
            prompt_rects, prompt_coords, prompt_labels = self.generate_prompt(image)

            # assigning coords and labels to their boxes
            # since multiple boxes and multiple coords/labels are not supported each box will be passed with
            # its own related coords and labels individualy
            coords_list = [np.array([pc for pc in prompt_coords 
                        if ((r[0]<=pc[0]<=r[2]) and (r[1]<=pc[1]<=r[3]))])for r in prompt_rects]
            label_list = [np.array([prompt_labels[e] for e,pc in enumerate(prompt_coords)
                        if ((r[0]<=pc[0]<=r[2]) and (r[1]<=pc[1]<=r[3]))]) for r in prompt_rects]

            # set the image
            if not np.array_equal(self.sam_setted_image, image):
                self.sam.set_image(image)
                self.sam_setted_image = image.copy()

            # segment each pass individualy with its prompt
            masks = []
            for (c, l, b) in zip(coords_list, label_list, prompt_rects):
                if len(c) == 0:
                    c = None
                if len(l) == 0:
                    l = None
                so = self.sam.predict(point_coords=c, point_labels=l, box=b)
                masks.append(so[0][0])
            # generate one mask and get labels
            final_mask = np.logical_or.reduce(masks)
            sam_segment = self.get_label_from_sam_with_prompt_output_mask(final_mask)
        
        elif type(self.sam) == SamAutomaticMaskGenerator:
            # get mask and label the segments
            sam_auto_output = self.sam.generate(image)
            sam_segment = self.get_label_from_sam_auto_output(sam_auto_output)
        
        return sam_segment

In [23]:
image_path = "/home/mericdemirors/Pictures/araba/araba.jpg"
image = cv2.imread(image_path)

In [24]:
sam_config = sam_model_registry["vit_b"](checkpoint="/home/mericdemirors/Downloads/sam_vit_b_01ec64.pth").to(device)
sam_with_prompt = SamPredictor(sam_config)
sam_generator = samSegmentator(sam_with_prompt, device)
segment = sam_generator.sam_predict(image)

In [17]:
sam_config = sam_model_registry["vit_b"](checkpoint="/home/mericdemirors/Downloads/sam_vit_b_01ec64.pth").to(device)
sam_auto = SamAutomaticMaskGenerator(sam_config)
sam_generator = samSegmentator(sam_auto, device)
segment = sam_generator.sam_predict(image)

In [25]:
cv2.imshow("u", segment.astype(np.uint8)*(255//segment.max()))
cv2.waitKey(0)
cv2.destroyAllWindows()