In [None]:
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import numpy as np
import os

print("Starting SAM model export using official SamOnnxModel...")

# Model paths
model_path = "sam_vit_b_01ec64.pth"
encoder_path = "sam_encoder.onnx"
decoder_path = "sam_decoder.onnx"

# Load SAM model
print("Loading SAM model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
sam = sam_model_registry["vit_b"](checkpoint=model_path)
sam.to(device=device)
sam.eval()

# Check if model files already exist
if os.path.exists(encoder_path):
    print(f"Warning: {encoder_path} already exists! Overwriting...")
if os.path.exists(decoder_path):
    print(f"Warning: {decoder_path} already exists! Overwriting...")

# 1. Export the image encoder
print(f"Exporting image encoder to {encoder_path}...")
with torch.no_grad():
    torch.onnx.export(
        model=sam.image_encoder,
        args=torch.randn(1, 3, 1024, 1024, device=device),
        f=encoder_path,
        opset_version=17,
        input_names=["images"],
        output_names=["image_embeddings"],
        dynamic_axes={
            "images": {0: "batch_size"},
            "image_embeddings": {0: "batch_size"}
        }
    )
    print(f"Successfully exported image encoder to {encoder_path}")

# 2. Export the mask decoder using SamOnnxModel
print(f"Exporting mask decoder to {decoder_path}...")

# Create wrapper model with the official SamOnnxModel
onnx_model = SamOnnxModel(sam, return_single_mask=False)

# Get dimensions from the model
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
print(f"SAM model dimensions - embed_dim: {embed_dim}, embed_size: {embed_size}, mask_input_size: {mask_input_size}")

# Create the expected input dummy tensors
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, device=device),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float, device=device),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float, device=device),
    "mask_input": torch.zeros(1, 1, *mask_input_size, dtype=torch.float, device=device),
    "has_mask_input": torch.zeros(1, dtype=torch.float, device=device),
    "orig_im_size": torch.tensor([1024, 1024], dtype=torch.float, device=device),
}

# Export with all necessary configs
output_names = ["masks", "iou_predictions", "low_res_masks"]

with torch.no_grad():
    try:
        torch.onnx.export(
            model=onnx_model,
            args=tuple(dummy_inputs.values()),
            f=decoder_path,
            opset_version=17,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes={
                "image_embeddings": {0: "batch_size"},
                "point_coords": {0: "batch_size", 1: "num_points"},
                "point_labels": {0: "batch_size", 1: "num_points"},
                "masks": {0: "batch_size"},
                "iou_predictions": {0: "batch_size"},
                "low_res_masks": {0: "batch_size"}
            },
            do_constant_folding=True
        )
        print(f"Successfully exported mask decoder to {decoder_path}")
    except Exception as e:
        print(f"Error exporting mask decoder: {e}")
        import traceback
        traceback.print_exc()

print("\nSAM ONNX export completed!")
print(f"Files exported:\n1. {encoder_path}\n2. {decoder_path}")

print("\nImportant notes for C++ integration:")
print("1. The decoder model expects the following inputs:")
for name in dummy_inputs.keys():
    print(f"   - {name}: {dummy_inputs[name].shape}")
print("2. The decoder returns multiple masks with IoU predictions")
print("3. Box prompts should be provided as point coordinates with special labels (2 & 3 for box corners)")
print("4. For usage examples, refer to the segment_anything/utils/onnx.py module in the SAM repository")
print("5. The output masks need to be processed (threshold > 0) to get binary masks")