Skip to content

Commit

Permalink
remove pe for cls token
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent c99f7eb commit 8971381
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ def __init__(
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None

pe_size = (img_size // patch_size) ** 2
if cls_token:
pe_size += 1
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
Expand All @@ -127,24 +123,19 @@ def __init__(
self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out + self.pe)
out = self.layers(out)
return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1)

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
pe = self.pe if self.cls_token is None else self.pe[:, 1:]

old_size = int(pe.shape[1] ** 0.5)
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)

if self.cls_token is not None:
pe = torch.cat((self.pe[:, :1], pe), 1)
self.pe = nn.Parameter(pe)

@staticmethod
Expand Down Expand Up @@ -186,9 +177,11 @@ def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.cls_token.copy_(get_w("cls"))
pe = get_w("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))
self.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

for idx, layer in enumerate(self.layers):
layer: ViTBlock
Expand Down

0 comments on commit 8971381

Please sign in to comment.