Skip to content

Commit

Permalink
refactor vit_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Feb 22, 2023
1 parent ef65dc8 commit 59b4935
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions hyperbox/networks/vit/vit_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,32 +259,37 @@ def __init__(
dim_head: int = 64, # dimension of each attention head
dropout: float = 0., # dropout rate
emb_dropout: float = 0., # embedding dropout rate
to_search_depth: bool = False,
mask: dict = None, # mask for the search space (mutables)
):
super().__init__(mask = mask)
self.to_search_path = to_search_depth

self.vit_embed = VitEmbedding(image_size, patch_size, channels, dim, emb_dropout)

runtime_depth = [v for v in range(1, depth + 1)]
self.run_depth = spaces.ValueSpace(runtime_depth, key='run_depth', mask=self.mask)
if self.to_search_path:
runtime_depth = [v for v in range(1, depth + 1)]
self.run_depth = spaces.ValueSpace(runtime_depth, key='run_depth', mask=self.mask)

vit_blocks = [
VitBlock(
dim=dim, heads=heads, dim_head=dim_head, mlp_dim=mlp_dim,
search_ratio=search_ratio, dropout=dropout, suffix=i, mask=self.mask
) for i in range(depth)]
self.vit_blocks = nn.Sequential(*vit_blocks)

self.vit_cls_head = VitClsHead(pool, dim, num_classes)
layers = [self.vit_embed] + vit_blocks + [self.vit_cls_head]
self.layers = nn.Sequential(*layers)

def forward(self, x):
runtime_depth = self.run_depth.value
layers = list(self.layers.children())
layers = layers[:runtime_depth+1] + [layers[-1]]
layers = nn.Sequential(*layers)
return layers(x)
if self.to_search_path:
layers = list(self.layers.children())
runtime_depth = self.run_depth.value
layers = layers[:runtime_depth+1] + [layers[-1]]
layers = nn.Sequential(*layers)
return layers(x)
else:
return self.layers(x)


_vit_s = dict(
Expand Down

0 comments on commit 59b4935

Please sign in to comment.