In [5]:
%load_ext autoreload
%autoreload 2
import torch
import torchvision
import coremltools as ct
import sys
import os
sys.path.append('..')
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.


In [2]:
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

checkpoint_dir = '.' #'/Users/anatoli/Documents/segment-anything/checkpoints'

The script `segment-anything/scripts/export_onnx_model.py` can be used to export the necessary portion of SAM. Alternatively, run the following code to export an ONNX model. If you have already exported a model, set the path below and skip to the next section. Assure that the exported ONNX model aligns with the checkpoint and model type set above. This notebook expects the model was exported with the parameter `return_single_mask=True`.

In [3]:
class MyModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(MyModelWrapper, self).__init__()
        self.model = model

    def forward(self, *x):
        output = self.model(*x)
        # Modify output here
        return (output[1], output[2])


In [4]:
import warnings

onnx_model_path = "sam_onnx_example.onnx"
sam = sam_model_registry[model_type](checkpoint=os.path.join(checkpoint_dir, checkpoint))
onnx_model = SamOnnxModel(sam, return_single_mask=False)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )

# We never load the quantized model, no point in producing it
# # If desired, the model can additionally be quantized and optimized. We find this improves web runtime significantly for negligible change in qualitative performance. Run the next cell to quantize the model, or skip to the next section otherwise.
#
# onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
# quantize_dynamic(
#     model_input=onnx_model_path,
#     model_output=onnx_model_quantized_path,
#     optimize_model=True,
#     per_channel=False,
#     reduce_range=False,
#     weight_type=QuantType.QUInt8,
# )
# onnx_model_path = onnx_model_quantized_path

wrapper_model = MyModelWrapper(onnx_model)
trace = torch.jit.trace(wrapper_model.eval(), tuple(dummy_inputs.values()))

verbose: False, log level: Level.ERROR



  i = torch.as_tensor(torch.arange(x.size()[0] * n, device=x.device) // n, dtype=torch.int32)
  attn = attn / math.sqrt(c_per_head)


In [None]:
coreml_inputs = {
    "image_embeddings": ct.TensorType(name="image_embeddings", shape=dummy_inputs["image_embeddings"].size()),
    "point_coords": ct.TensorType(name="point_coords", shape=dummy_inputs["point_coords"].size()),
    "point_labels": ct.TensorType(name="point_labels", shape=dummy_inputs["point_labels"].size()),
    "mask_input": ct.TensorType(name="mask_input", shape=dummy_inputs["mask_input"].size()),
    "has_mask_input": ct.TensorType(name="has_mask_input", shape=dummy_inputs["has_mask_input"].size()),
    "orig_im_size": ct.TensorType(name="orig_im_size", shape=dummy_inputs["orig_im_size"].size()),
}
coreml_outputs = {
    "iou_predictions": ct.TensorType(name="iou_predictions"),
    "low_res_masks": ct.TensorType(name="low_res_masks")
}
model = ct.convert(
    trace,
    outputs=list(coreml_outputs.values()),
    inputs=list(coreml_inputs.values()),
    minimum_deployment_target=ct.target.iOS15
)

In [None]:
model.save("sam.mlpackage")

Everything below this point is for exporting the entire pytorch model (embeddings and all) directly using coreml, and not the embeddings-to-masks part via the onnx model.

The following cells allow you to test the model on a single image, without specifying prompts (points, bboxes).

In [8]:
dummy_inputs_for_sam = {
    # The image as a torch tensor in 3xHxW format, already transformed for input to the model.
    'image': torch.randint(low=0, high=255, size=(3, 1024, 1024), dtype=torch.float),
}

dummy_inputs_for_sam['image'].shape

torch.Size([3, 1024, 1024])

In [5]:



from segment_anything.utils.coreml import SamEmbedder

# checkpoint = "sam_vit_h_4b8939.pth"
# model_type = "vit_h"
checkpoint = 'sam_vit_b_01ec64.pth'
model_type = 'vit_b'

checkpoint_dir = '.' #'/Users/anatoli/Documents/segment-anything/checkpoints'

class MyModelWrapper3(torch.nn.Module):
    def __init__(self, model):
        super(MyModelWrapper3, self).__init__()
        self.model = model

    def forward(self, *x):
        output = self.model(*x)
        # Modify output here
        return output


sam = sam_model_registry[model_type](checkpoint=os.path.join(checkpoint_dir, checkpoint))
mymodel3 = MyModelWrapper3(SamEmbedder(sam))
dummy_inputs_for_sam = {
    # The image as a torch tensor in 3xHxW format, already transformed for input to the model.
    'image': torch.randint(low=0, high=255, size=(3, 1024, 1024), dtype=torch.float),
}

sam_traced_model = torch.jit.trace(mymodel3.eval(), tuple(dummy_inputs_for_sam.values()))


coreml_inputs = {
    'image': ct.ImageType(name='image', shape=dummy_inputs_for_sam['image'].shape, channel_first=True),
}
coreml_outputs = {
    "image_embeddings": ct.TensorType(name="image_embeddings")
}
embedder_model = ct.convert(
    sam_traced_model,
    outputs=list(coreml_outputs.values()),
    inputs=list(coreml_inputs.values()),
    minimum_deployment_target=ct.target.iOS15
)

embedder_model.save("sambedder.mlpackage")

  and input_image_torch.shape[1] == 3
  and max(*input_image_torch.shape[2:]) == self.not_model.image_encoder.img_size
  if pad_h > 0 or pad_w > 0:
  max_rel_dist = int(2 * max(q_size, k_size) - 1)
  max_rel_dist = int(2 * max(q_size, k_size) - 1)
  if rel_pos.shape[0] != max_rel_dist:
  q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  if Hp > H or Wp > W:
Converting PyTorch Frontend ==> MIL Ops:   0%|                                                                                                                                     | 0/1875 [00:00<?, ? ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of

RuntimeError: BlobWriter not loaded

In [None]:
import cv2
from segment_anything import SamAutomaticMaskGenerator
image = cv2.imread('/Users/anatoli/Downloads/Untitled_anatoli.jpeg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


In [None]:
for i, mask in enumerate(masks):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.axis('off')
    show_mask(mask['segmentation'], plt.gca())
    plt.show()