Skip to content

Commit

Permalink
remove patch_size argument for ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 592d41b commit fc91fdb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@


def test_forward():
m = ViT.from_config("Ti", 16, 224)
m = ViT.from_config("Ti_16", 224)
m(torch.randn(1, 3, 224, 224))


def test_resize_pe():
m = ViT.from_config("Ti", 16, 224)
m = ViT.from_config("Ti_16", 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))


def test_from_pretrained():
m = ViT.from_config("Ti", 16, 224, True).eval()
m = ViT.from_config("Ti_16", 224, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

Expand Down
8 changes: 5 additions & 3 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,21 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = ViT(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
assert img_size == 224
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
Expand All @@ -172,8 +176,6 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))
if img_size != 224:
m.resize_pe(img_size)

return m

Expand Down

0 comments on commit fc91fdb

Please sign in to comment.