In [None]:
colab = True
# install requirements
if not colab:
    !pip install torch torchvision opencv-python numpy
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install gradio==3.40.0

In [None]:
import os
import numpy as np
import cv2
import gradio as gr
import torch
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

models = {
    "vit_h": "sam_vit_h_4b8939.pth",
    "vit_l": "sam_vit_l_0b3195.pth",
    "vit_b": "sam_vit_b_01ec64.pth",
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(model_type: str = "vit_b" ):
    checkpoint = models[model_type]
    if not os.path.exists("./"+checkpoint):
        import urllib.request
        urllib.request.urlretrieve(f"https://dl.fbaipublicfiles.com/segment_anything/{checkpoint}", checkpoint)

    sam = sam_model_registry[model_type](checkpoint=checkpoint).to(device=DEVICE)
    return SamPredictor(sam), SamAutomaticMaskGenerator(sam)

predictor, mask_generator = load_model("vit_h")

# one object one click point

upload the input image then you click on the object you want to keep and it automatically generate an output image with only that object

In [None]:
def extract_object(image: np.ndarray, click_x: int, click_y: int):
    predictor.set_image(image)
    input_point = np.array([[click_x, click_y]])
    input_label = np.array([1])
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    C, H, W = masks.shape
    result_mask = np.zeros((H, W), dtype=bool)
    for j in range(C):
      result_mask |= masks[j, :, :]

    result_mask = result_mask.astype(np.uint8)
    # remove background
    alpha_channel = np.ones(result_mask.shape, dtype=result_mask.dtype) * 255
    alpha_channel[result_mask == 0] = 0
    result_image = cv2.merge((image, alpha_channel))
    return result_image

def extract_object_by_event(image: np.ndarray, evt: gr.SelectData):
    click_x, click_y = evt.index

    return extract_object(image, click_x, click_y)


def get_coords(evt: gr.SelectData):
    return evt.index[0], evt.index[1]


with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=600)
        output_img = gr.Image(label="Output image", height=600)

    input_img.select(get_coords, None, [coord_x, coord_y])
    input_img.select(extract_object_by_event, [input_img], output_img)

app.launch(inline=False, share=True, debug=True)

# one object multiple click points

The same as before except you can keep clicking on multiple parts of the desired object until you are satisfied.  Each click will add a point to the list selected points. You can Use the "Clear Points" button if you want to start over Otherwise you can click on the Segment object button to retrieve the object without the background.

In [None]:
def extract_object(image: np.ndarray, points_str: str):
    points = eval(points_str)
    if not points:
        return image

    predictor.set_image(image)
    input_points = np.array(points)
    input_labels = np.array([1] * len(points))
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True,
    )
    C, H, W = masks.shape
    result_mask = np.zeros((H, W), dtype=bool)
    for j in range(C):
        result_mask |= masks[j, :, :]

    result_mask = result_mask.astype(np.uint8)
    alpha_channel = np.ones(result_mask.shape, dtype=result_mask.dtype) * 255
    alpha_channel[result_mask == 0] = 0
    result_image = cv2.merge((image, alpha_channel))
    return result_image

def update_points(points, evt: gr.SelectData):
    points_list = eval(points)
    x, y = evt.index
    points_list.append([x, y])
    return str(points_list)

def clear_points():
    return "[]"

def segment_image(image, points):
    if points == "[]":
        return image, "No points selected. Please click on the image to add points."
    result = extract_object(image, points)
    return result, f"Segmentation complete with {len(eval(points))} points."

with gr.Blocks() as app:
    gr.Markdown("# Interactive Multi-Point Object selection")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=500)
        output_img = gr.Image(label="Output image", height=500)

    points_display = gr.Textbox(label="Selected Points", value="[]")
    clear_button = gr.Button("Clear Points")
    segment_button = gr.Button("Segment Object")
    result_text = gr.Textbox(label="Result")

    input_img.select(update_points, points_display, points_display)
    clear_button.click(clear_points, None, points_display)
    segment_button.click(segment_image, [input_img, points_display], [output_img, result_text])

app.launch(inline=False, share=True, debug=True)

# Multiple objects multiple clicks

This time we generates masks for all potential objects in the image to allow for  multi-object retrieval . Users can select as many objects as they want, and the segmentation will only include the objects they've explicitly selected. 

Users can click on different objects in the image to select them. Each click adds a point to the "Selected Points" list. When "Segment Objects" is clicked, the function finds all segments that contain the selected points. The result shows all selected objects segmented, with the background set to transparent.

In [None]:
def extract_objects(image: np.ndarray, points_str: str):
    points = eval(points_str)
    if not points:
        return image, "No points selected. Please click on the image to add points."

    # Generate all possible masks
    masks = mask_generator.generate(image)

    # Filter masks based on user-provided points
    selected_masks = []
    for point in points:
        x, y = point
        for mask in masks:
            if mask['segmentation'][y, x]:
                selected_masks.append(mask['segmentation'])
                break

    if not selected_masks:
        return image, "No objects found at the selected points. Try selecting different points."

    # Combine selected masks
    combined_mask = np.zeros(image.shape[:2], dtype=bool)
    for mask in selected_masks:
        combined_mask |= mask

    # Create a new RGBA image
    if image.shape[2] == 3:  # If the input is RGB
        result_image = np.concatenate([image, np.full((*image.shape[:2], 1), 255, dtype=np.uint8)], axis=-1)
    else:  # If the input is already RGBA
        result_image = image.copy()

    # Apply the combined mask to the image
    result_image[~combined_mask] = [0, 0, 0, 0]  # Set background to transparent

    return result_image, f"Segmentation complete. {len(selected_masks)} objects segmented."

def update_points(points_str, evt: gr.SelectData):
    points_list = eval(points_str)
    x, y = evt.index
    points_list.append([x, y])
    return str(points_list)

def clear_points():
    return "[]"

with gr.Blocks() as app:
    gr.Markdown("# Interactive Multi-Object background removal")

    with gr.Row():
        input_img = gr.Image(label="Input image", height=500)
        output_img = gr.Image(label="Output image", height=500)

    points_display = gr.Textbox(label="Selected Points", value="[]")
    clear_button = gr.Button("Clear Points")
    segment_button = gr.Button("Segment Objects")
    result_text = gr.Textbox(label="Result")

    input_img.select(update_points, points_display, points_display)
    clear_button.click(clear_points, None, points_display)
    segment_button.click(extract_objects, [input_img, points_display], [output_img, result_text])

app.launch(inline=False, share=True, debug=True)