diff --git a/lab/configs/__init__.py b/lab/configs/__init__.py index e8fbaa56e..91d5c2a9f 100644 --- a/lab/configs/__init__.py +++ b/lab/configs/__init__.py @@ -13,6 +13,7 @@ _CALCULATORS = '_calculators' _EVALUATORS = '_evaluators' +_HYPERPARAMS = '_hyperparams' _CONFIG_PRINT_LEN = 40 @@ -119,6 +120,14 @@ def wrapper(func: Callable): def list(cls, name: str = None): return cls.calc(name, f"_{util.random_string()}", is_append=True) + @classmethod + def set_hyperparams(cls, *args: ConfigItem, is_hyperparam=True): + if _HYPERPARAMS not in cls.__dict__: + cls._hyperparams = {} + + for h in args: + cls._hyperparams[h.key] = is_hyperparam + class ConfigProcessor: def __init__(self, configs, values: Dict[str, any] = None): @@ -182,7 +191,9 @@ def save(self, configs_path: PurePath): 'value': self.__to_yaml(self.parser.values.get(k, None)), 'order': orders.get(k, -1), 'options': list(self.parser.options.get(k, {}).keys()), - 'computed': self.__to_yaml(getattr(self.calculator.configs, k, None)) + 'computed': self.__to_yaml(getattr(self.calculator.configs, k, None)), + 'is_hyperparam': self.parser.hyperparams.get(k, None), + 'is_explicitly_specified': (k in self.parser.explicitly_specified) } with open(str(configs_path), "w") as file: @@ -196,8 +207,22 @@ def __default_repr(value): hex(id(value)) ) - @staticmethod - def __print_config(key, *, value=None, option=None, + def get_hyperparams(self): + order = self.calculator.topological_order.copy() + + hyperparams = {} + for key in order: + if (self.parser.hyperparams.get(key, False) or + key in self.parser.explicitly_specified): + value = getattr(self.calculator.configs, key, None) + if key in self.parser.options: + value = self.parser.values[key] + + hyperparams[key] = value + + return hyperparams + + def __print_config(self, key, *, value=None, option=None, other_options=None, is_ignored=False, is_list=False): parts = ['\t'] @@ -205,7 +230,10 @@ def __print_config(key, *, value=None, option=None, parts.append((key, Text.subtle)) return parts - parts.append((key, Text.key)) + if self.parser.hyperparams.get(key, False) or key in self.parser.explicitly_specified: + parts.append((key, [Text.key, Text.highlight])) + else: + parts.append((key, Text.key)) if is_list: parts.append(('[]', Text.subtle)) diff --git a/lab/configs/parser.py b/lab/configs/parser.py index 434c6e98b..6ed80c803 100644 --- a/lab/configs/parser.py +++ b/lab/configs/parser.py @@ -11,6 +11,7 @@ _CALCULATORS = '_calculators' _EVALUATORS = '_evaluators' +_HYPERPARAMS = '_hyperparams' def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]: @@ -40,7 +41,7 @@ def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]: return unique_classes -RESERVED = {'calc', 'list'} +RESERVED = {'calc', 'list', 'set_hyperparams'} _STANDARD_TYPES = {int, str, bool, Dict, List} @@ -51,6 +52,8 @@ class Parser: types: Dict[str, Type] values: Dict[str, any] list_appends: Dict[str, List[ConfigFunction]] + explicitly_specified: Set[str] + hyperparams: Dict[str, bool] def __init__(self, configs: 'Configs', values: Dict[str, any] = None): classes = _get_base_classes(type(configs)) @@ -62,6 +65,8 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None): self.list_appends = {} self.config_items = {} self.configs = configs + self.explicitly_specified = set() + self.hyperparams = {} for c in classes: # for k, v in c.__annotations__.items(): @@ -86,6 +91,11 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None): for v in evals: self.__collect_evaluator(k, v) + for c in classes: + if _HYPERPARAMS in c.__dict__: + for k, is_hyperparam in c.__dict__[_HYPERPARAMS].items(): + self.hyperparams[k] = is_hyperparam + for k, v in configs.__dict__.items(): assert k in self.types self.__collect_value(k, v) @@ -126,6 +136,8 @@ def __collect_value(self, k, v): if not self.is_valid(k): return + self.explicitly_specified.add(k) + self.values[k] = v if k not in self.types: self.types[k] = type(v) diff --git a/lab/experiment/__init__.py b/lab/experiment/__init__.py index 443fa1520..1de30f36f 100644 --- a/lab/experiment/__init__.py +++ b/lab/experiment/__init__.py @@ -165,10 +165,6 @@ def calc_configs(self, self.configs_processor = ConfigProcessor(configs, configs_dict) self.configs_processor(run_order) - # May be we should write these in `start` - if configs_dict: - logger.internal().write_h_parameters(configs_dict) - logger.new_line() def __start_from_checkpoint(self, run_uuid: str, checkpoint: Optional[int]): @@ -210,3 +206,5 @@ def start(self, *, logger.internal().save_indicators(self.run.indicators_path) logger.internal().save_artifacts(self.run.artifacts_path) + logger.internal().write_h_parameters(self.configs_processor.get_hyperparams()) + diff --git a/samples/hyperparameter_tunining.py b/samples/hyperparameter_tunining.py index d602f78c7..c09b2eabf 100644 --- a/samples/hyperparameter_tunining.py +++ b/samples/hyperparameter_tunining.py @@ -176,11 +176,11 @@ def loop_step(c: Configs): return len(c.train_loader) -def run(run_name: str, hparams: dict): +def run(hparams: dict): logger.set_global_step(0) conf = Configs() - experiment = Experiment(name=run_name, writers={'sqlite', 'tensorboard'}) + experiment = Experiment(name='mnist_hyperparam_tuning', writers={'sqlite', 'tensorboard'}) experiment.calc_configs(conf, hparams, ['set_seed', 'main']) @@ -191,7 +191,6 @@ def run(run_name: str, hparams: dict): def main(): - session_num = 1 for conv1_kernal in [3, 5]: for conv2_kernal in [3, 5]: hparams = { @@ -199,11 +198,7 @@ def main(): 'conv2_kernal': conv2_kernal, } - run_name = "mnist_run-%d" % session_num - - run(run_name, hparams) - - session_num += 1 + run(hparams) if __name__ == '__main__':