Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 5d5a8f2 commit 057709f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
9 changes: 2 additions & 7 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor, nn

from .base import _act, _norm
from .vit import MHA, ViTBlock
from .vit import MHA, ViT, ViTBlock


# basically attention pooling
Expand Down Expand Up @@ -152,12 +152,7 @@ def forward(self, imgs: Tensor) -> Tensor:

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)
self.pe = nn.Parameter(pe)
ViT.resize_pe(self, size, interpolation_mode)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT:
Expand Down
6 changes: 3 additions & 3 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def __init__(
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
out = torch.cat([self.cls_token, self.dist_token, out], 1)
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = torch.cat([self.cls_token, self.dist_token, out + self.pe], 1)
out = self.layers(out)
return self.norm(out[:, :2]).mean(1)

@staticmethod
def from_config(variant: str, img_size: int, version: bool = False, pretrained: bool = False) -> DeiT:
def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Expand Down

0 comments on commit 057709f

Please sign in to comment.