Zero-shot image segmentation, the ability to segment objects in an image without prior training on those specific objects, has become an exciting new frontier in computer vision. The recently released Segment Anything Model (SAM) from Anthropic has emerged as a powerful tool for tackling this challenge.

# Introduction to Image Segmentation

Image segmentation is a process in digital image processing and computer vision that involves dividing an image into multiple segments, regions, or objects. It is used to simplify and change the representation of an image to make it easier to analyze and extract features from.

# Setting up the Working Environment

In [1]:
!pip install transformers 
!pip install timm 
!pip install torchvision 



Next, we will define helper functions that we will use throughout this article to plot the boxes on the segmented parts of the images

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [3]:
#Visualizes a mask overlay on an image.
# The Matplotlib axis to plot the mask.
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3),
                                np.array([0.6])],
                               axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

# Draws a bounding box on an image.
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0),
                               w,
                               h, edgecolor='green',
                               facecolor=(0,0,0,0),
                               lw=2))
#Displays an image with multiple bounding boxes overlaid.
def show_boxes_on_image(raw_image, boxes):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

# Displays an image with points overlaid, categorized by labels (positive or negative points).
def show_points_on_image(raw_image, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    plt.axis('on')
    plt.show()

#Combines points and bounding boxes visualization on an image
def show_points_and_boxes_on_image(raw_image,
                                   boxes,
                                   input_points,
                                   input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

# Combines points and bounding boxes visualization on an image
def show_points_and_boxes_on_image(raw_image,
                                   boxes,
                                   input_points,
                                   input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

#  Visualizes points with positive and negative labels.
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0],
               pos_points[:, 1],
               color='green',
               marker='*',
               s=marker_size,
               edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0],
               neg_points[:, 1],
               color='red',
               marker='*',
               s=marker_size,
               edgecolor='white',
               linewidth=1.25)

#  Converts a Matplotlib figure into a PIL image.
def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

#Displays a mask over an image.
def show_mask_on_image(raw_image, mask, return_image=False):
    if not isinstance(mask, torch.Tensor):
      mask = torch.Tensor(mask)

    if len(mask.shape) == 4:
      mask = mask.squeeze()

    fig, axes = plt.subplots(1, 1, figsize=(15, 15))

    mask = mask.cpu().detach()
    axes.imshow(np.array(raw_image))
    show_mask(mask, axes)
    axes.axis("off")
    plt.show()

    if return_image:
      fig = plt.gcf()
      return fig2img(fig)

# Visualizes multiple masks generated by a model on an image.
def show_pipe_masks_on_image(raw_image, outputs):
  plt.imshow(np.array(raw_image))
  ax = plt.gca()
  for mask in outputs["masks"]:
      show_mask(mask, ax=ax, random_color=True)
  plt.axis("off")
  plt.show()

# Mask Generation with SAM

**The Segment Anything Model (SAM)** is an image segmentation model developed by Meta AI. It can identify the precise location of either specific objects or every object in an image. SAM was released in April 2023 and is open source under the Apache 2.0 license
https://segment-anything.com/

SAM produces high-quality object masks from input prompts such as points or boxes and can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks and demonstrates strong zero-shot performance on a variety of segmentation tasks.

In July 2024, Meta AI released Segment Anything 2 (SAM 2), which is reported to be 6 times more accurate than the original SAM model at image segmentation tasks.

The model >= https://huggingface.co/Zigeng/SlimSAM-uniform-77


In [4]:
#from transformers import pipeline
#generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)

In [5]:
from PIL import Image

In [6]:
raw_image = Image.open(r"F:\learnings projects\hugging face with youssef hosni\zero-shot image classificatio\dd76db4d5d6bc43560e1d822084dd7cf.png").convert("RGB")
resize_image = raw_image.resize((720, 375))

In [7]:
raw_image.show()

The final step is to apply the model to the image and then we will use the show_pipe_masks_on_image to draw the masks on the image.

In [8]:
#output = generator(resize_image,points_per_batch)

points_per_batch: This parameter controls the number of points (pixels or features) that the generator processes in a single batch during inference. It is commonly used to balance

#show_pipe_masks_on_image(raw_image, outputs) # put the two togther

# Faster Inference: Infer an Image and a Single Point

In [9]:
from transformers import AutoProcessor, AutoModelForMaskGeneration

processor = AutoProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
model = AutoModelForMaskGeneration.from_pretrained("Zigeng/SlimSAM-uniform-77") 

In [10]:
input_points = [[[400, 500]]]

We will pass the raw image of the points we need to segment and how we want the results to return. We will choose “pt” which refers to pytorch

In [11]:
inputs = processor(raw_image,input_points = input_points,return_tensors="pt")

Next, we will forward pass through the pre-trained SAM model, using the provided inputs we got from the processor, and retrieve the model's output. The torch.no_grad() context manager is used to disable gradient tracking, as the goal is inference rather than training.

In [12]:
import torch

with torch.no_grad():
    outputs = model(**inputs)

In [13]:
#Now let's predict the mask for the input image

predicted_masks = processor.image_processor.post_process_masks(
    outputs.pred_masks,
    inputs["original_sizes"],
    inputs["reshaped_input_sizes"]
)

In [14]:
# We can see that we got one mask which refers to the object we selected as the input.

len(predicted_masks)

1

In [15]:
#We can also print the predicted mask shape.

predicted_mask = predicted_masks[0]
predicted_mask.shape

torch.Size([1, 3, 900, 1600])

The pre-trained SAM model produces a single predicted mask with 3 channels and dimensions of 855 x 1300 pixels. Finally, let's print the iou scores

In [16]:
outputs.iou_scores

tensor([[[0.7597, 0.8831, 0.9271]]])

In [None]:
for i in range(3):
    show_mask_on_image(raw_image, predicted_mask[:, i])

  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)


notebook link =>https://www.kaggle.com/code/youssef19/zero-shot-image-segmentation-using-sam