Skip to content

Commit

Permalink
v1.4.4 refactor vit
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Apr 11, 2023
1 parent 7d12050 commit 5be40c8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
41 changes: 29 additions & 12 deletions hyperbox/networks/vit/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'ViT',
'ViT_S',
'ViT_B',
'ViT_L',
'ViT_H',
'ViT_G',
'ViT_10B',
Expand Down Expand Up @@ -58,10 +59,10 @@ def __init__(
hidden_dim_list = keepPositiveList([int(r*hidden_dim) for r in search_ratio])
hidden_dim_list = spaces.ValueSpace(hidden_dim_list, key=f"{suffix}_hidden_dim", mask=self.mask) if len(hidden_dim_list) > 1 else hidden_dim_list[0]
self.net = nn.Sequential(
ops.Linear(dim, hidden_dim_list),
ops.Linear(dim, hidden_dim_list, bias=True),
nn.GELU(),
nn.Dropout(dropout),
ops.Linear(hidden_dim_list, dim),
ops.Linear(hidden_dim_list, dim, bias=True),
nn.Dropout(dropout)
)

Expand Down Expand Up @@ -108,10 +109,10 @@ def __init__(
qkv_dim_list = spaces.ValueSpace(qkv_dim_list, key=f"{suffix}_inner_dim", mask=self.mask) # coupled with self.inner_dim_list
else:
qkv_dim_list = self.inner_dim_list * 3
self.to_qkv = ops.Linear(dim, qkv_dim_list, bias = False)
self.to_qkv = ops.Linear(dim, qkv_dim_list, bias=True)

self.to_out = nn.Sequential(
ops.Linear(self.inner_dim_list, dim),
ops.Linear(self.inner_dim_list, dim, bias=True),
nn.Dropout(dropout)
)

Expand Down Expand Up @@ -177,21 +178,20 @@ def __init__(
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_embeddings = nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size)
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
b, n = x.shape[:2]

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.position_embeddings[:, :(n + 1)]
x = self.dropout(x)
return x

Expand Down Expand Up @@ -344,6 +344,22 @@ def forward(self, x):
emb_dropout=0.1,
)

_vit_l = dict(
image_size=224,
patch_size=16,
num_classes=1000,
dim=1024,
depth=24,
heads=16,
dim_head=64,
mlp_dim=4096,
search_ratio=[0.5, 0.75, 1],
pool='cls',
channels=3,
dropout=0.1,
emb_dropout=0.1,
)

_vit_h = dict(
image_size=224,
patch_size=16,
Expand Down Expand Up @@ -395,6 +411,7 @@ def forward(self, x):
ViT = partial(VisionTransformer, **_vit_b)
ViT_S = partial(VisionTransformer, **_vit_s)
ViT_B = partial(VisionTransformer, **_vit_b)
ViT_L = partial(VisionTransformer, **_vit_l)
ViT_H = partial(VisionTransformer, **_vit_h)
ViT_G = partial(VisionTransformer, **_vit_g)
ViT_10B = partial(VisionTransformer, **_vit_10b)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="hyperbox", # you should change "src" to your project name
version="1.4.3",
version="1.4.4",
description="Hyperbox: An easy-to-use NAS framework.",
author="marsggbo",
url="https://github.com/marsggbo/hyperbox",
Expand Down

0 comments on commit 5be40c8

Please sign in to comment.