## Image segmentation with SAM 3

This notebook demonstrates how to use SAM 3 for image segmentation with text or visual prompts. It covers the following capabilities:

- **Text prompts**: Using natural language descriptions to segment objects (e.g., "person", "face")
- **Box prompts**: Using bounding boxes as exemplar visual prompts

In [1]:
import os
import sys
sys.path.insert(0, "/home/groups/sammer/haogeh/util/models/sam3/")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [2]:
import sam3
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results

sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")

In [None]:
import importlib
importlib.reload(sam3)

In [4]:
import torch

# turn on tfloat32 for Ampere GPUs
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

# Build Model

In [None]:
bpe_path = f"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=bpe_path,checkpoint_path = f'{sam3_root}/assets/checkpoint/sam3.pt')

In [6]:
dataset_path = f'{sam3_root}/assets/images/dataset_basal_view_consented'
bbox_file = os.path.join(dataset_path,"df_with_bbox.pkl")

In [7]:
bbox_df = pd.read_pickle(bbox_file)

In [8]:
def xyxy_to_xywh(bbox):
    """
    Convert a bounding box from (x1, y1, x2, y2) format to (x, y, w, h) format.
    (x, y) is the top-left corner, (w, h) is width and height.
    """
    x1, y1, x2, y2 = bbox
    x = x1
    y = y1
    w = x2 - x1
    h = y2 - y1
    return (x, y, w, h)

In [10]:
for i,line in bbox_df.iterrows():
    image_path = os.path.join(dataset_path,line['save_path_rel'])
    result_image_path = os.path.join(os.path.dirname(image_path), os.path.basename(image_path).split('.')[0] + '_result.jpg')
    mask_path = os.path.join(os.path.dirname(image_path), os.path.basename(image_path).split('.')[0] + '_mask.npz')
    # load: state = np.load("state.npz", allow_pickle=True)

    image = Image.open(image_path)


    width, height = image.size
    processor = Sam3Processor(model, confidence_threshold=0.5)

    bbox = line['bbox']
    x1, y1, x2, y2 = bbox
    x,y,w,h = xyxy_to_xywh(bbox)
    box_input_xywh = torch.tensor([x,y,w,h]).view(-1, 4)
    box_input_cxcywh = box_xywh_to_cxcywh(box_input_xywh)
    norm_box_cxcywh = normalize_bbox(box_input_cxcywh, width, height).flatten().tolist()

    inference_state = processor.set_image(image)
    processor.reset_all_prompts(inference_state)
    inference_state = processor.set_text_prompt(state=inference_state, prompt="nose from basal view")
    inference_state = processor.add_geometric_prompt(
        state=inference_state, box=norm_box_cxcywh, label=True
    )

    mask = inference_state['masks'].cpu().numpy()[0,0]
    mask_logit = inference_state['masks_logits'].cpu().numpy()[0,0]
    score = inference_state['scores'].float().cpu().numpy()[0]

    save_dict = {}
    save_dict['bbox_xyxy'] = x1,y1,x2,y2
    save_dict['bbox_xywh'] = x,y,w,h
    save_dict['mask'] = mask
    save_dict['mask_logit'] = mask_logit

    plt.imshow(image)
    rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='cyan', facecolor='none', linewidth=1)
    plt.gca().add_patch(rect)

    mask_np = mask.astype(bool)     # ensure boolean
    overlay = np.zeros((mask_np.shape[0], mask_np.shape[1], 4), dtype=float)
    overlay[mask_np] = [0.0, 1.0, 0.0, 0.5]   # ONLY True pixels get color
    plt.imshow(overlay)

    plt.title(f"MRN: {line['mrn']}")
    plt.axis('off')
    # plt.show()
    plt.savefig(result_image_path,bbox_inches='tight', pad_inches=2)    
    np.savez(mask_path,
         **{k: v.cpu().numpy() if isinstance(v, torch.Tensor) else v
            for k, v in save_dict.items()})
    plt.close()

    del inference_state
    torch.cuda.empty_cache()