From 40241c5b6d174bfe86b6cce3fe47c98521c062a8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 12:04:50 +0800 Subject: [PATCH] fix weight loading for Swin --- vision_toolbox/backbones/swin.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 6de5beb..5b64d1f 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -240,7 +240,14 @@ def rearrange(p): block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index") ) copy_(block.norm1, prefix + "norm1") - copy_(block.mha.in_proj, prefix + "attn.qkv") + q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0) + block.mha.q_proj.weight.copy_(q_w) + block.mha.k_proj.weight.copy_(k_w) + block.mha.v_proj.weight.copy_(v_w) + q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0) + block.mha.q_proj.bias.copy_(q_b) + block.mha.k_proj.bias.copy_(k_b) + block.mha.v_proj.bias.copy_(v_b) copy_(block.mha.out_proj, prefix + "attn.proj") block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T) copy_(block.norm2, prefix + "norm2")