Skip to content

Commit

Permalink
refactor vit and nasbench201
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Apr 26, 2023
1 parent e84c39a commit 453d8bb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion hyperbox/networks/nasbench201/nasbench201.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_classes)
self.init_weights()

def forward(self, inputs):
out = self.stem(inputs)
Expand Down Expand Up @@ -479,5 +478,6 @@ def forward(self, inputs):
print(net.arch_size((2,3,64,64)))
arch_json = net.arch
acc = net.query_by_key()
print(acc)
for t in query_nb201_trial_stats(arch_json, 200, 'cifar10'):
pprint.pprint(t)
2 changes: 1 addition & 1 deletion hyperbox/networks/vit/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def forward(self, x):

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.position_embeddings[:, :(n + 1)]
x = x + self.position_embeddings
x = self.dropout(x)
return x

Expand Down

0 comments on commit 453d8bb

Please sign in to comment.