In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install opencv-python matplotlib onnxruntime onnx
!pip install wget

In [None]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from segment_anything.utils.onnx import SamOnnxModel
import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

import torch
from copy import deepcopy

In [None]:
# download a model
# !python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
sam = sam_model_registry["vit_b"](R"C:\Users\mmoller\Downloads\sam_vit_b_01ec64.pth")


In [None]:
# open the image of the path into an ndarray using PIL and set it as the image to be processed
image_path = R"C:\Users\mmoller\OneDrive - NVIDIA Corporation\Pictures\Camera Roll\WIN_20240318_17_50_49_Pro.jpg"
image = Image.open(image_path)

In [None]:
class ImageTensor:
    def __init__(self, image):
        self.image = image
        self.orig_width, self.orig_height = image.size
        self.resized_width, self.resized_height = None, None
        self.pad_width, self.pad_height = None, None

    def size(self):
        return self.image.size


In [None]:
image_object = ImageTensor(image)

In [None]:
class ImagePreprocessor:
    def __init__(self, long_side_max=1024, mean=None, std=None, image_format="RGB", pad_to_square=True):
        self.long_side_max = long_side_max
        self.mean = mean
        self.std = std
        self.image_format = image_format
        self.pad_to_square = pad_to_square
        if self.mean is None:
            self.mean = np.array([123.675, 116.28, 103.53])
        if self.std is None:
            self.std = np.array([58.395, 57.12, 57.375])


    def resize_image_to_long_side(self, img: ImageTensor):
        if self.long_side_max is None:
            return img
        orig_width, orig_height = img.image.size
        if orig_width > orig_height:
            img.resized_width = self.long_side_max
            img.resized_height = int(self.long_side_max / orig_width * orig_height)
        else:
            img.resized_height = self.long_side_max
            img.resized_width = int(self.long_side_max / orig_height * orig_width)

        img.image = img.image.resize((img.resized_width, img.resized_height), Image.Resampling.BILINEAR)
        return img

    def make_image_rgb(self, image):
        if image.image.mode == "RGB":
            return image
        return image.image.convert("RGB")

    def pad_image_to_square(self, image):
        if isinstance(image, ImageTensor):
            image.image = self.pad_image_to_square(image.image)
            return image
        else:
            h, w = image.shape[2:]
            max_dim = max(h, w)
            pad_h = max_dim - h
            pad_w = max_dim - w
            image = np.pad(image, ((0,0), (0,0), (0,pad_h), (0,pad_w)), mode="constant", constant_values=0)
            return image

    def normalize_image(self, image):
        if isinstance(image, ImageTensor):
            image.image = self.normalize_image(image.image)
            return image
        else:
            image = (image - self.mean) / self.std
            return image

    def to_tensor(self, image):
        if isinstance(image, ImageTensor):
            image.image = self.to_tensor(image.image)
            return image
        else:
            image = image.transpose(2,0,1)[None,:,:,:].astype(np.float32)
            return image

    def from_image_to_input(self, image):
        image = self.make_image_rgb(image)
        image = self.resize_image_to_long_side(image)
        image = self.normalize_image(image)
        image = self.to_tensor(image)
        # pad to square
        if self.pad_to_square:
            image = self.pad_image_to_square(image)
        return image

In [None]:
image_preprocessor = ImagePreprocessor()

In [None]:
input_image = image_preprocessor.from_image_to_input(image_object)

In [None]:
input_image.resized_width, input_image.resized_height, input_image.pad_width, input_image.pad_height, input_image.orig_width, input_image.orig_height

In [None]:
plt.imshow(input_image.image[0].transpose(1,2,0))

In [None]:
predictor = SamPredictor(sam)
predictor.set_image(image)

In [None]:
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

# Export an ONNX model

In [None]:
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
checkpoint_path =os.path.join(R"C:\\Users\\mmoller\\Downloads", checkpoint)

In [None]:
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)

In [None]:
onnx_model_path = "sam_onnx_example.onnx"

In [None]:
import warnings

onnx_model_path = "sam_onnx_example.onnx"

onnx_model = SamOnnxModel(sam, return_single_mask=True)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )

# using the local onnx model 


In [None]:
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
checkpoint_path =os.path.join(R"C:\\Users\\mmoller\\Downloads", checkpoint)
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
ort_session = onnxruntime.InferenceSession(onnx_model_path)
sam.to(device='cuda')
predictor = SamPredictor(sam)


In [None]:
predictor.set_image(image)

In [None]:
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

# export the image encoder to ONNX

In [None]:
# Load SAM model
sam = sam_model_registry["vit_b"](R"C:\Users\mmoller\Downloads\sam_vit_b_01ec64.pth")

# Export images encoder from SAM model to ONNX
torch.onnx.export(
    f="vit_b_encoder.onnx",
    model=sam.image_encoder,
    args=torch.randn(1, 3, 1024, 1024),
    input_names=["images"],
    output_names=["embeddings"],
    export_params=True
)

In [None]:
# Testing encoder and decoder on an image

encoder = onnxruntime.InferenceSession("vit_b_encoder.onnx")


In [None]:
outputs = encoder.run(None, {"images": input_image.image})
embeddings = outputs[0]
embeddings.shape


## Decoder 


## instructions from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb
The ONNX model has a different input signature than SamPredictor.predict. The following inputs must all be supplied. Note the special cases for both point and mask inputs. All inputs are np.float32.

- image_embeddings: The image embedding from predictor.get_image_embedding(). Has a batch index of length 1.
- point_coords: Coordinates of sparse input prompts, corresponding to both point inputs and box inputs. Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner. Coordinates must already be transformed to long-side 1024. Has a batch index of length 1.
- point_labels: Labels for the sparse input prompts. 0 is a negative input point, 1 is a positive input point, 2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point. If there is no box input, a single padding point with label -1 and coordinates (0.0, 0.0) should be concatenated.
- mask_input: A mask input to the model with shape 1x1x256x256. This must be supplied even if there is no mask input. In this case, it can just be zeros.
- has_mask_input: An indicator for the mask input. 1 indicates a mask input, 0 indicates no mask input.
- orig_im_size: The size of the input image in (H,W) format, before any transformation.

Additionally, the ONNX model does not threshold the output mask logits. To obtain a binary mask, threshold at sam.mask_threshold (equal to 0.0).


In [None]:
# use the entire image as the query
input_point = np.array([[0,0],  [1024, 1024]])
input_label = np.array([1, 1])
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

In [None]:
input_image.re

In [None]:
input_box = np.array([0, 0, 1000, 1025]).reshape(2,2)
box_labels = np.array([2,3])
input_point = np.array([[140, 160]])
input_label = np.array([0])

onnx_coord = np.concatenate([input_point, input_box], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, box_labels], axis=0)[None, :].astype(np.float32)

coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (input_image.resized_height / input_image.orig_height)
coords[..., 1] = coords[..., 1] * (input_image.resized_width / input_image.orig_width)

onnx_coord = coords.astype("float32")
onnx_coord


In [None]:
decoder = onnxruntime.InferenceSession("sam_onnx_example.onnx")
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

outputs = decoder.run(None,{
    "image_embeddings": embeddings,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array([input_image.resized_height, input_image.resized_width], dtype=np.float32),
})
masks = outputs[0]
masks.shape


In [None]:
mask = masks[0][0]
mask = (mask > 0).astype('uint8')*255
img = Image.fromarray(mask,'L')
img