Skip to content

Commit

Permalink
integrate ground truth (json file) of NASBenchMB
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed May 13, 2022
1 parent 1d585d8 commit d0b147c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion hyperbox/networks/nasbench_mbnet/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch.nn as nn
import torch
import inspect
from collections import OrderedDict

from hyperbox.mutables.spaces import OperationSpace
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
assert len(arch_list) == sum(stages)
self.arch_list = arch_list
self.arch_info = None
self.QUERY_FILE_PATH = inspect.getfile(self.__class__).replace('network.py', 'nasbench_mbnet_cifar10.json')

self.stem = nn.Sequential(
nn.Conv2d(3, init_channels, 3, padding=1, bias=False),
Expand Down Expand Up @@ -143,9 +145,11 @@ def arch(self):
arch_list.append(block.mask.cpu().detach().numpy().argmax())
return ''.join([str(x) for x in arch_list])

def query_by_key(self, query_file_path, key='mean_acc', arch=None):
def query_by_key(self, query_file_path=None, key='mean_acc', arch=None):
if arch is None:
arch = self.arch
if query_file_path is None:
query_file_path = self.QUERY_FILE_PATH
if self.arch_info is None:
import json
with open(query_file_path, 'r') as f:
Expand Down

0 comments on commit d0b147c

Please sign in to comment.