In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Produces masks from prompts using an ONNX model

SAM's prompt encoder and mask decoder are very lightweight, which allows for efficient computation of a mask given user input. This notebook shows an example of how to export and use this lightweight component of the model in ONNX format, allowing it to run on a variety of platforms that support an ONNX runtime.

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. The latest stable versions of PyTorch and ONNX are recommended for this notebook. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib onnx onnxruntime
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    
    !mkdir images
    # !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
        
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
    !pip install piexif

PyTorch version: 2.0.0+cu118
Torchvision version: 0.15.1+cu118
CUDA is available: True
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting onnx
  Downloading onnx-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.5/13.5 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting onnxruntime
  Downloading onnxruntime-1.14.1-cp39-cp39-manylinux_2_27_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting coloredlogs
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


## Set-up

Note that this notebook requires both the `onnx` and `onnxruntime` optional dependencies, in addition to `opencv-python` and `matplotlib` for visualization.

In [None]:
import torch
import numpy as np
import piexif
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

In [None]:
def show_mask(mask, ax):
    color = np.array([255/255, 255/255, 255/255, 0.7])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    # ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    # ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   

## Export an ONNX model

Set the path below to a SAM model checkpoint, then load the model. This will be needed to both export the model and to calculate embeddings for the model.

In [None]:
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

In [None]:
sam = sam_model_registry[model_type](checkpoint=checkpoint)

The script `segment-anything/scripts/export_onnx_model.py` can be used to export the necessary portion of SAM. Alternatively, run the following code to export an ONNX model. If you have already exported a model, set the path below and skip to the next section. Assure that the exported ONNX model aligns with the checkpoint and model type set above. This notebook expects the model was exported with the parameter `return_single_mask=True`.

In [None]:
onnx_model_path = None  # Set to use an already exported model, then skip to the next section.

In [None]:
import warnings

onnx_model_path = "sam_onnx_example.onnx"

onnx_model = SamOnnxModel(sam, return_single_mask=True)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )    

verbose: False, log level: Level.ERROR



If desired, the model can additionally be quantized and optimized. We find this improves web runtime significantly for negligible change in qualitative performance. Run the next cell to quantize the model, or skip to the next section otherwise.

In [None]:
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
quantize_dynamic(
    model_input=onnx_model_path,
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path

Ignore MatMul due to non constant B: /[/transformer/layers.0/self_attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/layers.0/self_attn/MatMul_1]
Ignore MatMul due to non constant B: /[/transformer/layers.0/cross_attn_token_to_image/MatMul]
Ignore MatMul due to non constant B: /[/transformer/layers.0/cross_attn_token_to_image/MatMul_1]
Ignore MatMul due to non constant B: /[/transformer/layers.0/cross_attn_image_to_token/MatMul]
Ignore MatMul due to non constant B: /[/transformer/layers.0/cross_attn_image_to_token/MatMul_1]
Ignore MatMul due to non constant B: /[/transformer/layers.1/self_attn/MatMul]
Ignore MatMul due to non constant B: /[/transformer/layers.1/self_attn/MatMul_1]
Ignore MatMul due to non constant B: /[/transformer/layers.1/cross_attn_token_to_image/MatMul]
Ignore MatMul due to non constant B: /[/transformer/layers.1/cross_attn_token_to_image/MatMul_1]
Ignore MatMul due to non constant B: /[/transformer/layers.1/cross_attn_image_to_token/MatMul]
Ignore Ma

## Example Image

In [None]:
import glob as gl
files = gl.glob('/content/drive/MyDrive/NeRF Supervision/nerf-supervision-public/data/fork/images/*.png')

In [None]:
counter = 0

for item in sorted(files):
  counter += 1
  print(item)
  image = cv2.imread(item)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  im = Image.open(item)
  metadata = im.info

  ort_session = onnxruntime.InferenceSession(onnx_model_path)
  sam.to(device='cuda')
  predictor = SamPredictor(sam)
  predictor.set_image(image)
  image_embedding = predictor.get_image_embedding().cpu().numpy()

  input_box = np.array([0, 0, 4032, 3024])
  input_point = np.array([[575, 750]])
  input_label = np.array([0])

  onnx_box_coords = input_box.reshape(2, 2)
  onnx_box_labels = np.array([2,3])

  onnx_coord = np.concatenate([input_point, onnx_box_coords], axis=0)[None, :, :]
  onnx_label = np.concatenate([input_label, onnx_box_labels], axis=0)[None, :].astype(np.float32)

  onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

  onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  onnx_has_mask_input = np.zeros(1, dtype=np.float32)

  ort_inputs = {
      "image_embeddings": image_embedding,
      "point_coords": onnx_coord,
      "point_labels": onnx_label,
      "mask_input": onnx_mask_input,
      "has_mask_input": onnx_has_mask_input,
      "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
  }

  masks, _, _ = ort_session.run(None, ort_inputs)
  masks = masks > predictor.model.mask_threshold

  repeat_mask = masks[0,0,:,:,np.newaxis].repeat(3,axis=2)

  masked_img = np.where(repeat_mask,np.zeros_like(repeat_mask),image)
  # print(im)
  # masked_img = image
  plt.imshow(masked_img)
  plt.show()

  modified_img = Image.fromarray(np.uint8(masked_img))
  modified_img.info.update(metadata)
  modified_img.save('/content/drive/MyDrive/NeRF Supervision/Segmented_fork/seg_fork/{:4d}.png'.format(2000+counter))

  # plt.figure(figsize=(40, 30))
  # plt.imshow(image)
  # show_mask(masks[0], plt.gca())
  # plt.axis('off')
  # plt.savefig('/content/drive/MyDrive/NeRF Supervision/Segmented_fork/seg_one/{:4d}.png'.format(counter), pil_kwargs={'icc_profile': im.info.get('icc_profile')})
  # plt.show()
  # break

Output hidden; open in https://colab.research.google.com to view.

In [None]:
counter = 0

for item in files:
  counter += 1
  image = cv2.imread(item)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  im = Image.open(item)
  metadata = im.info

  ort_session = onnxruntime.InferenceSession(onnx_model_path)
  sam.to(device='cuda')
  predictor = SamPredictor(sam)
  predictor.set_image(image)
  image_embedding = predictor.get_image_embedding().cpu().numpy()

  input_point = np.array([[500, 375]])
  input_label = np.array([1])

  onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
  onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

  onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
  onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
  onnx_has_mask_input = np.zeros(1, dtype=np.float32)
  ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
  }

  masks, _, low_res_logits = ort_session.run(None, ort_inputs)
  masks = masks > predictor.model.mask_threshold
  plt.figure(figsize=(40,30))
  plt.imshow(image)
  show_mask(masks, plt.gca())
  # show_points(input_point, input_label, plt.gca())
  plt.axis('off')
  plt.savefig('/content/drive/MyDrive/NeRF Supervision/Segmented_fork/realone/{:4d}.png'.format(counter), pil_kwargs={'icc_profile': im.info.get('icc_profile')})
  plt.show() 

Output hidden; open in https://colab.research.google.com to view.