diff --git a/searchspaces/partialplus.py b/searchspaces/partialplus.py index 00292ad..1ff490c 100644 --- a/searchspaces/partialplus.py +++ b/searchspaces/partialplus.py @@ -17,6 +17,10 @@ # TODO: support o_len functionality from old Apply nodes +def is_variable_node(node): + return hasattr(node, 'func') and node.func is variable_node + + def is_tuple_node(node): return hasattr(node, 'func') and node.func is make_tuple @@ -49,6 +53,17 @@ def make_tuple(*args): return tuple(args) +def variable_node(*args, **kwargs): + """ + Marker function for variable nodes created by `variable()`. + + Notes + ----- + By convention we store everything in kwargs. + """ + assert len(args) == 0 + + def call_with_list_of_pos_args(f, *args): return f(args) @@ -696,7 +711,58 @@ def extend_args(self, args): self._args = self._args + tuple(args) -def evaluate(p, instantiate_call=None, bindings=None): +def variable(name, value_type, minimum=None, maximum=None, default=None, + log_scale=False, distribution=None, **kwargs): + """ + Create a special variable node to be replaced at evaluation time + of a `PartialPlus` graph. + + Parameters + ---------- + name : str + A unique string identifier. Must be a valid Python variable name. + TODO: validate this requirement. + value_type : type or iterable + One of `float`, `int`, or a sequence of possible values. + minimum : float or int, optional + If `value_type` is float or int, the minimum value this variable + can take. + maximum : float or int, optional + If `value_type` is float or int, the maximum value this variable + can take. + default : object, optional + A "default" value for this variable, used by some optimizers. + log_scale : bool, optional + Indicator used by some systems to determine whether a quantity + should be treated as if varying on a logarithmic scale. + distribution : callable(?), optional + A prior distribution on the support of this parameter, used by + some optimizers. + + Returns + ------- + variable_node : PartialPlus + A `PartialPlus` with `variable_node` as the function attribute. + """ + d = locals() + d.update(kwargs) # kwargs guaranteed not to have keys already in locals() + return partial(variable_node, **d) + + +def evaluate(p, **kwargs): + """ + Evaluate a nested tree of functools.partial objects, + used for deferred evaluation. + + Parameters + ---------- + p : object + + """ + return _evaluate(p, bindings=kwargs) + + +def _evaluate(p, instantiate_call=None, bindings=None): """ Evaluate a nested tree of functools.partial objects, used for deferred evaluation. @@ -738,6 +804,9 @@ def evaluate(p, instantiate_call=None, bindings=None): bindings[p] = p.value return bindings[p] + recurse = _partial(_evaluate, instantiate_call=instantiate_call, + bindings=bindings) + # When evaluating an expression of the form # `list(...)[item]` # only evaluate the element(s) of the list that we need. @@ -745,14 +814,15 @@ def evaluate(p, instantiate_call=None, bindings=None): obj, index = p.args if (isinstance(obj, _partial) and obj.func in (make_list, make_tuple)): - index_val = evaluate(index, instantiate_call, bindings) - # TODO: is_iterable + index_val = recurse(index) elem_val = obj.args[index_val] if isinstance(index_val, slice): # TODO: something more robust? - elem_val = obj.func(*[evaluate(e, instantiate_call, bindings) - for e in elem_val]) + # elem_val is a sliced out sublist, recurse on each element + # therein and call obj.func (make_list, make_tuple) on result. + elem_val = instantiate_call(obj.func, + *[recurse(e) for e in elem_val]) else: - elem_val = evaluate(elem_val, instantiate_call, bindings) + elem_val = recurse(elem_val) try: # bindings the value of this subexpression as int(index_val) @@ -761,10 +831,18 @@ def evaluate(p, instantiate_call=None, bindings=None): # TODO: is this even conceivably used? bindings[p] = instantiate_call(p.func, elem_val, index_val) return bindings[p] + args = [recurse(arg) for arg in p.args] + kw = (dict((kw, recurse(val)) for kw, val in p.keywords.iteritems()) + if p.keywords else {}) + + if is_variable_node(p): + assert 'name' in p.keywords + name = kw['name'] + try: + return bindings[name] + except KeyError: + raise KeyError("variable with name '%s' not bound" % name) - args = [evaluate(arg, instantiate_call, bindings) for arg in p.args] - kw = dict((kw, evaluate(val, instantiate_call, bindings)) - for kw, val in p.keywords.iteritems()) if p.keywords else {} # bindings the evaluated value (for subsequent calls that # will look at this bindings dictionary) and return. bindings[p] = instantiate_call(p.func, *args, **kw) diff --git a/searchspaces/tests/test_partialplus.py b/searchspaces/tests/test_partialplus.py index 2d7db01..86a8e7f 100644 --- a/searchspaces/tests/test_partialplus.py +++ b/searchspaces/tests/test_partialplus.py @@ -1,5 +1,5 @@ from searchspaces.partialplus import partial, Literal -from searchspaces.partialplus import evaluate +from searchspaces.partialplus import evaluate, variable from searchspaces.partialplus import depth_first_traversal, topological_sort from searchspaces.partialplus import as_partialplus as as_pp @@ -227,6 +227,20 @@ class Foo(object): assert r[0] is r[1][-1] assert r[0] is r[2][0][0] + +def test_variable_substitution(): + x = variable(name='x', value_type=int) + y = variable(name='y', value_type=float) + p = as_pp({3: x, x:[y, [y]], y:4}) + # Currently no type-checking. This will fail when we add it and need to be + # updated. + e = evaluate(p, x='hey', y=5) + assert e[3] == 'hey' + assert e['hey'] == [5, [5]] + assert e[5] == 4 + + + if __name__ == "__main__": test_switch() test_switch_range()