diff --git a/micro_sam/util.py b/micro_sam/util.py index d96ec0ec..75ebe724 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -409,10 +409,11 @@ def _handle_checkpoint_loading(sam, model_state): reference_state = sam.state_dict() for k, v in model_state.items(): - if reference_state[k].size() == v.size(): - new_state_dict[k] = v - else: - mismatched_layers.append(k) + if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. + if reference_state[k].size() == v.size(): + new_state_dict[k] = v + else: + mismatched_layers.append(k) reference_state.update(new_state_dict)