Skip to content

Commit

Permalink
✨ eval methods in configs
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed Apr 21, 2020
1 parent 195826e commit bb87c0c
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 77 deletions.
80 changes: 67 additions & 13 deletions lab/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from pathlib import PurePath
from typing import List, Dict, Callable, Optional, \
Union
Expand All @@ -11,11 +12,28 @@
from ..logger.colors import Text

_CALCULATORS = '_calculators'
_EVALUATORS = '_evaluators'
_CONFIG_PRINT_LEN = 40


def _is_class_method(func: Callable):
if not callable(func):
return False

spec: inspect.Signature = inspect.signature(func)
params: List[inspect.Parameter] = list(spec.parameters.values())
if len(params) != 1:
return False
p = params[0]
if p.kind != p.POSITIONAL_OR_KEYWORD:
return False

return p.name == 'self'


class Configs:
_calculators: Dict[str, List[ConfigFunction]] = {}
_evaluators: Dict[str, List[ConfigFunction]] = {}

def __init_subclass__(cls, **kwargs):
configs = {}
Expand All @@ -28,35 +46,70 @@ def __init_subclass__(cls, **kwargs):
True, v,
k in cls.__dict__, cls.__dict__.get(k, None))

evals = []
for k, v in cls.__dict__.items():
if not Parser.is_valid(k):
continue

if _is_class_method(v):
evals.append((k, v))
continue

configs[k] = ConfigItem(k,
k in cls.__annotations__, cls.__annotations__.get(k, None),
True, v)

for e in evals:
cls._add_eval_function(e[1], e[0], 'default')

for k, v in configs.items():
setattr(cls, k, v)

@classmethod
def calc(cls, name: Union[ConfigItem, List[ConfigItem]] = None,
option: str = None, *,
is_append: bool = False):
def _add_config_function(cls,
func: Callable,
name: Union[ConfigItem, List[ConfigItem]],
option: str, *,
is_append: bool
):
if _CALCULATORS not in cls.__dict__:
cls._calculators = {}

def wrapper(func: Callable):
calc = ConfigFunction(func, config_names=name, option_name=option, is_append=is_append)
if type(calc.config_names) == str:
config_names = [calc.config_names]
else:
config_names = calc.config_names
calc = ConfigFunction(func, config_names=name, option_name=option, is_append=is_append)
if type(calc.config_names) == str:
config_names = [calc.config_names]
else:
config_names = calc.config_names

for n in config_names:
if n not in cls._calculators:
cls._calculators[n] = []
cls._calculators[n].append(calc)

@classmethod
def _add_eval_function(cls,
func: Callable,
name: str,
option: str):
if _EVALUATORS not in cls.__dict__:
cls._evaluators = {}

calc = ConfigFunction(func,
config_names=name,
option_name=option,
is_append=False,
check_string_names=False)

if name not in cls._evaluators:
cls._evaluators[name] = []
cls._evaluators[name].append(calc)

for n in config_names:
if n not in cls._calculators:
cls._calculators[n] = []
cls._calculators[n].append(calc)
@classmethod
def calc(cls, name: Union[ConfigItem, List[ConfigItem]] = None,
option: str = None, *,
is_append: bool = False):
def wrapper(func: Callable):
cls._add_config_function(func, name, option, is_append=is_append)

return func

Expand All @@ -72,6 +125,7 @@ def __init__(self, configs, values: Dict[str, any] = None):
self.parser = Parser(configs, values)
self.calculator = Calculator(configs=configs,
options=self.parser.options,
evals=self.parser.evals,
types=self.parser.types,
values=self.parser.values,
list_appends=self.parser.list_appends)
Expand Down
19 changes: 18 additions & 1 deletion lab/configs/calculator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
from typing import List, Dict, Type, Set, Optional, \
OrderedDict as OrderedDictType, Union, Any, Tuple
from typing import TYPE_CHECKING
Expand All @@ -6,10 +7,11 @@
from .. import logger

if TYPE_CHECKING:
from . import Configs
from . import Configs, ConfigFunction


class Calculator:
evals: Dict[str, OrderedDictType[str, ConfigFunction]]
options: Dict[str, OrderedDictType[str, ConfigFunction]]
types: Dict[str, Type]
values: Dict[str, any]
Expand All @@ -27,9 +29,11 @@ class Calculator:
def __init__(self, *,
configs: 'Configs',
options: Dict[str, OrderedDictType[str, ConfigFunction]],
evals: Dict[str, OrderedDictType[str, ConfigFunction]],
types: Dict[str, Type],
values: Dict[str, any],
list_appends: Dict[str, List[ConfigFunction]]):
self.evals = evals
self.configs = configs
self.options = options
self.types = types
Expand Down Expand Up @@ -71,6 +75,14 @@ def __get_dependencies(self, key) -> Set[str]:

return dep

if key in self.evals:
value = self.values.get(key, None)
if not value:
value = 'default'
if value not in self.evals[key]:
return set()
return self.evals[key][value].dependencies

assert key in self.values, f"Cannot compute {key}"
# assert self.values[key] is not None, f"Cannot compute {key}"

Expand All @@ -80,6 +92,8 @@ def __create_graph(self):
self.dependencies = {}
for k in self.types:
self.dependencies[k] = self.__get_dependencies(k)
for k in self.evals:
self.dependencies[k] = self.__get_dependencies(k)

def __add_to_topological_order(self, key):
assert self.stack.pop() == key
Expand Down Expand Up @@ -125,6 +139,9 @@ def __compute(self, key):
if key in self.is_computed:
return

if key in self.evals:
return

value, funcs = self.__get_property(key)
if funcs is None:
self.__set_configs(key, value)
Expand Down
8 changes: 6 additions & 2 deletions lab/configs/config_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, func: Callable):
self.required = set()
self.is_referenced = False

source = textwrap.dedent(source)
parsed = ast.parse(source)
self.visit(parsed)

Expand Down Expand Up @@ -118,7 +119,8 @@ def __get_config_names(self, config_names: Union[str, ConfigItem, List[ConfigIte
warnings.warn("Use @Config.[name]", FutureWarning, 4)
return self.func.__name__
elif type(config_names) == str:
warnings.warn("Use @Config.[name] instead of '[name]'", FutureWarning, 4)
if self.check_string_names:
warnings.warn("Use @Config.[name] instead of '[name]'", FutureWarning, 4)
return config_names
elif type(config_names) == ConfigItem:
return config_names.key
Expand Down Expand Up @@ -151,8 +153,10 @@ def __get_params(self):
def __init__(self, func, *,
config_names: Union[str, ConfigItem, List[ConfigItem], List[str]],
option_name: str,
is_append: bool):
is_append: bool,
check_string_names: bool = True):
self.func = func
self.check_string_names = check_string_names
self.config_names = self.__get_config_names(config_names)
self.is_append = is_append
assert not (self.is_append and type(self.config_names) != str)
Expand Down
19 changes: 19 additions & 0 deletions lab/configs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from . import Configs

_CALCULATORS = '_calculators'
_EVALUATORS = '_evaluators'


def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]:
Expand Down Expand Up @@ -46,6 +47,7 @@ def _get_base_classes(class_: Type['Configs']) -> List[Type['Configs']]:
class Parser:
config_items: Dict[str, ConfigItem]
options: Dict[str, OrderedDictType[str, ConfigFunction]]
evals: Dict[str, OrderedDictType[str, ConfigFunction]]
types: Dict[str, Type]
values: Dict[str, any]
list_appends: Dict[str, List[ConfigFunction]]
Expand All @@ -56,6 +58,7 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None):
self.values = {}
self.types = {}
self.options = {}
self.evals = {}
self.list_appends = {}
self.config_items = {}
self.configs = configs
Expand All @@ -65,6 +68,8 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None):
# self.__collect_annotation(k, v)
#
for k, v in c.__dict__.items():
if _EVALUATORS in c.__dict__ and k in c.__dict__[_EVALUATORS]:
continue
self.__collect_config_item(k, v)

for c in classes:
Expand All @@ -75,6 +80,12 @@ def __init__(self, configs: 'Configs', values: Dict[str, any] = None):
for v in calcs:
self.__collect_calculator(k, v)

for c in classes:
if _EVALUATORS in c.__dict__:
for k, evals in c.__dict__[_EVALUATORS].items():
for v in evals:
self.__collect_evaluator(k, v)

for k, v in configs.__dict__.items():
assert k in self.types
self.__collect_value(k, v)
Expand Down Expand Up @@ -140,6 +151,14 @@ def __collect_calculator(self, k, v: ConfigFunction):

self.options[k][v.option_name] = v

def __collect_evaluator(self, k, v: ConfigFunction):
assert not v.is_append

if k not in self.evals:
self.evals[k] = OrderedDict()

self.evals[k][v.option_name] = v

def __calculate_missing_values(self):
for k in self.types:
if k in self.values:
Expand Down
Loading

0 comments on commit bb87c0c

Please sign in to comment.