Skip to content

Commit

Permalink
Minor fix to loading models with incompatible layers (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jun 26, 2024
1 parent 8e9750c commit 5e16964
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,11 @@ def _handle_checkpoint_loading(sam, model_state):
reference_state = sam.state_dict()

for k, v in model_state.items():
if reference_state[k].size() == v.size():
new_state_dict[k] = v
else:
mismatched_layers.append(k)
if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM.
if reference_state[k].size() == v.size():
new_state_dict[k] = v
else:
mismatched_layers.append(k)

reference_state.update(new_state_dict)

Expand Down

0 comments on commit 5e16964

Please sign in to comment.