In [None]:
!pip install -q ultralytics gradio

In [None]:
import gradio as gr
from ultralytics import SAM
import torch
import numpy as np
import time
import cv2
import warnings

warnings.filterwarnings("ignore", "UsingGradioCache")

In [None]:
print("‚è≥ Loading SAM model...")
start_load = time.time()
model = SAM("sam_b.pt")
global_load_time = time.time() - start_load
print(f"‚úÖ Model loaded in {global_load_time:.2f}s")

In [None]:
def create_isolated_image(image, mask):
    rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
    rgba_image[:, :, 3] = 0
    rgba_image[mask, 3] = 255
    return rgba_image

In [None]:
def create_overlay(image, mask_tensor, color=[0, 255, 0], alpha=0.5):
    if image.shape[2] == 3:
        image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    else:
        image_bgr = image.copy()

    mask_np = mask_tensor.cpu().numpy().astype(np.uint8) * 255

    colored_mask = np.zeros_like(image_bgr, dtype=np.uint8)
    colored_mask[mask_np > 0] = color

    overlay = cv2.addWeighted(colored_mask, alpha, image_bgr, 1 - alpha, 0)

    overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)

    return overlay_rgb

In [None]:
def segment_with_point(original_image, evt: gr.SelectData):
    point = [evt.index[0], evt.index[1]]

    print(f"\n‚è≥ Segmenting at point: {point}")

    start_inference = time.time()

    results = model(original_image, points=point, labels=[1])

    inference_time = time.time() - start_inference
    print(f"‚úÖ Inference complete in {inference_time:.2f}s")

    if not results or not results[0].masks:
        print("‚ö†Ô∏è No objects found at this point.")
        perf_string = (
            f"Model Load: {global_load_time:.2f}s\n"
            f"Inference: {inference_time:.2f}s\n"
            "Status: No object found."
        )
        return original_image, perf_string

    best_mask = results[0].masks.data[0]

    annotated_image = create_overlay(original_image, best_mask, color=[0, 255, 0], alpha=0.6)

    click_point_cv2 = (point[0], point[1])
    cv2.circle(annotated_image, click_point_cv2, radius=8, color=(255, 0, 0), thickness=-1, lineType=cv2.LINE_AA)
    cv2.circle(annotated_image, click_point_cv2, radius=10, color=(255, 255, 255), thickness=2, lineType=cv2.LINE_AA)
    isolated_object = create_isolated_image(original_image, best_mask)
    perf_string = (
        f"Model Load Time: {global_load_time:.2f}s\n"
        f"Inference Time: {inference_time:.2f}s\n"
        f"Point Prompt: {point}"
    )

    return annotated_image,isolated_object, perf_string

In [None]:
def load_image_to_ui(image_upload):
    print(f"\nüñºÔ∏è New image uploaded. Shape: {image_upload.shape}")
    return image_upload, image_upload,None,None

In [None]:

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # üöÄ Interactive Segment Anything Model (SAM) Demo

        This app uses the `ultralytics` library to run **SAM (sam_b.pt)**.

        ### How to Use:
        1.  Upload an image using the panel on the left.
        2.  The image will appear on the right. **Click on any object** in the right-hand panel.
        3.  The model will generate a segmentation mask for the object you clicked!
        """
    )

    image_state = gr.State()

    with gr.Row():
        with gr.Column(scale=1):
            image_upload = gr.Image(type="numpy", label="1. Upload Image", sources=['upload', 'clipboard'])
            performance_text = gr.Textbox(label="üìä Performance & Stats", lines=4, interactive=False)

            image_isolated = gr.Image(type="numpy", label="3. Isolated Object (PNG)", format="png")
            gr.Markdown("*(Right-click or long-press to save the transparent PNG)*")

        with gr.Column(scale=2):
            image_output = gr.Image(type="numpy", label="2. Click on an Object to Segment")

    image_upload.upload(
        fn=load_image_to_ui,
        inputs=[image_upload],
        outputs=[image_output, image_state, image_isolated, performance_text] # Clear all outputs
    )

    image_output.select(
        fn=segment_with_point,
        inputs=[image_state], # Use the original, clean image for segmentation
        outputs=[image_output, image_isolated, performance_text] # Update all outputs
    )
demo.launch(debug=True, share=True)