In [1]:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from modeling.efficient_3dsam.efficient_3dsam_encoder import ImageEncoderViT_3d
from functools import partial
import torch
from modeling.efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits

In [2]:
efficient_sam = build_efficient_sam_vitt()

img_encoder = ImageEncoderViT_3d(
    depth=12,
    embed_dim=192,
    img_size=1024,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    cubic_window_size=8,
    out_chans=256,
    num_slice = 16)

print("img_encoder", img_encoder)
img_encoder.load_state_dict(efficient_sam.image_encoder.state_dict(), strict=False)
del efficient_sam
# img_encoder.to(device)

for p in img_encoder.parameters():
    p.requires_grad = False
img_encoder.depth_embed.requires_grad = True
for p in img_encoder.slice_embed.parameters():
    p.requires_grad = True
for i in img_encoder.blocks:
    for p in i.norm1.parameters():
        p.requires_grad = True
    for p in i.adapter.parameters():
        p.requires_grad = True
    for p in i.norm2.parameters():
        p.requires_grad = True
    i.attn.rel_pos_d = nn.parameter.Parameter(0.5 * (i.attn.rel_pos_h + i.attn.rel_pos_w), requires_grad=True)
for i in img_encoder.neck_3d:
    for p in i.parameters():
        p.requires_grad = True

prompt_encoder_list = []
parameter_list = []
for i in range(4):
    prompt_encoder = PromptEncoder(transformer=TwoWayTransformer(depth=2,
                                                                embedding_dim=256,
                                                                mlp_dim=2048,
                                                                num_heads=8))
    prompt_encoder.to(device)
    prompt_encoder_list.append(prompt_encoder)
    parameter_list.extend([i for i in prompt_encoder.parameters() if i.requires_grad == True])

mask_decoder = VIT_MLAHead(img_size=96, num_classes=2)
mask_decoder.to(device)



img_encoder ImageEncoderViT_3d(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (slice_embed): Conv3d(192, 192, kernel_size=(1, 1, 16), stride=(1, 1, 16), groups=192)
  (blocks): ModuleList(
    (0-11): 12 x Block_3d(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention_3d(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (proj): Linear(in_features=192, out_features=192, bias=True)
      )
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=192, out_features=768, bias=True)
        (lin2): Linear(in_features=768, out_features=192, bias=True)
        (act): GELU(approximate='none')
      )
      (adapter): Adapter(
        (linear1): Linear(in_features=192, out_features=96, bias=True)
        (conv): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=96)
        (l

_IncompatibleKeys(missing_keys=['pos_embed_3d', 'depth_embed', 'slice_embed.weight', 'slice_embed.bias', 'blocks.0.attn_mask', 'blocks.0.attn.rel_pos_h', 'blocks.0.attn.rel_pos_w', 'blocks.0.attn.rel_pos_d', 'blocks.0.attn.lr', 'blocks.0.mlp.lin1.weight', 'blocks.0.mlp.lin1.bias', 'blocks.0.mlp.lin2.weight', 'blocks.0.mlp.lin2.bias', 'blocks.0.adapter.linear1.weight', 'blocks.0.adapter.linear1.bias', 'blocks.0.adapter.conv.weight', 'blocks.0.adapter.conv.bias', 'blocks.0.adapter.linear2.weight', 'blocks.0.adapter.linear2.bias', 'blocks.1.attn.rel_pos_h', 'blocks.1.attn.rel_pos_w', 'blocks.1.attn.rel_pos_d', 'blocks.1.attn.lr', 'blocks.1.mlp.lin1.weight', 'blocks.1.mlp.lin1.bias', 'blocks.1.mlp.lin2.weight', 'blocks.1.mlp.lin2.bias', 'blocks.1.adapter.linear1.weight', 'blocks.1.adapter.linear1.bias', 'blocks.1.adapter.conv.weight', 'blocks.1.adapter.conv.bias', 'blocks.1.adapter.linear2.weight', 'blocks.1.adapter.linear2.bias', 'blocks.2.attn_mask', 'blocks.2.attn.rel_pos_h', 'blocks.2.

In [None]:
sam = sam_model_registry["vit_b"](checkpoint="weights/sam_vit_b_01ec64.pth")

mask_generator = SamAutomaticMaskGenerator(sam)
img_encoder = ImageEncoderViT_3d(
    depth=12,
    embed_dim=768,
    img_size=1024,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    cubic_window_size=8,
    out_chans=256,
    num_slice = 16)

img_encoder.load_state_dict(mask_generator.predictor.model.image_encoder.state_dict(), strict=False)
