Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 87 additions & 9 deletions searchspaces/partialplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
# TODO: support o_len functionality from old Apply nodes


def is_variable_node(node):
return hasattr(node, 'func') and node.func is variable_node


def is_tuple_node(node):
return hasattr(node, 'func') and node.func is make_tuple

Expand Down Expand Up @@ -49,6 +53,17 @@ def make_tuple(*args):
return tuple(args)


def variable_node(*args, **kwargs):
"""
Marker function for variable nodes created by `variable()`.

Notes
-----
By convention we store everything in kwargs.
"""
assert len(args) == 0


def call_with_list_of_pos_args(f, *args):
return f(args)

Expand Down Expand Up @@ -696,7 +711,58 @@ def extend_args(self, args):
self._args = self._args + tuple(args)


def evaluate(p, instantiate_call=None, bindings=None):
def variable(name, value_type, minimum=None, maximum=None, default=None,
log_scale=False, distribution=None, **kwargs):
"""
Create a special variable node to be replaced at evaluation time
of a `PartialPlus` graph.

Parameters
----------
name : str
A unique string identifier. Must be a valid Python variable name.
TODO: validate this requirement.
value_type : type or iterable
One of `float`, `int`, or a sequence of possible values.
minimum : float or int, optional
If `value_type` is float or int, the minimum value this variable
can take.
maximum : float or int, optional
If `value_type` is float or int, the maximum value this variable
can take.
default : object, optional
A "default" value for this variable, used by some optimizers.
log_scale : bool, optional
Indicator used by some systems to determine whether a quantity
should be treated as if varying on a logarithmic scale.
distribution : callable(?), optional
A prior distribution on the support of this parameter, used by
some optimizers.

Returns
-------
variable_node : PartialPlus
A `PartialPlus` with `variable_node` as the function attribute.
"""
d = locals()
d.update(kwargs) # kwargs guaranteed not to have keys already in locals()
return partial(variable_node, **d)


def evaluate(p, **kwargs):
"""
Evaluate a nested tree of functools.partial objects,
used for deferred evaluation.

Parameters
----------
p : object

"""
return _evaluate(p, bindings=kwargs)


def _evaluate(p, instantiate_call=None, bindings=None):
"""
Evaluate a nested tree of functools.partial objects,
used for deferred evaluation.
Expand Down Expand Up @@ -738,21 +804,25 @@ def evaluate(p, instantiate_call=None, bindings=None):
bindings[p] = p.value
return bindings[p]

recurse = _partial(_evaluate, instantiate_call=instantiate_call,
bindings=bindings)

# When evaluating an expression of the form
# `list(...)[item]`
# only evaluate the element(s) of the list that we need.
if p.func == _getitem:
obj, index = p.args
if (isinstance(obj, _partial)
and obj.func in (make_list, make_tuple)):
index_val = evaluate(index, instantiate_call, bindings)
# TODO: is_iterable
index_val = recurse(index)
elem_val = obj.args[index_val]
if isinstance(index_val, slice): # TODO: something more robust?
elem_val = obj.func(*[evaluate(e, instantiate_call, bindings)
for e in elem_val])
# elem_val is a sliced out sublist, recurse on each element
# therein and call obj.func (make_list, make_tuple) on result.
elem_val = instantiate_call(obj.func,
*[recurse(e) for e in elem_val])
else:
elem_val = evaluate(elem_val, instantiate_call, bindings)
elem_val = recurse(elem_val)
try:
# bindings the value of this subexpression as
int(index_val)
Expand All @@ -761,10 +831,18 @@ def evaluate(p, instantiate_call=None, bindings=None):
# TODO: is this even conceivably used?
bindings[p] = instantiate_call(p.func, elem_val, index_val)
return bindings[p]
args = [recurse(arg) for arg in p.args]
kw = (dict((kw, recurse(val)) for kw, val in p.keywords.iteritems())
if p.keywords else {})

if is_variable_node(p):
assert 'name' in p.keywords
name = kw['name']
try:
return bindings[name]
except KeyError:
raise KeyError("variable with name '%s' not bound" % name)

args = [evaluate(arg, instantiate_call, bindings) for arg in p.args]
kw = dict((kw, evaluate(val, instantiate_call, bindings))
for kw, val in p.keywords.iteritems()) if p.keywords else {}
# bindings the evaluated value (for subsequent calls that
# will look at this bindings dictionary) and return.
bindings[p] = instantiate_call(p.func, *args, **kw)
Expand Down
16 changes: 15 additions & 1 deletion searchspaces/tests/test_partialplus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from searchspaces.partialplus import partial, Literal
from searchspaces.partialplus import evaluate
from searchspaces.partialplus import evaluate, variable
from searchspaces.partialplus import depth_first_traversal, topological_sort
from searchspaces.partialplus import as_partialplus as as_pp

Expand Down Expand Up @@ -227,6 +227,20 @@ class Foo(object):
assert r[0] is r[1][-1]
assert r[0] is r[2][0][0]


def test_variable_substitution():
x = variable(name='x', value_type=int)
y = variable(name='y', value_type=float)
p = as_pp({3: x, x:[y, [y]], y:4})
# Currently no type-checking. This will fail when we add it and need to be
# updated.
e = evaluate(p, x='hey', y=5)
assert e[3] == 'hey'
assert e['hey'] == [5, [5]]
assert e[5] == 4



if __name__ == "__main__":
test_switch()
test_switch_range()
Expand Down