Skip to content

Commit

Permalink
refactor v1.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Jan 17, 2023
1 parent 1e3e409 commit b6af5a4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
File renamed without changes.
17 changes: 9 additions & 8 deletions hyperbox/mutables/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,18 +424,18 @@ def forward(self, optional_inputs: Union[list, dict]) -> Union[torch.Tensor, Tup
-------
tuple of torch.Tensor and torch.Tensor or torch.Tensor
"""
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
"Optional input list must be a list, not a {}.".format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
if self.is_search and hasattr(self, "mutator") and self.mutator._cache:
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
"Optional input list must be a list, not a {}.".format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_space(self, optional_input_list)
else:
mask = self.mask
out = self._select_with_mask(lambda x: x, [(t,) for t in optional_inputs], mask)
out = self._select_with_mask(lambda x: x, [(t,) for t in optional_input_list], mask)
out = self._tensor_reduction(self.reduction, out)
if self.return_mask:
return out, mask
Expand Down Expand Up @@ -496,6 +496,7 @@ def sortIdx(self, indices):
indices = torch.tensor(indices)
self._sortIdx = indices


if __name__ == '__main__':
# mask = {'test': torch.tensor([0.5,0.3,0.2,0.1])}
mask = {
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ seaborn
pudb
colorlog
graphviz>=0.20.1
peewee # for building nasbench-201
# peewee is for building nasbench-201
peewee
#numpy>=1.22.0
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from setuptools import setup, find_packages

required_modules = []
with open('requirements.txt') as f:
required = f.read().splitlines()
for module in required:
if not module.startswith('#'):
required_modules.append(module)

setup(
name="hyperbox", # you should change "src" to your project name
version="1.3.0",
version="1.3.1",
description="Hyperbox: An easy-to-use NAS framework.",
author="marsggbo",
url="https://github.com/marsggbo/hyperbox",
# replace with your own github project link
install_requires=["pytorch-lightning>=1.5", "hydra-core>=1.2"],
install_requires=required_modules,
packages=find_packages(),
include_package_data=True,
)

0 comments on commit b6af5a4

Please sign in to comment.