In [1]:
import numpy as np
import torch
import torch.nn as nn
from repvit import RepViT, RepViTBlock
from segment_anything.modeling import  PromptEncoder, MaskDecoder, TwoWayTransformer, ImageEncoderViT
from repvit_cfgs import repvit_m1_0_cfgs


In [None]:
model = torch.load('teacher/MedSAM_Enc.pth')


In [5]:
class MedSAM_Lite(nn.Module):
    def __init__(self, 
                image_encoder, 
                mask_decoder,
                prompt_encoder
                ):
        super().__init__()
        self.image_encoder = image_encoder  
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

    def forward(self, image, boxes):
        image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
        sparse_embeddings, dense_embeddings = self.prompt_encoder( 
            points=None,
            boxes=boxes,
            masks=None,
        ) # get sparse_embeddings (one-point based and bbox) and z()
        
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          ) # (B, 1, 256, 256)

        if args.distillation:
            return image_embedding
        return low_res_masks, iou_predictions


In [9]:
parameters = sum(p.numel() for p in medsam_lite_model.parameters())
print(parameters)

10570092


In [7]:
from functools import partial

med_enc = ImageEncoderViT(
    depth=12,
    embed_dim=768,
    img_size=1024,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    out_chans=256,
)

medsam_lite_prompt_encoder = PromptEncoder(
    embed_dim=256,
    image_embedding_size=(64, 64),
    input_image_size=(256, 256),
    mask_in_chans=16
)

medsam_lite_mask_decoder = MaskDecoder(
    num_multimask_outputs=3,
        transformer=TwoWayTransformer(
            depth=2,
            embedding_dim=256,
            mlp_dim=2048,
            num_heads=8,
        ),
        transformer_dim=256,
        iou_head_depth=3,
        iou_head_hidden_dim=256,
)


medsam = MedSAM_Lite(
    image_encoder = med_enc,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
)



In [10]:
medsam.image_encoder.load_state_dict(model, strict=True)

<All keys matched successfully>

In [11]:
medsam.mask_decoder.load_state_dict(torch.load('teacher/mask_decoder.pth'),  strict=True)

<All keys matched successfully>

In [12]:
medsam.prompt_encoder.load_state_dict(torch.load('teacher/prompt_encoder.pth'),  strict=True)

<All keys matched successfully>

In [13]:
torch.save(medsam.state_dict(), 'teacher/medsam.pth')

In [7]:
medsam_lite_image_encoder = RepViT(
    cfgs=repvit_m1_0_cfgs,
    img_size=256
    )



medsam_lite_model = MedSAM_Lite(
    image_encoder = medsam_lite_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
)

medsam_model  = MedSAM_Lite(
    
)

In [None]:
from tiny_vit_sam import TinyViT

medsam_lite_image_encoder = TinyViT(
    img_size=256,
    in_chans=3,
    embed_dims=[
        64, ## (64, 256, 256)
        128, ## (128, 128, 128)
        160, ## (160, 64, 64)
        320 ## (320, 64, 64) 
    ],
    depths=[2, 2, 6, 2],
    num_heads=[2, 4, 5, 10],
    window_sizes=[7, 7, 14, 7],
    mlp_ratio=4.,
    drop_rate=0.,
    drop_path_rate=0.0,
    use_checkpoint=False,
    mbconv_expand_ratio=4.0,
    local_conv_size=3,
    layer_lr_decay=0.8
)

medsam_lite_prompt_encoder = PromptEncoder(
    embed_dim=256,
    image_embedding_size=(64, 64),
    input_image_size=(256, 256),
    mask_in_chans=16
)

medsam_lite_mask_decoder = MaskDecoder(
    num_multimask_outputs=3,
        transformer=TwoWayTransformer(
            depth=2,
            embedding_dim=256,
            mlp_dim=2048,
            num_heads=8,
        ),
        transformer_dim=256,
        iou_head_depth=3,
        iou_head_hidden_dim=256,
)

tinyvit_lite_model = MedSAM_Lite(
    image_encoder = medsam_lite_image_encoder,
    mask_decoder = medsam_lite_mask_decoder,
    prompt_encoder = medsam_lite_prompt_encoder
)

print(f"tinymedsam size:{sum(p.numel() for p in medsam.parameters())}")

In [3]:
embedding = np.load('/mnt/embeddings_npy_1024/MR_BraTS_FLAIR_BraTS-GLI-00000-000-000.npy').squeeze(0)
print(embedding.shape)
tensor_embedding = torch.tensor(embedding).squeeze(0)
embeddings = torch.stack([torch.tensor(embedding), out_put])
print(embeddings.shape)

(256, 64, 64)


NameError: name 'out_put' is not defined

In [24]:
from tiny_vit_sam import TinyViT

tiny_vit = TinyViT(
    img_size=256,
    in_chans=3,
    embed_dims=[
        64, ## (64, 256, 256)
        128, ## (128, 128, 128)
        160, ## (160, 64, 64)
        320 ## (320, 64, 64) 
    ],
    depths=[2, 2, 6, 2],
    num_heads=[2, 4, 5, 10],
    window_sizes=[7, 7, 14, 7],
    mlp_ratio=4.,
    drop_rate=0.,
    drop_path_rate=0.0,
    use_checkpoint=False,
    mbconv_expand_ratio=4.0,
    local_conv_size=3,
    layer_lr_decay=0.8)

tiny_vit.eval()
input_image = torch.randn(1, 3, 256, 256)

In [52]:
img = np.load('/cvpr-data/train_npy_256/imgs/PET_Lesion_PETCT_0beb67c923-019.npy')
print(img.shape)

(256, 256, 3)


In [11]:
input_tensor = torch.randn(1, 3, 16, 16)

upsample_layer = nn.Sequential(
    nn.ConvTranspose2d(in_channels= 3, out_channels= 3, kernel_size= 4, stride=4 )
)
output_tensor = upsample_layer(input_tensor)
print(output_tensor.shape)

torch.Size([1, 3, 64, 64])
