Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loading the trained checkpoint throws error #2

Closed
pearlmary opened this issue Apr 21, 2023 · 1 comment
Closed

loading the trained checkpoint throws error #2

pearlmary opened this issue Apr 21, 2023 · 1 comment

Comments

@pearlmary
Copy link

**Hi,

If I load the trained last.ckpt to see the output results, it throws error while loading.** Can you please guide as how to test the ckpt.

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "last.ckpt"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registrymodel_type
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

REsults:


RuntimeError Traceback (most recent call last)
Input In [14], in <cell line: 10>()
6 model_type = "vit_h"
8 device = "cuda"
---> 10 sam = sam_model_registrymodel_type
11 sam.to(device=device)
13 mask_generator = SamAutomaticMaskGenerator(sam)

File /workspace/segment-anything-finetuner/segment-anything/segment_anything/build_sam.py:15, in build_sam_vit_h(checkpoint)
14 def build_sam_vit_h(checkpoint=None):
---> 15 return _build_sam(
16 encoder_embed_dim=1280,
17 encoder_depth=32,
18 encoder_num_heads=16,
19 encoder_global_attn_indexes=[7, 15, 23, 31],
20 checkpoint=checkpoint,
21 )

File /workspace/segment-anything-finetuner/segment-anything/segment_anything/build_sam.py:106, in _build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint)
104 with open(checkpoint, "rb") as f:
105 state_dict = torch.load(f)
--> 106 sam.load_state_dict(state_dict)
107 return sam

File /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict)
2036 error_msgs.insert(
2037 0, 'Missing key(s) in state_dict: {}. '.format(
2038 ', '.join('"{}"'.format(k) for k in missing_keys)))
2040 if len(error_msgs) > 0:
-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2042 self.class.name, "\n\t".join(error_msgs)))
2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Sam:
Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_encoder.blocks.1.attn.rel_pos_w", "image_encoder.blocks.1.attn.qkv.weight", "image_encoder.blocks.1.attn.qkv.bias", "image_encoder.blocks.1.attn.proj.weight", "image_encoder.blocks.1.attn.proj.bias", "image_encoder.blocks.1.norm2.weight", "image_encoder.blocks.1.norm2.bias", "image_encoder.blocks.1.mlp.lin1.weight", "image_encoder.blocks.1.mlp.lin1.bias", "image_encoder.blocks.1.mlp.lin2.weight", "image_encoder.blocks.1.mlp.lin2.bias", "image_encoder.blocks.2.norm1.weight", "image_encoder.blocks.2.norm1.bias", "image_encoder.blocks.2.attn.rel_pos_h", "image_encoder.blocks.2.attn.rel_pos_w", "image_encoder.blocks.2.attn.qkv.weight", "image_encoder.blocks.2.attn.qkv.bias", "image_encoder.blocks.2.attn.proj.weight", "image_encoder.blocks.2.attn.proj.bias", "image_encoder.blocks.2.norm2.weight", "image_encoder.blocks.2.norm2.bias", "image_encoder.blocks.2.mlp.lin1.weight", "image_encoder.blocks.2.mlp.lin1.bias", "image_encoder.blocks.2.mlp.lin2.weight", "image_encoder.blocks.2.mlp.lin2.bias", "image_encoder.blocks.3.norm1.weight", "image_encoder.blocks.3.norm1.bias", "image_encoder.blocks.3.attn.rel_pos_h", "image_encoder.blocks.3.attn.rel_pos_w", "image_encoder.blocks.3.attn.qkv.weight", "image_encoder.blocks.3.attn.qkv.bias", "image_encoder.blocks.3.attn.proj.weight", "image_encoder.blocks.3.attn.proj.bias", "image_encoder.blocks.3.norm2.weight", "image_encoder.blocks.3.norm2.bias", "image_encoder.blocks.3.mlp.lin1.weight", "image_encoder.blocks.3.mlp.lin1.bias", "image_encoder.blocks.3.mlp.lin2.weight", "image_encoder.blocks.3.mlp.lin2.bias", "image_encoder.blocks.4.norm1.weight", "image_encoder.blocks.4.norm1.bias", "image_encoder.blocks.4.attn.rel_pos_h", "image_encoder.blocks.4.attn.rel_pos_w", "image_encoder.blocks.4.attn.qkv.weight", "image_encoder.blocks.4.attn.qkv.bias", "image_encoder.blocks.4.attn.proj.weight", "image_encoder.blocks.4.attn.proj.bias", "image_encoder.blocks.4.norm2.weight", "image_encoder.blocks.4.norm2.bias", "image_encoder.blocks.4.mlp.lin1.weight", "image_encoder.blocks.4.mlp.lin1.bias", "image_encoder.blocks.4.mlp.lin2.weight", "image_encoder.blocks.4.mlp.lin2.bias", "image_encoder.blocks.5.norm1.weight", "image_encoder.blocks.5.norm1.bias", "image_encoder.blocks.5.attn.rel_pos_h", "image_encoder.blocks.5.attn.rel_pos_w", "image_encoder.blocks.5.attn.qkv.weight", "image_encoder.blocks.5.attn.qkv.bias", "image_encoder.blocks.5.attn.proj.weight", "image_encoder.blocks.5.attn.proj.bias", "image_encoder.blocks.5.norm2.weight", "image_encoder.blocks.5.norm2.bias", "image_encoder.blocks.5.mlp.lin1.weight", "image_encoder.blocks.5.mlp.lin1.bias", "image_encoder.blocks.5.mlp.lin2.weight", "image_encoder.blocks.5.mlp.lin2.bias", "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias", "image_encoder.blocks.12.norm1.weight", "image_encoder.blocks.12.norm1.bias", "image_encoder.blocks.12.attn.rel_pos_h", "image_encoder.blocks.12.attn.rel_pos_w", "image_encoder.blocks.12.attn.qkv.weight", "image_encoder.blocks.12.attn.qkv.bias", "image_encoder.blocks.12.attn.proj.weight", "image_encoder.blocks.12.attn.proj.bias", "image_encoder.blocks.12.norm2.weight", "image_encoder.blocks.12.norm2.bias", "image_encoder.blocks.12.mlp.lin1.weight", "image_encoder.blocks.12.mlp.lin1.bias", "image_encoder.blocks.12.mlp.lin2.weight", "image_encoder.blocks.12.mlp.lin2.bias", "image_encoder.blocks.13.norm1.weight", "image_encoder.blocks.13.norm1.bias", "image_encoder.blocks.13.attn.rel_pos_h", "image_encoder.blocks.13.attn.rel_pos_w", "image_encoder.blocks.13.attn.qkv.weight", "image_encoder.blocks.13.attn.qkv.bias", "image_encoder.blocks.13.attn.proj.weight", "image_encoder.blocks.13.attn.proj.bias", "image_encoder.blocks.13.norm2.weight", "image_encoder.blocks.13.norm2.bias", "image_encoder.blocks.13.mlp.lin1.weight", "image_encoder.blocks.13.mlp.lin1.bias", "image_encoder.blocks.13.mlp.lin2.weight", "image_encoder.blocks.13.mlp.lin2.bias", "image_encoder.blocks.14.norm1.weight", "image_encoder.blocks.14.norm1.bias", "image_encoder.blocks.14.attn.rel_pos_h", "image_encoder.blocks.14.attn.rel_pos_w", "image_encoder.blocks.14.attn.qkv.weight", "image_encoder.blocks.14.attn.qkv.bias", "image_encoder.blocks.14.attn.proj.weight", "image_encoder.blocks.14.attn.proj.bias", "image_encoder.blocks.14.norm2.weight", "image_encoder.blocks.14.norm2.bias", "image_encoder.blocks.14.mlp.lin1.weight", "image_encoder.blocks.14.mlp.lin1.bias", "image_encoder.blocks.14.mlp.lin2.weight", "image_encoder.blocks.14.mlp.lin2.bias", "image_encoder.blocks.15.norm1.weight", "image_encoder.blocks.15.norm1.bias", "image_encoder.blocks.15.attn.rel_pos_h", "image_encoder.blocks.15.attn.rel_pos_w", "image_encoder.blocks.15.attn.qkv.weight", "image_encoder.blocks.15.attn.qkv.bias", "image_encoder.blocks.15.attn.proj.weight", "image_encoder.blocks.15.attn.proj.bias", "image_encoder.blocks.15.norm2.weight", "image_encoder.blocks.15.norm2.bias", "image_encoder.blocks.15.mlp.lin1.weight", "image_encoder.blocks.15.mlp.lin1.bias", "image_encoder.blocks.15.mlp.lin2.weight", "image_encoder.blocks.15.mlp.lin2.bias", "image_encoder.blocks.16.norm1.weight", "image_encoder.blocks.16.norm1.bias", "image_encoder.blocks.16.attn.rel_pos_h", "image_encoder.blocks.16.attn.rel_pos_w", "image_encoder.blocks.16.attn.qkv.weight", "image_encoder.blocks.16.attn.qkv.bias", "image_encoder.blocks.16.attn.proj.weight", "image_encoder.blocks.16.attn.proj.bias", "image_encoder.blocks.16.norm2.weight", "image_encoder.blocks.16.norm2.bias", "image_encoder.blocks.16.mlp.lin1.weight", "image_encoder.blocks.16.mlp.lin1.bias", "image_encoder.blocks.16.mlp.lin2.weight", "image_encoder.blocks.16.mlp.lin2.bias", "image_encoder.blocks.17.norm1.weight", "image_encoder.blocks.17.norm1.bias", "image_encoder.blocks.17.attn.rel_pos_h", "image_encoder.blocks.17.attn.rel_pos_w", "image_encoder.blocks.17.attn.qkv.weight", "image_encoder.blocks.17.attn.qkv.bias", "image_encoder.blocks.17.attn.proj.weight", "image_encoder.blocks.17.attn.proj.bias", "image_encoder.blocks.17.norm2.weight", "image_encoder.blocks.17.norm2.bias", "image_encoder.blocks.17.mlp.lin1.weight", "image_encoder.blocks.17.mlp.lin1.bias", "image_encoder.blocks.17.mlp.lin2.weight", "image_encoder.blocks.17.mlp.lin2.bias", "image_encoder.blocks.18.norm1.weight", "image_encoder.blocks.18.norm1.bias", "image_encoder.blocks.18.attn.rel_pos_h", "image_encoder.blocks.18.attn.rel_pos_w", "image_encoder.blocks.18.attn.qkv.weight", "image_encoder.blocks.18.attn.qkv.bias", "image_encoder.blocks.18.attn.proj.weight", "image_encoder.blocks.18.attn.proj.bias", "image_encoder.blocks.18.norm2.weight", "image_encoder.blocks.18.norm2.bias", "image_encoder.blocks.18.mlp.lin1.weight", "image_encoder.blocks.18.mlp.lin1.bias", "image_encoder.blocks.18.mlp.lin2.weight", "image_encoder.blocks.18.mlp.lin2.bias", "image_encoder.blocks.19.norm1.weight", "image_encoder.blocks.19.norm1.bias", "image_encoder.blocks.19.attn.rel_pos_h", "image_encoder.blocks.19.attn.rel_pos_w", "image_encoder.blocks.19.attn.qkv.weight", "image_encoder.blocks.19.attn.qkv.bias", "image_encoder.blocks.19.attn.proj.weight", "image_encoder.blocks.19.attn.proj.bias", "image_encoder.blocks.19.norm2.weight", "image_encoder.blocks.19.norm2.bias", "image_encoder.blocks.19.mlp.lin1.weight", "image_encoder.blocks.19.mlp.lin1.bias", "image_encoder.blocks.19.mlp.lin2.weight", "image_encoder.blocks.19.mlp.lin2.bias", "image_encoder.blocks.20.norm1.weight", "image_encoder.blocks.20.norm1.bias", "image_encoder.blocks.20.attn.rel_pos_h", "image_encoder.blocks.20.attn.rel_pos_w", "image_encoder.blocks.20.attn.qkv.weight", "image_encoder.blocks.20.attn.qkv.bias", "image_encoder.blocks.20.attn.proj.weight", "image_encoder.blocks.20.attn.proj.bias", "image_encoder.blocks.20.norm2.weight", "image_encoder.blocks.20.norm2.bias", "image_encoder.blocks.20.mlp.lin1.weight", "image_encoder.blocks.20.mlp.lin1.bias", "image_encoder.blocks.20.mlp.lin2.weight", "image_encoder.blocks.20.mlp.lin2.bias", "image_encoder.blocks.21.norm1.weight", "image_encoder.blocks.21.norm1.bias", "image_encoder.blocks.21.attn.rel_pos_h", "image_encoder.blocks.21.attn.rel_pos_w", "image_encoder.blocks.21.attn.qkv.weight", "image_encoder.blocks.21.attn.qkv.bias", "image_encoder.blocks.21.attn.proj.weight", "image_encoder.blocks.21.attn.proj.bias", "image_encoder.blocks.21.norm2.weight", "image_encoder.blocks.21.norm2.bias", "image_encoder.blocks.21.mlp.lin1.weight", "image_encoder.blocks.21.mlp.lin1.bias", "image_encoder.blocks.21.mlp.lin2.weight", "image_encoder.blocks.21.mlp.lin2.bias", "image_encoder.blocks.22.norm1.weight", "image_encoder.blocks.22.norm1.bias", "image_encoder.blocks.22.attn.rel_pos_h", "image_encoder.blocks.22.attn.rel_pos_w", "image_encoder.blocks.22.attn.qkv.weight", "image_encoder.blocks.22.attn.qkv.bias", "image_encoder.blocks.22.attn.proj.weight", "image_encoder.blocks.22.attn.proj.bias", "image_encoder.blocks.22.norm2.weight", "image_encoder.blocks.22.norm2.bias", "image_encoder.blocks.22.mlp.lin1.weight", "image_encoder.blocks.22.mlp.lin1.bias", "image_encoder.blocks.22.mlp.lin2.weight", "image_encoder.blocks.22.mlp.lin2.bias", "image_encoder.blocks.23.norm1.weight", "image_encoder.blocks.23.norm1.bias", "image_encoder.blocks.23.attn.rel_pos_h", "image_encoder.blocks.23.attn.rel_pos_w", "image_encoder.blocks.23.attn.qkv.weight", "image_encoder.blocks.23.attn.qkv.bias", "image_encoder.blocks.23.attn.proj.weight", "image_encoder.blocks.23.attn.proj.bias", "image_encoder.blocks.23.norm2.weight", "image_encoder.blocks.23.norm2.bias", "image_encoder.blocks.23.mlp.lin1.weight", "image_encoder.blocks.23.mlp.lin1.bias", "image_encoder.blocks.23.mlp.lin2.weight", "image_encoder.blocks.23.mlp.lin2.bias", "image_encoder.blocks.24.norm1.weight", "image_encoder.blocks.24.norm1.bias", "image_encoder.blocks.24.attn.rel_pos_h", "image_encoder.blocks.24.attn.rel_pos_w", "image_encoder.blocks.24.attn.qkv.weight", "image_encoder.blocks.24.attn.qkv.bias", "image_encoder.blocks.24.attn.proj.weight", "image_encoder.blocks.24.attn.proj.bias", "image_encoder.blocks.24.norm2.weight", "image_encoder.blocks.24.norm2.bias", "image_encoder.blocks.24.mlp.lin1.weight", "image_encoder.blocks.24.mlp.lin1.bias", "image_encoder.blocks.24.mlp.lin2.weight", "image_encoder.blocks.24.mlp.lin2.bias", "image_encoder.blocks.25.norm1.weight", "image_encoder.blocks.25.norm1.bias", "image_encoder.blocks.25.attn.rel_pos_h", "image_encoder.blocks.25.attn.rel_pos_w", "image_encoder.blocks.25.attn.qkv.weight", "image_encoder.blocks.25.attn.qkv.bias", "image_encoder.blocks.25.attn.proj.weight", "image_encoder.blocks.25.attn.proj.bias", "image_encoder.blocks.25.norm2.weight", "image_encoder.blocks.25.norm2.bias", "image_encoder.blocks.25.mlp.lin1.weight", "image_encoder.blocks.25.mlp.lin1.bias", "image_encoder.blocks.25.mlp.lin2.weight", "image_encoder.blocks.25.mlp.lin2.bias", "image_encoder.blocks.26.norm1.weight", "image_encoder.blocks.26.norm1.bias", "image_encoder.blocks.26.attn.rel_pos_h", "image_encoder.blocks.26.attn.rel_pos_w", "image_encoder.blocks.26.attn.qkv.weight", "image_encoder.blocks.26.attn.qkv.bias", "image_encoder.blocks.26.attn.proj.weight", "image_encoder.blocks.26.attn.proj.bias", "image_encoder.blocks.26.norm2.weight", "image_encoder.blocks.26.norm2.bias", "image_encoder.blocks.26.mlp.lin1.weight", "image_encoder.blocks.26.mlp.lin1.bias", "image_encoder.blocks.26.mlp.lin2.weight", "image_encoder.blocks.26.mlp.lin2.bias", "image_encoder.blocks.27.norm1.weight", "image_encoder.blocks.27.norm1.bias", "image_encoder.blocks.27.attn.rel_pos_h", "image_encoder.blocks.27.attn.rel_pos_w", "image_encoder.blocks.27.attn.qkv.weight", "image_encoder.blocks.27.attn.qkv.bias", "image_encoder.blocks.27.attn.proj.weight", "image_encoder.blocks.27.attn.proj.bias", "image_encoder.blocks.27.norm2.weight", "image_encoder.blocks.27.norm2.bias", "image_encoder.blocks.27.mlp.lin1.weight", "image_encoder.blocks.27.mlp.lin1.bias", "image_encoder.blocks.27.mlp.lin2.weight", "image_encoder.blocks.27.mlp.lin2.bias", "image_encoder.blocks.28.norm1.weight", "image_encoder.blocks.28.norm1.bias", "image_encoder.blocks.28.attn.rel_pos_h", "image_encoder.blocks.28.attn.rel_pos_w", "image_encoder.blocks.28.attn.qkv.weight", "image_encoder.blocks.28.attn.qkv.bias", "image_encoder.blocks.28.attn.proj.weight", "image_encoder.blocks.28.attn.proj.bias", "image_encoder.blocks.28.norm2.weight", "image_encoder.blocks.28.norm2.bias", "image_encoder.blocks.28.mlp.lin1.weight", "image_encoder.blocks.28.mlp.lin1.bias", "image_encoder.blocks.28.mlp.lin2.weight", "image_encoder.blocks.28.mlp.lin2.bias", "image_encoder.blocks.29.norm1.weight", "image_encoder.blocks.29.norm1.bias", "image_encoder.blocks.29.attn.rel_pos_h", "image_encoder.blocks.29.attn.rel_pos_w", "image_encoder.blocks.29.attn.qkv.weight", "image_encoder.blocks.29.attn.qkv.bias", "image_encoder.blocks.29.attn.proj.weight", "image_encoder.blocks.29.attn.proj.bias", "image_encoder.blocks.29.norm2.weight", "image_encoder.blocks.29.norm2.bias", "image_encoder.blocks.29.mlp.lin1.weight", "image_encoder.blocks.29.mlp.lin1.bias", "image_encoder.blocks.29.mlp.lin2.weight", "image_encoder.blocks.29.mlp.lin2.bias", "image_encoder.blocks.30.norm1.weight", "image_encoder.blocks.30.norm1.bias", "image_encoder.blocks.30.attn.rel_pos_h", "image_encoder.blocks.30.attn.rel_pos_w", "image_encoder.blocks.30.attn.qkv.weight", "image_encoder.blocks.30.attn.qkv.bias", "image_encoder.blocks.30.attn.proj.weight", "image_encoder.blocks.30.attn.proj.bias", "image_encoder.blocks.30.norm2.weight", "image_encoder.blocks.30.norm2.bias", "image_encoder.blocks.30.mlp.lin1.weight", "image_encoder.blocks.30.mlp.lin1.bias", "image_encoder.blocks.30.mlp.lin2.weight", "image_encoder.blocks.30.mlp.lin2.bias", "image_encoder.blocks.31.norm1.weight", "image_encoder.blocks.31.norm1.bias", "image_encoder.blocks.31.attn.rel_pos_h", "image_encoder.blocks.31.attn.rel_pos_w", "image_encoder.blocks.31.attn.qkv.weight", "image_encoder.blocks.31.attn.qkv.bias", "image_encoder.blocks.31.attn.proj.weight", "image_encoder.blocks.31.attn.proj.bias", "image_encoder.blocks.31.norm2.weight", "image_encoder.blocks.31.norm2.bias", "image_encoder.blocks.31.mlp.lin1.weight", "image_encoder.blocks.31.mlp.lin1.bias", "image_encoder.blocks.31.mlp.lin2.weight", "image_encoder.blocks.31.mlp.lin2.bias", "image_encoder.neck.0.weight", "image_encoder.neck.1.weight", "image_encoder.neck.1.bias", "image_encoder.neck.2.weight", "image_encoder.neck.3.weight", "image_encoder.neck.3.bias", "prompt_encoder.pe_layer.positional_encoding_gaussian_matrix", "prompt_encoder.point_embeddings.0.weight", "prompt_encoder.point_embeddings.1.weight", "prompt_encoder.point_embeddings.2.weight", "prompt_encoder.point_embeddings.3.weight", "prompt_encoder.not_a_point_embed.weight", "prompt_encoder.mask_downscaling.0.weight", "prompt_encoder.mask_downscaling.0.bias", "prompt_encoder.mask_downscaling.1.weight", "prompt_encoder.mask_downscaling.1.bias", "prompt_encoder.mask_downscaling.3.weight", "prompt_encoder.mask_downscaling.3.bias", "prompt_encoder.mask_downscaling.4.weight", "prompt_encoder.mask_downscaling.4.bias", "prompt_encoder.mask_downscaling.6.weight", "prompt_encoder.mask_downscaling.6.bias", "prompt_encoder.no_mask_embed.weight", "mask_decoder.transformer.layers.0.self_attn.q_proj.weight", "mask_decoder.transformer.layers.0.self_attn.q_proj.bias", "mask_decoder.transformer.layers.0.self_attn.k_proj.weight", "mask_decoder.transformer.layers.0.self_attn.k_proj.bias", "mask_decoder.transformer.layers.0.self_attn.v_proj.weight", "mask_decoder.transformer.layers.0.self_attn.v_proj.bias", "mask_decoder.transformer.layers.0.self_attn.out_proj.weight", "mask_decoder.transformer.layers.0.self_attn.out_proj.bias", "mask_decoder.transformer.layers.0.norm1.weight", "mask_decoder.transformer.layers.0.norm1.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias", "mask_decoder.transformer.layers.0.norm2.weight", "mask_decoder.transformer.layers.0.norm2.bias", "mask_decoder.transformer.layers.0.mlp.lin1.weight", "mask_decoder.transformer.layers.0.mlp.lin1.bias", "mask_decoder.transformer.layers.0.mlp.lin2.weight", "mask_decoder.transformer.layers.0.mlp.lin2.bias", "mask_decoder.transformer.layers.0.norm3.weight", "mask_decoder.transformer.layers.0.norm3.bias", "mask_decoder.transformer.layers.0.norm4.weight", "mask_decoder.transformer.layers.0.norm4.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias", "mask_decoder.transformer.layers.1.self_attn.q_proj.weight", "mask_decoder.transformer.layers.1.self_attn.q_proj.bias", "mask_decoder.transformer.layers.1.self_attn.k_proj.weight", "mask_decoder.transformer.layers.1.self_attn.k_proj.bias", "mask_decoder.transformer.layers.1.self_attn.v_proj.weight", "mask_decoder.transformer.layers.1.self_attn.v_proj.bias", "mask_decoder.transformer.layers.1.self_attn.out_proj.weight", "mask_decoder.transformer.layers.1.self_attn.out_proj.bias", "mask_decoder.transformer.layers.1.norm1.weight", "mask_decoder.transformer.layers.1.norm1.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias", "mask_decoder.transformer.layers.1.norm2.weight", "mask_decoder.transformer.layers.1.norm2.bias", "mask_decoder.transformer.layers.1.mlp.lin1.weight", "mask_decoder.transformer.layers.1.mlp.lin1.bias", "mask_decoder.transformer.layers.1.mlp.lin2.weight", "mask_decoder.transformer.layers.1.mlp.lin2.bias", "mask_decoder.transformer.layers.1.norm3.weight", "mask_decoder.transformer.layers.1.norm3.bias", "mask_decoder.transformer.layers.1.norm4.weight", "mask_decoder.transformer.layers.1.norm4.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.q_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.q_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.k_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.k_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.v_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.v_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.out_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.out_proj.bias", "mask_decoder.transformer.norm_final_attn.weight", "mask_decoder.transformer.norm_final_attn.bias", "mask_decoder.iou_token.weight", "mask_decoder.mask_tokens.weight", "mask_decoder.output_upscaling.0.weight", "mask_decoder.output_upscaling.0.bias", "mask_decoder.output_upscaling.1.weight", "mask_decoder.output_upscaling.1.bias", "mask_decoder.output_upscaling.3.weight", "mask_decoder.output_upscaling.3.bias", "mask_decoder.output_hypernetworks_mlps.0.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.0.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.0.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.0.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.0.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.0.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.1.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.1.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.1.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.1.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.1.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.1.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.2.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.2.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.2.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.2.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.2.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.2.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.2.bias", "mask_decoder.iou_prediction_head.layers.0.weight", "mask_decoder.iou_prediction_head.layers.0.bias", "mask_decoder.iou_prediction_head.layers.1.weight", "mask_decoder.iou_prediction_head.layers.1.bias", "mask_decoder.iou_prediction_head.layers.2.weight", "mask_decoder.iou_prediction_head.layers.2.bias".
Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "loops".

@bhpfelix
Copy link
Owner

Hi,

Due to the SAMFinetuner wrapper outside of SAM, each key in the state_dict will have an extra 'model.' prefix, which you can remove by, for example, looping through each key in the state_dict and replace it with new_key = key.replace('model.', '', 1). After the prefix is removed, you can load the checkpoint as usual.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants