Skip to content

Commit

Permalink
✨ configs hyper-parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed Apr 21, 2020
1 parent bb87c0c commit 4c3b1e8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
36 changes: 32 additions & 4 deletions lab/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

_CALCULATORS = '_calculators'
_EVALUATORS = '_evaluators'
_HYPERPARAMS = '_hyperparams'
_CONFIG_PRINT_LEN = 40


Expand Down Expand Up @@ -119,6 +120,14 @@ def wrapper(func: Callable):
def list(cls, name: str = None):
return cls.calc(name, f"_{util.random_string()}", is_append=True)

@classmethod
def set_hyperparams(cls, *args: ConfigItem, is_hyperparam=True):
if _HYPERPARAMS not in cls.__dict__:
cls._hyperparams = {}

for h in args:
cls._hyperparams[h.key] = is_hyperparam


class ConfigProcessor:
def __init__(self, configs, values: Dict[str, any] = None):
Expand Down Expand Up @@ -182,7 +191,9 @@ def save(self, configs_path: PurePath):
'value': self.__to_yaml(self.parser.values.get(k, None)),
'order': orders.get(k, -1),
'options': list(self.parser.options.get(k, {}).keys()),
'computed': self.__to_yaml(getattr(self.calculator.configs, k, None))
'computed': self.__to_yaml(getattr(self.calculator.configs, k, None)),
'is_hyperparam': self.parser.hyperparams.get(k, None),
'is_explicitly_specified': (k in self.parser.explicitly_specified)
}

with open(str(configs_path), "w") as file:
Expand All @@ -196,16 +207,33 @@ def __default_repr(value):
hex(id(value))
)

@staticmethod
def __print_config(key, *, value=None, option=None,
def get_hyperparams(self):
order = self.calculator.topological_order.copy()

hyperparams = {}
for key in order:
if (self.parser.hyperparams.get(key, False) or
key in self.parser.explicitly_specified):
value = getattr(self.calculator.configs, key, None)
if key in self.parser.options:
value = self.parser.values[key]

hyperparams[key] = value

return hyperparams

def __print_config(self, key, *, value=None, option=None,
other_options=None, is_ignored=False, is_list=False):
parts = ['\t']

if is_ignored:
parts.append((key, Text.subtle))
return parts

parts.append((key, Text.key))
if self.parser.hyperparams.get(key, False) or key in self.parser.explicitly_specified:
parts.append((key, [Text.key, Text.highlight]))
else:
parts.append((key, Text.key))

if is_list:
parts.append(('[]', Text.subtle))
Expand Down
14 changes: 13 additions & 1 deletion lab/configs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

_CALCULATORS = '_calculators'
_EVALUATORS = '_evaluators'
_HYPERPARAMS = '_hyperparams'


def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]:
Expand Down Expand Up @@ -40,7 +41,7 @@ def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]:
return unique_classes


RESERVED = {'calc', 'list'}
RESERVED = {'calc', 'list', 'set_hyperparams'}
_STANDARD_TYPES = {int, str, bool, Dict, List}


Expand All @@ -51,6 +52,8 @@ class Parser:
types: Dict[str, Type]
values: Dict[str, any]
list_appends: Dict[str, List[ConfigFunction]]
explicitly_specified: Set[str]
hyperparams: Dict[str, bool]

def __init__(self, configs: 'Configs', values: Dict[str, any] = None):
classes = _get_base_classes(type(configs))
Expand All @@ -62,6 +65,8 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None):
self.list_appends = {}
self.config_items = {}
self.configs = configs
self.explicitly_specified = set()
self.hyperparams = {}

for c in classes:
# for k, v in c.__annotations__.items():
Expand All @@ -86,6 +91,11 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None):
for v in evals:
self.__collect_evaluator(k, v)

for c in classes:
if _HYPERPARAMS in c.__dict__:
for k, is_hyperparam in c.__dict__[_HYPERPARAMS].items():
self.hyperparams[k] = is_hyperparam

for k, v in configs.__dict__.items():
assert k in self.types
self.__collect_value(k, v)
Expand Down Expand Up @@ -126,6 +136,8 @@ def __collect_value(self, k, v):
if not self.is_valid(k):
return

self.explicitly_specified.add(k)

self.values[k] = v
if k not in self.types:
self.types[k] = type(v)
Expand Down
6 changes: 2 additions & 4 deletions lab/experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ def calc_configs(self,
self.configs_processor = ConfigProcessor(configs, configs_dict)
self.configs_processor(run_order)

# May be we should write these in `start`
if configs_dict:
logger.internal().write_h_parameters(configs_dict)

logger.new_line()

def __start_from_checkpoint(self, run_uuid: str, checkpoint: Optional[int]):
Expand Down Expand Up @@ -210,3 +206,5 @@ def start(self, *,

logger.internal().save_indicators(self.run.indicators_path)
logger.internal().save_artifacts(self.run.artifacts_path)
logger.internal().write_h_parameters(self.configs_processor.get_hyperparams())

11 changes: 3 additions & 8 deletions samples/hyperparameter_tunining.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def loop_step(c: Configs):
return len(c.train_loader)


def run(run_name: str, hparams: dict):
def run(hparams: dict):
logger.set_global_step(0)

conf = Configs()
experiment = Experiment(name=run_name, writers={'sqlite', 'tensorboard'})
experiment = Experiment(name='mnist_hyperparam_tuning', writers={'sqlite', 'tensorboard'})
experiment.calc_configs(conf,
hparams,
['set_seed', 'main'])
Expand All @@ -191,19 +191,14 @@ def run(run_name: str, hparams: dict):


def main():
session_num = 1
for conv1_kernal in [3, 5]:
for conv2_kernal in [3, 5]:
hparams = {
'conv1_kernal': conv1_kernal,
'conv2_kernal': conv2_kernal,
}

run_name = "mnist_run-%d" % session_num

run(run_name, hparams)

session_num += 1
run(hparams)


if __name__ == '__main__':
Expand Down

0 comments on commit 4c3b1e8

Please sign in to comment.