From bb87c0c6561b28297fc2f581ac855abfd85c6653 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 21 Apr 2020 11:43:42 +0530 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20eval=20methods=20in=20configs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lab/configs/__init__.py | 80 +++++++++++++++++++++----- lab/configs/calculator.py | 19 ++++++- lab/configs/config_function.py | 8 ++- lab/configs/parser.py | 19 +++++++ samples/mnist_loop.py | 101 +++++++++++++-------------------- 5 files changed, 150 insertions(+), 77 deletions(-) diff --git a/lab/configs/__init__.py b/lab/configs/__init__.py index ff9a1a63b..e8fbaa56e 100644 --- a/lab/configs/__init__.py +++ b/lab/configs/__init__.py @@ -1,3 +1,4 @@ +import inspect from pathlib import PurePath from typing import List, Dict, Callable, Optional, \ Union @@ -11,11 +12,28 @@ from ..logger.colors import Text _CALCULATORS = '_calculators' +_EVALUATORS = '_evaluators' _CONFIG_PRINT_LEN = 40 +def _is_class_method(func: Callable): + if not callable(func): + return False + + spec: inspect.Signature = inspect.signature(func) + params: List[inspect.Parameter] = list(spec.parameters.values()) + if len(params) != 1: + return False + p = params[0] + if p.kind != p.POSITIONAL_OR_KEYWORD: + return False + + return p.name == 'self' + + class Configs: _calculators: Dict[str, List[ConfigFunction]] = {} + _evaluators: Dict[str, List[ConfigFunction]] = {} def __init_subclass__(cls, **kwargs): configs = {} @@ -28,35 +46,70 @@ def __init_subclass__(cls, **kwargs): True, v, k in cls.__dict__, cls.__dict__.get(k, None)) + evals = [] for k, v in cls.__dict__.items(): if not Parser.is_valid(k): continue + if _is_class_method(v): + evals.append((k, v)) + continue + configs[k] = ConfigItem(k, k in cls.__annotations__, cls.__annotations__.get(k, None), True, v) + for e in evals: + cls._add_eval_function(e[1], e[0], 'default') + for k, v in configs.items(): setattr(cls, k, v) @classmethod - def calc(cls, name: Union[ConfigItem, List[ConfigItem]] = None, - option: str = None, *, - is_append: bool = False): + def _add_config_function(cls, + func: Callable, + name: Union[ConfigItem, List[ConfigItem]], + option: str, *, + is_append: bool + ): if _CALCULATORS not in cls.__dict__: cls._calculators = {} - def wrapper(func: Callable): - calc = ConfigFunction(func, config_names=name, option_name=option, is_append=is_append) - if type(calc.config_names) == str: - config_names = [calc.config_names] - else: - config_names = calc.config_names + calc = ConfigFunction(func, config_names=name, option_name=option, is_append=is_append) + if type(calc.config_names) == str: + config_names = [calc.config_names] + else: + config_names = calc.config_names + + for n in config_names: + if n not in cls._calculators: + cls._calculators[n] = [] + cls._calculators[n].append(calc) + + @classmethod + def _add_eval_function(cls, + func: Callable, + name: str, + option: str): + if _EVALUATORS not in cls.__dict__: + cls._evaluators = {} + + calc = ConfigFunction(func, + config_names=name, + option_name=option, + is_append=False, + check_string_names=False) + + if name not in cls._evaluators: + cls._evaluators[name] = [] + cls._evaluators[name].append(calc) - for n in config_names: - if n not in cls._calculators: - cls._calculators[n] = [] - cls._calculators[n].append(calc) + @classmethod + def calc(cls, name: Union[ConfigItem, List[ConfigItem]] = None, + option: str = None, *, + is_append: bool = False): + def wrapper(func: Callable): + cls._add_config_function(func, name, option, is_append=is_append) return func @@ -72,6 +125,7 @@ def __init__(self, configs, values: Dict[str, any] = None): self.parser = Parser(configs, values) self.calculator = Calculator(configs=configs, options=self.parser.options, + evals=self.parser.evals, types=self.parser.types, values=self.parser.values, list_appends=self.parser.list_appends) diff --git a/lab/configs/calculator.py b/lab/configs/calculator.py index 7335bf5bb..7ef1464fd 100644 --- a/lab/configs/calculator.py +++ b/lab/configs/calculator.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from typing import List, Dict, Type, Set, Optional, \ OrderedDict as OrderedDictType, Union, Any, Tuple from typing import TYPE_CHECKING @@ -6,10 +7,11 @@ from .. import logger if TYPE_CHECKING: - from . import Configs + from . import Configs, ConfigFunction class Calculator: + evals: Dict[str, OrderedDictType[str, ConfigFunction]] options: Dict[str, OrderedDictType[str, ConfigFunction]] types: Dict[str, Type] values: Dict[str, any] @@ -27,9 +29,11 @@ class Calculator: def __init__(self, *, configs: 'Configs', options: Dict[str, OrderedDictType[str, ConfigFunction]], + evals: Dict[str, OrderedDictType[str, ConfigFunction]], types: Dict[str, Type], values: Dict[str, any], list_appends: Dict[str, List[ConfigFunction]]): + self.evals = evals self.configs = configs self.options = options self.types = types @@ -71,6 +75,14 @@ def __get_dependencies(self, key) -> Set[str]: return dep + if key in self.evals: + value = self.values.get(key, None) + if not value: + value = 'default' + if value not in self.evals[key]: + return set() + return self.evals[key][value].dependencies + assert key in self.values, f"Cannot compute {key}" # assert self.values[key] is not None, f"Cannot compute {key}" @@ -80,6 +92,8 @@ def __create_graph(self): self.dependencies = {} for k in self.types: self.dependencies[k] = self.__get_dependencies(k) + for k in self.evals: + self.dependencies[k] = self.__get_dependencies(k) def __add_to_topological_order(self, key): assert self.stack.pop() == key @@ -125,6 +139,9 @@ def __compute(self, key): if key in self.is_computed: return + if key in self.evals: + return + value, funcs = self.__get_property(key) if funcs is None: self.__set_configs(key, value) diff --git a/lab/configs/config_function.py b/lab/configs/config_function.py index 398823a93..9e2e9a1b5 100644 --- a/lab/configs/config_function.py +++ b/lab/configs/config_function.py @@ -38,6 +38,7 @@ def __init__(self, func: Callable): self.required = set() self.is_referenced = False + source = textwrap.dedent(source) parsed = ast.parse(source) self.visit(parsed) @@ -118,7 +119,8 @@ def __get_config_names(self, config_names: Union[str, ConfigItem, List[ConfigIte warnings.warn("Use @Config.[name]", FutureWarning, 4) return self.func.__name__ elif type(config_names) == str: - warnings.warn("Use @Config.[name] instead of '[name]'", FutureWarning, 4) + if self.check_string_names: + warnings.warn("Use @Config.[name] instead of '[name]'", FutureWarning, 4) return config_names elif type(config_names) == ConfigItem: return config_names.key @@ -151,8 +153,10 @@ def __get_params(self): def __init__(self, func, *, config_names: Union[str, ConfigItem, List[ConfigItem], List[str]], option_name: str, - is_append: bool): + is_append: bool, + check_string_names: bool = True): self.func = func + self.check_string_names = check_string_names self.config_names = self.__get_config_names(config_names) self.is_append = is_append assert not (self.is_append and type(self.config_names) != str) diff --git a/lab/configs/parser.py b/lab/configs/parser.py index b8d149156..434c6e98b 100644 --- a/lab/configs/parser.py +++ b/lab/configs/parser.py @@ -10,6 +10,7 @@ from . import Configs _CALCULATORS = '_calculators' +_EVALUATORS = '_evaluators' def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]: @@ -46,6 +47,7 @@ def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]: class Parser: config_items: Dict[str, ConfigItem] options: Dict[str, OrderedDictType[str, ConfigFunction]] + evals: Dict[str, OrderedDictType[str, ConfigFunction]] types: Dict[str, Type] values: Dict[str, any] list_appends: Dict[str, List[ConfigFunction]] @@ -56,6 +58,7 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None): self.values = {} self.types = {} self.options = {} + self.evals = {} self.list_appends = {} self.config_items = {} self.configs = configs @@ -65,6 +68,8 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None): # self.__collect_annotation(k, v) # for k, v in c.__dict__.items(): + if _EVALUATORS in c.__dict__ and k in c.__dict__[_EVALUATORS]: + continue self.__collect_config_item(k, v) for c in classes: @@ -75,6 +80,12 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None): for v in calcs: self.__collect_calculator(k, v) + for c in classes: + if _EVALUATORS in c.__dict__: + for k, evals in c.__dict__[_EVALUATORS].items(): + for v in evals: + self.__collect_evaluator(k, v) + for k, v in configs.__dict__.items(): assert k in self.types self.__collect_value(k, v) @@ -140,6 +151,14 @@ def __collect_calculator(self, k, v: ConfigFunction): self.options[k][v.option_name] = v + def __collect_evaluator(self, k, v: ConfigFunction): + assert not v.is_append + + if k not in self.evals: + self.evals[k] = OrderedDict() + + self.evals[k][v.option_name] = v + def __calculate_missing_values(self): for k in self.types: if k in self.values: diff --git a/samples/mnist_loop.py b/samples/mnist_loop.py index 6afa2ded0..3568199ef 100644 --- a/samples/mnist_loop.py +++ b/samples/mnist_loop.py @@ -5,7 +5,7 @@ import torch.utils.data from torchvision import datasets, transforms -from lab import logger, configs +from lab import logger from lab import training_loop from lab.experiment.pytorch import Experiment from lab.logger.indicators import Queue, Histogram @@ -30,18 +30,37 @@ def forward(self, x): return self.fc2(x) -class MNIST: - def __init__(self, c: 'Configs'): - self.model = c.model - self.device = c.device - self.train_loader = c.train_loader - self.test_loader = c.test_loader - self.optimizer = c.optimizer - self.train_log_interval = c.train_log_interval - self.loop = c.training_loop - self.__is_log_parameters = c.is_log_parameters +class Configs(training_loop.TrainingLoopConfigs): + epochs: int = 10 + + loop_step = 'loop_step' + loop_count = 'loop_count' + + is_save_models = True + batch_size: int = 64 + test_batch_size: int = 1000 - def _train(self): + use_cuda: bool = True + cuda_device: int = 0 + seed: int = 5 + train_log_interval: int = 10 + + is_log_parameters: bool = True + + device: any + + train_loader: torch.utils.data.DataLoader + test_loader: torch.utils.data.DataLoader + + model: nn.Module + + learning_rate: float = 0.01 + momentum: float = 0.5 + optimizer: optim.SGD + + set_seed = 'set_seed' + + def train(self): self.model.train() for i, (data, target) in logger.enum("Train", self.train_loader): data, target = data.to(self.device), target.to(self.device) @@ -58,7 +77,7 @@ def _train(self): if i % self.train_log_interval == 0: logger.write() - def _test(self): + def test(self): self.model.eval() test_loss = 0 correct = 0 @@ -73,58 +92,18 @@ def _test(self): logger.store(test_loss=test_loss / len(self.test_loader.dataset)) logger.store(accuracy=correct / len(self.test_loader.dataset)) - def __log_model_params(self): - if not self.__is_log_parameters: - return - - logger_util.store_model_indicators(self.model) - - def __call__(self): + def run(self): logger_util.add_model_indicators(self.model) logger.add_indicator(Queue("train_loss", 20, True)) logger.add_indicator(Histogram("test_loss", True)) logger.add_indicator(Histogram("accuracy", True)) - for _ in self.loop: - self._train() - self._test() - self.__log_model_params() - - -class LoaderConfigs(configs.Configs): - train_loader: torch.utils.data.DataLoader - test_loader: torch.utils.data.DataLoader - - -class Configs(training_loop.TrainingLoopConfigs, LoaderConfigs): - epochs: int = 10 - - loop_step = 'loop_step' - loop_count = 'loop_count' - - is_save_models = True - batch_size: int = 64 - test_batch_size: int = 1000 - - use_cuda: bool = True - cuda_device: int = 0 - seed: int = 5 - train_log_interval: int = 10 - - is_log_parameters: bool = True - - device: any - - model: nn.Module - - learning_rate: float = 0.01 - momentum: float = 0.5 - optimizer: optim.SGD - - set_seed = 'set_seed' - - main: MNIST + for _ in self.training_loop: + self.train() + self.test() + if self.is_log_parameters: + logger_util.store_model_indicators(self.model) @Configs.calc(Configs.device) @@ -191,10 +170,10 @@ def main(): experiment = Experiment(writers={'sqlite', 'tensorboard'}) experiment.calc_configs(conf, {'optimizer': 'adam_optimizer'}, - ['set_seed', 'main']) + ['set_seed', 'run']) experiment.add_models(dict(model=conf.model)) experiment.start() - conf.main() + conf.run() if __name__ == '__main__':