In [9]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Automatically generating object masks with SAM

Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. This method was used to generate the dataset SA-1B.

The class `SamAutomaticMaskGenerator` implements this capability. It works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.

In [10]:
from IPython.display import display, HTML
display(HTML(
"""
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
"""
))

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [11]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg

    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

PyTorch version: 2.6.0+cu124
Torchvision version: 0.21.0+cu124
CUDA is available: True
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-su45qfi4
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-su45qfi4
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf


## Set-up

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("..")
import hashlib
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor


In [None]:
image1 = cv2.imread('/content/image_A.jpg')
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)

image2 = cv2.imread('/content/image_B.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image1)
plt.axis('off')
plt.show()

plt.figure(figsize=(20,20))
plt.imshow(image2)
plt.axis('off')
plt.show()

In [None]:

def load_image_rgb(path):
    image = cv2.imread(path)
    if image is None:
        raise ValueError(f"Image not found at {path}")
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

def mask_hash(mask):
    return hashlib.md5(mask.astype(np.uint8).tobytes()).hexdigest()

def generate_masks(image, sam_model):
    mask_generator = SamAutomaticMaskGenerator(
        model=sam_model,
        points_per_side=5,
        pred_iou_thresh=0.95,
        stability_score_thresh=0.95,
        crop_n_layers=1,
        crop_n_points_downscale_factor=1,
        min_mask_region_area=8,
    )
    return mask_generator.generate(image)

def build_mask_dict(masks):
    mask_dict = {}
    for mask in masks:
        seg = mask['segmentation'].astype(np.uint8)
        key = mask_hash(seg)
        mask_dict[key] = seg
    return mask_dict

def show_anns(image, anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:, :, i] = color_mask[i]
        ax.imshow(np.dstack((img, m * 0.35)))

sam_checkpoint = "/content/sam_vit_h_4b8939.pth"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to(device="cuda")


img_A = load_image_rgb("/content/image_A.jpg")
img_B = load_image_rgb("/content/image_B.jpg")


masks_A = generate_masks(img_A, sam)
masks_B = generate_masks(img_B, sam)


mask_dict_A = build_mask_dict(masks_A)
mask_dict_B = build_mask_dict(masks_B)


new_keys = set(mask_dict_B.keys()) - set(mask_dict_A.keys())
print(f"New or changed objects in Image B: {len(new_keys)}")


plt.figure(figsize=(10, 10))
plt.imshow(img_A)
show_anns(img_A, masks_A)
plt.title("SAM Segmentations - Image A")
plt.axis('off')
plt.show()

plt.figure(figsize=(10, 10))
plt.imshow(img_B)
show_anns(img_B, masks_B)
plt.title("SAM Segmentations - Image B")
plt.axis('off')
plt.show()

for i, key in enumerate(new_keys):
    plt.figure(figsize=(4, 4))
    plt.imshow(mask_dict_B[key], cmap='gray')
    plt.title(f"New Object {i+1}")
    plt.axis('off')
    plt.show()

