In [1]:
import numpy as np
import torch
import torch.nn as nn
from repvit import RepViT, RepViTBlock
from segment_anything.modeling import  PromptEncoder, MaskDecoder, TwoWayTransformer, ImageEncoderViT
from utils import replace_batchnorm
from collections import OrderedDict
from repvit_cfgs import repvit_m1_0_cfgs
from torchsummary import  summary
from fvcore.nn import FlopCountAnalysis, parameter_count_table, flop_count

device = torch.device('cpu')


In [2]:
class MedSAM_Lite(nn.Module):
    def __init__(self, 
                image_encoder, 
                mask_decoder,
                prompt_encoder
                ):
        super().__init__()
        self.image_encoder = image_encoder  
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

    def forward(self, image, boxes):
        image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
        sparse_embeddings, dense_embeddings = self.prompt_encoder( 
            points=None,
            boxes=boxes,
            masks=None,
        ) # get sparse_embeddings (one-point based and bbox) and z()
        
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          ) # (B, 1, 256, 256)

        return low_res_masks, iou_predictions


In [3]:
from functools import partial

medsam_lite_image_encoder = RepViT(repvit_m1_0_cfgs)
medsam_image_encoder = ImageEncoderViT(
    depth=12,
    embed_dim=768,
    img_size=1024,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    out_chans=256,
).to(device)

medsam_lite_prompt_encoder = PromptEncoder(
    embed_dim=256,
    image_embedding_size=(64, 64),
    input_image_size=(256, 256),
    mask_in_chans=16
).to(device)

medsam_lite_mask_decoder = MaskDecoder(
    num_multimask_outputs=3,
        transformer=TwoWayTransformer(
            depth=2,
            embedding_dim=256,
            mlp_dim=2048,
            num_heads=8,
        ),
        transformer_dim=256,
        iou_head_depth=3,
        iou_head_hidden_dim=256,
).to(device)

rep_medsam =  MedSAM_Lite(
    image_encoder = medsam_lite_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
).to(device)

medsam = MedSAM_Lite(
    image_encoder = medsam_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
)

# medsam = MedSAM_Lite(
#     image_encoder = med_enc,
#     mask_decoder = medsam_lite_mask_decoder,
#     prompt_encoder = medsam_lite_prompt_encoder
# )



In [6]:
summary(model=medsam_lite_image_encoder, input_size=(3, 256, 256), device='cuda:0' ,batch_size=16)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [16, 28, 256, 256]             756
       BatchNorm2d-2         [16, 28, 256, 256]              56
              GELU-3         [16, 28, 256, 256]               0
            Conv2d-4         [16, 56, 256, 256]          14,112
       BatchNorm2d-5         [16, 56, 256, 256]             112
            Conv2d-6         [16, 56, 256, 256]             504
       BatchNorm2d-7         [16, 56, 256, 256]             112
            Conv2d-8         [16, 56, 256, 256]             112
       BatchNorm2d-9         [16, 56, 256, 256]             112
         RepVGGDW-10         [16, 56, 256, 256]               0
           Conv2d-11             [16, 16, 1, 1]             912
         Identity-12             [16, 16, 1, 1]               0
             ReLU-13             [16, 16, 1, 1]               0
           Conv2d-14             [16, 5

In [4]:
image = torch.rand(16, 3, 256, 256).to(device)
boxes = torch.randint(low=0, high=256, size=(16, 1, 4)).to(device)

In [5]:
if image.is_cuda:
    print('!')
if boxes.is_cuda:
    print('!')
print(next(rep_medsam.parameters()).device)


cpu


In [6]:
flops = flop_count(model=rep_medsam, inputs= (image, boxes))
print(flops.total())

: 

In [12]:
medsam_image_encoder_flops = FlopCountAnalysis(model=medsam_image_encoder, inputs= (image, ))

In [14]:
print(flop_count_str(medsam_image_encoder_flops))

NameError: name 'flop_count_str' is not defined

RuntimeError: The size of tensor a (16) must match the size of tensor b (64) at non-singleton dimension 2

In [None]:
replace_batchnorm(medsam_lite_image_encoder)
med
summary(model=medsam_lite_image_encoder, input_size=(3, 256, 256),device='cpu' ,batch_size=1)