# Download official SAM2 weights and modeling code

In [3]:
!wget "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"

--2025-01-14 23:00:34--  https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.154.161.79, 18.154.161.30, 18.154.161.85, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.154.161.79|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 155906050 (149M) [application/vnd.snesdev-page-table]
Saving to: ‘sam2_hiera_tiny.pt’


2025-01-14 23:00:59 (5.97 MB/s) - ‘sam2_hiera_tiny.pt’ saved [155906050/155906050]



# Install dependency with uv

In [1]:
!uv --version
!uv venv
!uv sync

uv 0.5.6 (b70c4f30e 2024-12-03)
Using CPython [36m3.11.11[39m
Creating virtual environment at: [36m.venv[39m
Activate with: [32msource .venv/bin/activate[39m
[2mResolved [1m81 packages[0m [2min 5ms[0m[0m
[2K[2mInstalled [1m61 packages[0m [2min 453ms[0m[0m                              [0m
 [32m+[39m [1mantlr4-python3-runtime[0m[2m==4.9.3[0m
 [32m+[39m [1mappnope[0m[2m==0.1.4[0m
 [32m+[39m [1masttokens[0m[2m==3.0.0[0m
 [32m+[39m [1mcoloredlogs[0m[2m==15.0.1[0m
 [32m+[39m [1mcomm[0m[2m==0.2.2[0m
 [32m+[39m [1mdebugpy[0m[2m==1.8.11[0m
 [32m+[39m [1mdecorator[0m[2m==5.1.1[0m
 [32m+[39m [1mexecuting[0m[2m==2.1.0[0m
 [32m+[39m [1mfilelock[0m[2m==3.16.1[0m
 [32m+[39m [1mflatbuffers[0m[2m==24.12.23[0m
 [32m+[39m [1mfsspec[0m[2m==2024.12.0[0m
 [32m+[39m [1mhumanfriendly[0m[2m==10.0[0m
 [32m+[39m [1mhydra-core[0m[2m==1.3.2[0m
 [32m+[39m [1miopath[0m[2m==0.1.10[0m
 [32m+[39m [1mipykernel[0m

# Split model into Encoder and Decoder

In [3]:
from typing import Optional, Tuple, Any
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_

from sam2.modeling.sam2_base import SAM2Base

class SAM2ImageEncoder(nn.Module):
    def __init__(self, sam_model: SAM2Base) -> None:
        super().__init__()
        self.model = sam_model
        self.image_encoder = sam_model.image_encoder
        self.no_mem_embed = sam_model.no_mem_embed

    def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
        backbone_out = self.image_encoder(x)
        backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
            backbone_out["backbone_fpn"][0]
        )
        backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(
            backbone_out["backbone_fpn"][1]
        )

        feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels:]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels:]

        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]

        # flatten NxCxHxW to HWxNxC
        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]

        vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
                 for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]

        return feats[0], feats[1], feats[2]


class SAM2ImageDecoder(nn.Module):
    def __init__(
            self,
            sam_model: SAM2Base,
            multimask_output: bool
    ) -> None:
        super().__init__()
        self.mask_decoder = sam_model.sam_mask_decoder
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.model = sam_model
        self.multimask_output = multimask_output

    @torch.no_grad()
    def forward(
            self,
            image_embed: torch.Tensor,
            high_res_feats_0: torch.Tensor,
            high_res_feats_1: torch.Tensor,
            point_coords: torch.Tensor,
            point_labels: torch.Tensor,
            mask_input: torch.Tensor,
            has_mask_input: torch.Tensor,
            img_size: torch.Tensor
    ):
        sparse_embedding = self._embed_points(point_coords, point_labels)
        self.sparse_embedding = sparse_embedding
        dense_embedding = self._embed_masks(mask_input, has_mask_input)

        high_res_feats = [high_res_feats_0, high_res_feats_1]
        image_embed = image_embed

        masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
            image_embeddings=image_embed,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embedding,
            dense_prompt_embeddings=dense_embedding,
            repeat_image=False,
            high_res_features=high_res_feats,
        )

        if self.multimask_output:
            masks = masks[:, 1:, :, :]
            iou_predictions = iou_predictions[:, 1:]
        else:
            masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(masks, iou_predictions)

        masks = torch.clamp(masks, -32.0, 32.0)
        print(masks.shape, iou_predictions.shape)

        masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False)

        return masks, iou_predictions

    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:

        point_coords = point_coords + 0.5

        padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
        padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
        point_coords = torch.cat([point_coords, padding_point], dim=1)
        point_labels = torch.cat([point_labels, padding_label], dim=1)

        point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
        point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size

        point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)

        point_embedding = point_embedding * (point_labels != -1)
        point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (
                point_labels == -1
        )

        for i in range(self.prompt_encoder.num_point_embeddings):
            point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)

        return point_embedding

    def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
        mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)
        mask_embedding = mask_embedding + (
                1 - has_mask_input
        ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
        return mask_embedding

# Convert model to .onnx

## Encoder

In [4]:
import torch
from sam2.build_sam import build_sam2

model_type = "sam2_hiera_tiny"
model_cfg = "sam2_hiera_t.yaml"
input_size = 1024 
multimask_output = True

sam2_checkpoint = f"./{model_type}.pt"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")

img=torch.randn(1, 3, input_size, input_size).cpu()

sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img)
print(high_res_feats_0.shape)
print(high_res_feats_1.shape)
print(image_embed.shape)

torch.onnx.export(sam2_encoder,
      img,
      f"{model_type}_encoder.onnx",
      export_params=True,
      opset_version=17,
      do_constant_folding=True,
      input_names = ['image'],
      output_names = ['high_res_feats_0', 'high_res_feats_1', 'image_embed'],
    )

torch.Size([1, 32, 256, 256])
torch.Size([1, 64, 128, 128])
torch.Size([1, 256, 64, 64])


  if pad_h > 0 or pad_w > 0:
  if Hp > H or Wp > W:


## Decoder

In [5]:
sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).cpu()

embed_dim = sam2_model.sam_prompt_encoder.embed_dim
embed_size = (sam2_model.image_size // sam2_model.backbone_stride, sam2_model.image_size // sam2_model.backbone_stride)
mask_input_size = [4 * x for x in embed_size]
print(embed_dim, embed_size, mask_input_size)

point_coords = torch.randint(low=0, high=input_size, size=(1, 5, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float)
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([1], dtype=torch.float)
orig_im_size = torch.tensor([input_size, input_size], dtype=torch.int32)

masks, scores = sam2_decoder(image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size)

torch.onnx.export(sam2_decoder,
      (image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size),
      f"{model_type}_decoder.onnx",
      export_params=True,
      opset_version=16,
      do_constant_folding=True,
      input_names = ['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input', 'orig_im_size'],
      output_names = ['masks', 'iou_predictions'],
      dynamic_axes = {"point_coords": {0: "num_labels", 1: "num_points"},
                      "point_labels": {0: "num_labels", 1: "num_points"},
                      "mask_input": {0: "num_labels"},
                      "has_mask_input": {0: "num_labels"}
      }
    )


256 (64, 64) [256, 256]
torch.Size([1, 3, 256, 256]) torch.Size([1, 3])


  assert image_embeddings.shape[0] == tokens.shape[0]
  image_pe.size(0) == 1


torch.Size([1, 3, 256, 256]) torch.Size([1, 3])


# Test exported models with `onnxruntime`

## Encoder

In [6]:
import onnx

onnx_model = onnx.load("sam2_hiera_tiny_encoder.onnx")
onnx.checker.check_model(onnx_model)

In [7]:
import onnxruntime as ort
import numpy as np

ort_sess = ort.InferenceSession('sam2_hiera_tiny_encoder.onnx', {})

[0;93m2025-01-14 23:24:42.614163 [W:onnxruntime:, graph.cc:109 MergeShapeInfo] Error merging shape info for output. '/image_encoder/trunk/Concat_3_output_0' source:{4} target:{5}. Falling back to lenient merge.[m


In [8]:
img = torch.randn(1, 3, input_size, input_size).cpu()

outputs = ort_sess.run(None, {
    'image': img.numpy(),
})

for i in range(len(outputs)):
    print (ort_sess.get_outputs()[i].name)
    print (outputs[i].shape)

high_res_feats_0
(1, 32, 256, 256)
high_res_feats_1
(1, 64, 128, 128)
image_embed
(1, 256, 64, 64)


## Decoder

In [9]:
onnx_model = onnx.load("sam2_hiera_tiny_decoder.onnx")
onnx.checker.check_model(onnx_model)

In [10]:
ort_sess = ort.InferenceSession('sam2_hiera_tiny_decoder.onnx', {})

outputs = ort_sess.run(None, {
    'image_embed': image_embed.detach().numpy(), 
    'high_res_feats_0': high_res_feats_0.detach().numpy(), 
    'high_res_feats_1': high_res_feats_1.detach().numpy(), 
    'point_coords': point_coords.numpy(), 
    'point_labels': point_labels.numpy(), 
    'mask_input': mask_input.numpy(), 
    'has_mask_input': has_mask_input.numpy(), 
    'orig_im_size': orig_im_size.numpy()
})

for i in range(len(outputs)):
    print (ort_sess.get_outputs()[i].name)
    print (outputs[i].shape)

masks
(1, 3, 1024, 1024)
iou_predictions
(1, 3)


# Convert Encoder model .onnx to .ort

In [11]:
!python3 -m onnxruntime.tools.convert_onnx_models_to_ort sam2_hiera_tiny_encoder.onnx

Converting models with optimization style 'Fixed' and level 'all'
Converting optimized ONNX model /Users/giuseppeambrosio/Desktop/projects/vite-sam/notebooks/sam2_hiera_tiny_encoder.onnx to ORT format model /Users/giuseppeambrosio/Desktop/projects/vite-sam/notebooks/sam2_hiera_tiny_encoder.ort
[0;93m2025-01-14 23:25:25.476227 [W:onnxruntime:, graph.cc:109 MergeShapeInfo] Error merging shape info for output. '/image_encoder/trunk/Concat_3_output_0' source:{4} target:{5}. Falling back to lenient merge.[m
Converted 1/1 models successfully.
Generating config file from ORT format models with optimization style 'Fixed' and level 'all'
2025-01-14 23:25:26,213 ort_format_model.utils [INFO] - Created config in /Users/giuseppeambrosio/Desktop/projects/vite-sam/notebooks/sam2_hiera_tiny_encoder.required_operators.config
Converting models with optimization style 'Runtime' and level 'all'
Converting optimized ONNX model /Users/giuseppeambrosio/Desktop/projects/vite-sam/notebooks/sam2_hiera_tiny_e