Skip to content

Commit

Permalink
fix weight loading for Swin
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent a363553 commit 40241c5
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def rearrange(p):
block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index")
)
copy_(block.norm1, prefix + "norm1")
copy_(block.mha.in_proj, prefix + "attn.qkv")
q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0)
block.mha.q_proj.weight.copy_(q_w)
block.mha.k_proj.weight.copy_(k_w)
block.mha.v_proj.weight.copy_(v_w)
q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0)
block.mha.q_proj.bias.copy_(q_b)
block.mha.k_proj.bias.copy_(k_b)
block.mha.v_proj.bias.copy_(v_b)
copy_(block.mha.out_proj, prefix + "attn.proj")
block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T)
copy_(block.norm2, prefix + "norm2")
Expand Down

0 comments on commit 40241c5

Please sign in to comment.