Skip to content

Commit

Permalink
v1.4.3 refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Apr 9, 2023
1 parent 85bb7d8 commit 2b613f7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion hyperbox/mutables/ops/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def __init__(
padding_mode: str,
device=None,
dtype=None,
auto_padding: bool = False,
transposed: bool = False,
output_padding: Tuple[int, ...] = 0,
auto_padding: bool = False,
**kwargs
):
'''Base Conv Module
Expand Down
31 changes: 20 additions & 11 deletions hyperbox/networks/vit/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.heads_list = keepPositiveList([int(r*heads) for r in search_ratio])
self.dim_head_list = keepPositiveList([int(r*dim_head) for r in search_ratio])
self.scale_list = [dh**-0.5 for dh in self.dim_head_list]

self.inner_dim_list = []
self.heads_idx_map = {}
self.dim_head_idx_map = {}
Expand Down Expand Up @@ -196,6 +196,15 @@ def forward(self, x):
return x


class ResidualBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block

def forward(self, x):
return self.block(x) + x


class VitBlock(nn.Module):
def __init__(
self,
Expand All @@ -214,16 +223,16 @@ def __init__(
self.suffix = suffix
attKey = f"att_{suffix}"
ffKey = f"ff_{suffix}"
self.attn = PreNorm(dim, Attention(
self.attn = ResidualBlock(PreNorm(dim, Attention(
dim, heads = heads, dim_head = dim_head, search_ratio=self.search_ratio,
dropout = dropout, suffix = attKey, mask = self.mask))
self.ff = PreNorm(dim, FeedForward(
dropout = dropout, suffix = attKey, mask = self.mask)))
self.ff = ResidualBlock(PreNorm(dim, FeedForward(
dim, mlp_dim, search_ratio=self.search_ratio, dropout = dropout,
suffix = ffKey, mask = self.mask))
suffix = ffKey, mask = self.mask)))

def forward(self, x):
x = self.attn(x) + x
x = self.ff(x) + x
x = self.attn(x)
x = self.ff(x)
return x


Expand Down Expand Up @@ -276,10 +285,6 @@ def __init__(

self.vit_embed = VitEmbedding(image_size, patch_size, channels, dim, emb_dropout)

if self.to_search_path:
runtime_depth = [v for v in range(1, depth + 1)]
self.run_depth = spaces.ValueSpace(runtime_depth, key='run_depth', mask=self.mask)

vit_blocks = [
VitBlock(
dim=dim, heads=heads, dim_head=dim_head, mlp_dim=mlp_dim,
Expand All @@ -289,6 +294,10 @@ def __init__(

self.vit_cls_head = VitClsHead(pool, dim, num_classes)

if self.to_search_path:
runtime_depth = [v for v in range(1, depth + 1)]
self.run_depth = spaces.ValueSpace(runtime_depth, key='run_depth', mask=self.mask)

def forward(self, x):
out = self.vit_embed(x)
if self.to_search_path:
Expand Down

0 comments on commit 2b613f7

Please sign in to comment.