In [1]:
import numpy as np
import torch
import torch.nn as nn
from repvit import RepViT, RepViTBlock
from segment_anything.modeling import  PromptEncoder, MaskDecoder, TwoWayTransformer, ImageEncoderViT
from utils import replace_batchnorm
from collections import OrderedDict
from repvit_cfgs import repvit_m1_0_cfgs
from tiny_vit_sam import  TinyViT
# from torchsummary import  summary
from fvcore.nn import FlopCountAnalysis, parameter_count_table, flop_count

device = torch.device('cpu')


In [3]:
class MedSAM_Lite(nn.Module):
    def __init__(self, 
                image_encoder, 
                mask_decoder,
                prompt_encoder
                ):
        super().__init__()
        self.image_encoder = image_encoder  
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

    def forward(self, image, boxes):
        image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
        sparse_embeddings, dense_embeddings = self.prompt_encoder( 
            points=None,
            boxes=boxes,
            masks=None,
        ) # get sparse_embeddings (one-point based and bbox) and z()
        
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          ) # (B, 1, 256, 256)

        return low_res_masks, iou_predictions


In [4]:
from functools import partial

medsam_lite_image_encoder = RepViT(repvit_m1_0_cfgs)
medsam_image_encoder = ImageEncoderViT(
    depth=12,
    embed_dim=768,
    img_size=256,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    out_chans=256,
).to(device)

medsam_lite_prompt_encoder = PromptEncoder(
    embed_dim=256,
    image_embedding_size=(64, 64),
    input_image_size=(256, 256),
    mask_in_chans=16
).to(device)

tiny_lite_image_encoder = TinyViT(
        img_size=256,
        in_chans=3,
        embed_dims=[
            64, ## (64, 256, 256)
            128, ## (128, 128, 128)
            160, ## (160, 64, 64)
            320 ## (320, 64, 64) 
        ],
        depths=[2, 2, 6, 2],
        num_heads=[2, 4, 5, 10],
        window_sizes=[7, 7, 14, 7],
        mlp_ratio=4.,
        drop_rate=0.,
        drop_path_rate=0.0,
        use_checkpoint=False,
        mbconv_expand_ratio=4.0,
        local_conv_size=3,
        layer_lr_decay=0.8
    )
    
medsam_lite_mask_decoder = MaskDecoder(
    num_multimask_outputs=3,
        transformer=TwoWayTransformer(
            depth=2,
            embedding_dim=256,
            mlp_dim=2048,
            num_heads=8,
        ),
        transformer_dim=256,
        iou_head_depth=3,
        iou_head_hidden_dim=256,
).to(device)

rep_medsam =  MedSAM_Lite(
    image_encoder = medsam_lite_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
).to(device)

medsam = MedSAM_Lite(
    image_encoder = medsam_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
)

tiny_sam = MedSAM_Lite(
    image_encoder = tiny_lite_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
)



In [8]:
print(f"MedSAM Image Encoder size:{sum(p.numel() for p in medsam.image_encoder.parameters())}")

MedSAM Image Encoder size:89670912


In [9]:
print(f"MedSAM Image Encoder size:{sum(p.numel() for p in tiny_sam.image_encoder.parameters())}")

MedSAM Image Encoder size:5726740


In [10]:
print(f"MedSAM Image Encoder size:{sum(p.numel() for p in rep_medsam.image_encoder.parameters())}")

MedSAM Image Encoder size:6505532


In [11]:
replace_batchnorm(rep_medsam.image_encoder)

In [12]:
print(f"MedSAM Image Encoder size:{sum(p.numel() for p in rep_medsam.image_encoder.parameters())}")

MedSAM Image Encoder size:6463840


In [6]:
from time import  time


In [7]:
image = torch.rand(3, 256, 256).to(device)
curr_time = time()
output = rep_medsam.image_encoder(image.unsqueeze(0))
end_time = time()
print(f'cost = {end_time - curr_time}')

cost = 0.5688502788543701


In [8]:
image = torch.rand(3, 256, 256).to(device)
curr_time = time()
output = tiny_sam.image_encoder(image.unsqueeze(0))
end_time = time()
print(f'cost = {end_time - curr_time}')

cost = 0.9256021976470947


In [9]:
image = torch.rand(3, 256, 256).to(device)
curr_time = time()
output = medsam.image_encoder(image.unsqueeze(0))
end_time = time()
print(f'cost = {end_time - curr_time}')

cost = 0.6497838497161865


In [6]:
flops = flop_count(model=rep_medsam, inputs= (image, boxes))
print(flops.total())

: 

In [12]:
medsam_image_encoder_flops = FlopCountAnalysis(model=medsam_image_encoder, inputs= (image, ))

In [14]:
print(flop_count_str(medsam_image_encoder_flops))

NameError: name 'flop_count_str' is not defined

RuntimeError: The size of tensor a (16) must match the size of tensor b (64) at non-singleton dimension 2

In [None]:
replace_batchnorm(medsam_lite_image_encoder)
med
summary(model=medsam_lite_image_encoder, input_size=(3, 256, 256),device='cpu' ,batch_size=1)