In [1]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis
from torchinfo import summary
from src.models.base_sam import BaseSAM
from src.models.efficientvit.sam_model_zoo import create_sam_model

In [2]:
boxes = torch.randn(1, 4)

class FlopCountWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, image):
        return self.model(image=image, boxes=boxes)

# L0

In [3]:
l0_model = create_sam_model(name="l0", pretrained=False).to("cpu")
l0_sam = BaseSAM.construct_from(original_sam=l0_model)

image = torch.randn(1, 3, 512, 512)

In [4]:
print(
    "L0 params:",
    round(
        summary(l0_sam, image=image, boxes=boxes, device="cpu").total_params / 1000000,
        2,
    ),
    "M",
)

L0 params: 34.79 M


In [5]:
flops = FlopCountAnalysis(FlopCountWrapper(l0_sam), image)
print("L0 flops:", round(flops.total() / 1e9, 2), "G")

Unsupported operator aten::add_ encountered 38 time(s)
Unsupported operator aten::gelu encountered 32 time(s)
Unsupported operator aten::add encountered 50 time(s)
Unsupported operator aten::mul encountered 18 time(s)
Unsupported operator aten::pad encountered 4 time(s)
Unsupported operator aten::div encountered 17 time(s)
Unsupported operator aten::upsample_bicubic2d encountered 3 time(s)
Unsupported operator aten::mean encountered 4 time(s)
Unsupported operator aten::sub encountered 7 time(s)
Unsupported operator aten::square encountered 1 time(s)
Unsupported operator aten::sqrt encountered 2 time(s)
Unsupported operator aten::clone encountered 1 time(s)
Unsupported operator aten::sin encountered 2 time(s)
Unsupported operator aten::cos encountered 2 time(s)
Unsupported operator aten::cumsum encountered 2 time(s)
Unsupported operator aten::repeat_interleave encountered 2 time(s)
Unsupported operator aten::softmax encountered 7 time(s)
Unsupported operator aten::pow encountered 1 time

L0 flops: 36.8 G


# L1

In [6]:
l1_model = create_sam_model(name="l1", pretrained=False).to("cpu")
l1_sam = BaseSAM.construct_from(original_sam=l1_model)

image = torch.randn(1, 3, 512, 512)

In [7]:
print(
    "L1 image encoder params:",
    round(
        summary(l1_model.image_encoder, (1, 3, 512, 512), device="cpu").total_params
        / 1000000,
        2,
    ),
    "M",
)

L1 image encoder params: 43.59 M


In [8]:
flops = FlopCountAnalysis(l1_model.image_encoder, image)
print("L1 image encoder flops:", round(flops.total() / 1e9, 2), "G")

Unsupported operator aten::add_ encountered 50 time(s)
Unsupported operator aten::gelu encountered 42 time(s)
Unsupported operator aten::add encountered 39 time(s)
Unsupported operator aten::mul encountered 7 time(s)
Unsupported operator aten::pad encountered 6 time(s)
Unsupported operator aten::div encountered 7 time(s)
Unsupported operator aten::upsample_bicubic2d encountered 3 time(s)
Unsupported operator aten::mean encountered 2 time(s)
Unsupported operator aten::sub encountered 1 time(s)
Unsupported operator aten::square encountered 1 time(s)
Unsupported operator aten::sqrt encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
backbone.stages.0.op_list.1.shortcut, backbone.stages.1.op_list.1.shor

L1 image encoder flops: 49.23 G


In [9]:
print(
    "L1 params:",
    round(
        summary(l1_sam, image=image, boxes=boxes, device="cpu").total_params / 1000000,
        2,
    ),
    "M",
)

L1 params: 47.65 M


In [10]:
flops = FlopCountAnalysis(FlopCountWrapper(l1_sam), image)
print("L1 flops:", round(flops.total() / 1e9, 2), "G")

Unsupported operator aten::add_ encountered 52 time(s)
Unsupported operator aten::gelu encountered 44 time(s)
Unsupported operator aten::add encountered 62 time(s)
Unsupported operator aten::mul encountered 20 time(s)
Unsupported operator aten::pad encountered 6 time(s)
Unsupported operator aten::div encountered 19 time(s)
Unsupported operator aten::upsample_bicubic2d encountered 3 time(s)
Unsupported operator aten::mean encountered 4 time(s)
Unsupported operator aten::sub encountered 7 time(s)
Unsupported operator aten::square encountered 1 time(s)
Unsupported operator aten::sqrt encountered 2 time(s)
Unsupported operator aten::clone encountered 1 time(s)
Unsupported operator aten::sin encountered 2 time(s)
Unsupported operator aten::cos encountered 2 time(s)
Unsupported operator aten::cumsum encountered 2 time(s)
Unsupported operator aten::repeat_interleave encountered 2 time(s)
Unsupported operator aten::softmax encountered 7 time(s)
Unsupported operator aten::pow encountered 1 time

L1 flops: 51.05 G


# L2

In [11]:
l2_model = create_sam_model(name="l2", pretrained=False)
l2_sam = BaseSAM.construct_from(original_sam=l2_model)

image = torch.randn(1, 3, 512, 512)

In [12]:
print(
    "L2 params:",
    round(
        summary(l2_sam, image=image, boxes=boxes, device="cpu").total_params / 1000000,
        2,
    ),
    "M",
)

L2 params: 61.33 M


In [13]:
flops = FlopCountAnalysis(FlopCountWrapper(l2_sam), image)
print("L2 flops:", round(flops.total() / 1e9, 2), "G")

Unsupported operator aten::add_ encountered 70 time(s)
Unsupported operator aten::gelu encountered 58 time(s)
Unsupported operator aten::add encountered 76 time(s)
Unsupported operator aten::mul encountered 22 time(s)
Unsupported operator aten::pad encountered 8 time(s)
Unsupported operator aten::div encountered 21 time(s)
Unsupported operator aten::upsample_bicubic2d encountered 3 time(s)
Unsupported operator aten::mean encountered 4 time(s)
Unsupported operator aten::sub encountered 7 time(s)
Unsupported operator aten::square encountered 1 time(s)
Unsupported operator aten::sqrt encountered 2 time(s)
Unsupported operator aten::clone encountered 1 time(s)
Unsupported operator aten::sin encountered 2 time(s)
Unsupported operator aten::cos encountered 2 time(s)
Unsupported operator aten::cumsum encountered 2 time(s)
Unsupported operator aten::repeat_interleave encountered 2 time(s)
Unsupported operator aten::softmax encountered 7 time(s)
Unsupported operator aten::pow encountered 1 time

L2 flops: 70.71 G


# MedSAM

In [14]:
from src.models.segment_anything import build_sam_vit_b

medsam_model = build_sam_vit_b()
medsam_sam = BaseSAM.construct_from(original_sam=medsam_model)

image = torch.randn(1, 3, 1024, 1024)

In [15]:
print(
    "MedSAM params:",
    round(
        summary(medsam_sam, image=image, boxes=boxes, device="cpu").total_params
        / 1000000,
        2,
    ),
    "M",
)

MedSAM params: 93.74 M


In [16]:
flops = FlopCountAnalysis(FlopCountWrapper(medsam_sam), image)
print("MedSAM flops:", round(flops.total() / 1e9, 2), "G")

Unsupported operator aten::add encountered 116 time(s)
Unsupported operator aten::rsub encountered 16 time(s)
Unsupported operator aten::pad encountered 8 time(s)
Unsupported operator aten::mul encountered 167 time(s)
Unsupported operator aten::div encountered 86 time(s)
Unsupported operator aten::sub encountered 58 time(s)
Unsupported operator aten::softmax encountered 19 time(s)
Unsupported operator aten::gelu encountered 14 time(s)
Unsupported operator aten::mean encountered 6 time(s)
Unsupported operator aten::pow encountered 3 time(s)
Unsupported operator aten::sqrt encountered 3 time(s)
Unsupported operator aten::clone encountered 1 time(s)
Unsupported operator aten::sin encountered 2 time(s)
Unsupported operator aten::cos encountered 2 time(s)
Unsupported operator aten::add_ encountered 2 time(s)
Unsupported operator aten::cumsum encountered 2 time(s)
Unsupported operator aten::repeat_interleave encountered 2 time(s)
The following submodules of the model were never called during

MedSAM flops: 488.24 G


# LiteMedSAM

In [17]:
from src.models.lite_medsam import build_lite_medsam

lite_medsam_model = build_lite_medsam()
lite_medsam_sam = BaseSAM.construct_from(original_sam=lite_medsam_model)

image = torch.randn(1, 3, 256, 256)

In [18]:
print(
    "LiteMedSAM params:",
    round(
        summary(lite_medsam_sam, image=image, boxes=boxes, device="cpu").total_params
        / 1000000,
        2,
    ),
    "M",
)

LiteMedSAM params: 9.79 M


In [19]:
flops = FlopCountAnalysis(FlopCountWrapper(lite_medsam_sam), image)
print("LiteMedSAM flops:", round(flops.total() / 1e9, 2), "G")

Unsupported operator aten::add_ encountered 31 time(s)
Unsupported operator aten::gelu encountered 25 time(s)
Unsupported operator aten::pad encountered 10 time(s)
Unsupported operator aten::mul encountered 45 time(s)
Unsupported operator aten::add encountered 57 time(s)
Unsupported operator aten::softmax encountered 17 time(s)
Unsupported operator aten::mean encountered 6 time(s)
Unsupported operator aten::sub encountered 10 time(s)
Unsupported operator aten::pow encountered 3 time(s)
Unsupported operator aten::sqrt encountered 3 time(s)
Unsupported operator aten::div encountered 14 time(s)
Unsupported operator aten::clone encountered 1 time(s)
Unsupported operator aten::sin encountered 2 time(s)
Unsupported operator aten::cos encountered 2 time(s)
Unsupported operator aten::cumsum encountered 2 time(s)
Unsupported operator aten::repeat_interleave encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they we

LiteMedSAM flops: 39.98 G
