Skip to content

Commit

Permalink
add weight loading. fix layer norm eps
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent adda4ff commit 59d487c
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
# fmt: off
Expand All @@ -89,7 +89,7 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
# fmt: off
Expand All @@ -115,7 +115,7 @@ def __init__(
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = nn.LayerNorm,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
Expand Down Expand Up @@ -143,14 +143,12 @@ def __init__(

def forward(self, imgs: Tensor) -> Tensor:
patches = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
print(patches.shape)
print(self.pe.shape)
patches = self.sa_layers(patches + self.pe)

cls_token = self.cls_token
for block in self.ca_layers:
cls_token = block(patches, cls_token)
return self.norm(cls_token)
return self.norm(cls_token.squeeze(1))

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT:
Expand All @@ -177,11 +175,60 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT:
m_48_448="M48_448.pth",
)[f"{variant}_{sa_depth}_{img_size}"]
base_url = "https://dl.fbaipublicfiles.com/deit/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
raise NotImplementedError()
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str):
m.weight.copy_(state_dict.pop(prefix + ".weight").view(m.weight.shape))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
self.cls_token.copy_(state_dict.pop("cls_token"))
self.pe.copy_(state_dict.pop("pos_embed"))

for i, sa_block in enumerate(self.sa_layers):
sa_block: CaiTSABlock
prefix = f"blocks.{i}."

copy_(sa_block.mha[0], prefix + "norm1")
q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0)
sa_block.mha[1].q_proj.weight.copy_(q_w)
sa_block.mha[1].k_proj.weight.copy_(k_w)
sa_block.mha[1].v_proj.weight.copy_(v_w)
q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0)
sa_block.mha[1].q_proj.bias.copy_(q_b)
sa_block.mha[1].k_proj.bias.copy_(k_b)
sa_block.mha[1].v_proj.bias.copy_(v_b)
copy_(sa_block.mha[1].out_proj, prefix + "attn.proj")
copy_(sa_block.mha[1].talking_head_proj[0], prefix + "attn.proj_l")
copy_(sa_block.mha[1].talking_head_proj[2], prefix + "attn.proj_w")
sa_block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1"))

copy_(sa_block.mlp[0], prefix + "norm2")
copy_(sa_block.mlp[1].linear1, prefix + "mlp.fc1")
copy_(sa_block.mlp[1].linear2, prefix + "mlp.fc2")
sa_block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2"))

for i, ca_block in enumerate(self.ca_layers):
ca_block: CaiTCABlock
prefix = f"blocks_token_only.{i}."

copy_(ca_block.mha[0], prefix + "norm1")
copy_(ca_block.mha[1].q_proj, prefix + "attn.q")
copy_(ca_block.mha[1].k_proj, prefix + "attn.k")
copy_(ca_block.mha[1].v_proj, prefix + "attn.v")
copy_(ca_block.mha[1].out_proj, prefix + "attn.proj")
ca_block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1"))

copy_(ca_block.mlp[0], prefix + "norm2")
copy_(ca_block.mlp[1].linear1, prefix + "mlp.fc1")
copy_(ca_block.mlp[1].linear2, prefix + "mlp.fc2")
ca_block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2"))

copy_(self.norm, "norm")
assert len(state_dict) == 2

0 comments on commit 59d487c

Please sign in to comment.