Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic expression constraint (generalization of Pin) #31

Merged
merged 9 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions motile/_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from typing import Any, Hashable, TypeAlias, Union
from typing import Any, Mapping, TypeAlias, Union

# Nodes are represented as integers, or a "meta-node" tuple of integers.
NodeId: TypeAlias = Union[int, tuple[int, ...]]

# objects in the graph are represented as dicts
# eg. { "id": 1, "x": 0.5, "y": 0.5, "t": 0 }
GraphObject: TypeAlias = dict[Hashable, Any]
GraphObject: TypeAlias = Mapping[str, Any]

# Edges are represented as tuples of NodeId.
# (0, 1) is an edge from node 0 to node 1.
Expand Down
2 changes: 2 additions & 0 deletions motile/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .constraint import Constraint
from .expression import ExpressionConstraint
from .max_children import MaxChildren
from .max_parents import MaxParents
from .pin import Pin
from .select_edge_nodes import SelectEdgeNodes

__all__ = [
"Constraint",
"ExpressionConstraint",
"MaxChildren",
"MaxParents",
"Pin",
Expand Down
113 changes: 113 additions & 0 deletions motile/constraints/expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import ast
import contextlib
from typing import TYPE_CHECKING, Union

import ilpy

from ..variables import EdgeSelected, NodeSelected, Variable
from .constraint import Constraint

if TYPE_CHECKING:
from motile._types import EdgeId, GraphObject, NodeId
from motile.solver import Solver

NodesOrEdges = Union[dict[NodeId, GraphObject], dict[EdgeId, GraphObject]]


class ExpressionConstraint(Constraint):
"""Enforces the selection of nodes/edges based on an expression evaluated
with the node/edge dict as a namespace.

This is a powerful general constraint that allows you to select nodes/edges based on
any combination of node/edge attributes. The `expression` string is evaluated for
each node/edge (assuming eval_nodes/eval_edges is True) using the actual node object
as a namespace to populate any variables names used in the provided expression. If
the expression evaluates to True, the node/edge is selected; otherwise, it is
excluded.

This takes advantaged of python's `eval` function, like this:

```python
my_expression = "some_attribute == True"
eval(my_expression, None, {"some_attribute": True}) # returns True (select)
eval(my_expression, None, {"some_attribute": False}) # returns False (exclude)
eval(my_expression, None, {}) # raises NameError (do nothing)
```

Args:
expression (string):
An expression to evaluate for each node/edge. The expression must
evaluate to a boolean value. The expression can use any names of
node/edge attributes as variables.
eval_nodes (bool):
Whether to evaluate the expression for nodes. By default, True.
eval_edges (bool):
Whether to evaluate the expression for edges. By default, True.

Example:

If the nodes of a graph are:
cells = [
{"id": 0, "t": 0, "color": "red", "score": 1.0},
{"id": 1, "t": 0, "color": "green", "score": 1.0},
{"id": 2, "t": 1, "color": "blue", "score": 1.0},
]

Then the following constraint will select node 0:
>>> expr = "t == 0 and color != 'green'"
>>> solver.add_constraints(ExpressionConstraint(expr))
"""

def __init__(
self, expression: str, eval_nodes: bool = True, eval_edges: bool = True
) -> None:
try:
tree = ast.parse(expression, mode="eval")
if not isinstance(tree, ast.Expression):
raise SyntaxError
except SyntaxError:
raise ValueError(f"Invalid expression: {expression}") from None

self._expression = compile(expression, "<string>", "eval")
self.eval_nodes = eval_nodes
self.eval_edges = eval_edges

def instantiate(self, solver: Solver) -> list[ilpy.Constraint]:
# create two constraints: one to select nodes/edges, and one to exclude
select = ilpy.Constraint()
exclude = ilpy.Constraint()
n_selected = 0 # number of nodes/edges selected

to_evaluate: list[tuple[NodesOrEdges, type[Variable]]] = []
if self.eval_nodes:
to_evaluate.append((solver.graph.nodes, NodeSelected))
if self.eval_edges:
to_evaluate.append((solver.graph.edges, EdgeSelected))

for nodes_or_edges, VariableType in to_evaluate:
indicator_variables = solver.get_variables(VariableType)
for id_, node_or_edge in nodes_or_edges.items():
with contextlib.suppress(NameError):
# Here is where the expression string is evaluated.
# We use the node/edge dict as a namespace to look up variables.
# if the expression uses a variable name that is not in the dict,
# a NameError will be raised.
# contextlib.suppress (above) will just skip it and move on...
if eval(self._expression, None, node_or_edge):
# if the expression evaluates to True, we select the node/edge
select.set_coefficient(indicator_variables[id_], 1)
n_selected += 1
else:
# Otherwise, we exclude it.
exclude.set_coefficient(indicator_variables[id_], 1)

# finally, apply the relation and value to the constraints
select.set_relation(ilpy.Relation.Equal)
select.set_value(n_selected)

exclude.set_relation(ilpy.Relation.Equal)
exclude.set_value(0)

return [select, exclude]
54 changes: 3 additions & 51 deletions motile/constraints/pin.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from .expression import ExpressionConstraint

import ilpy

from ..variables import EdgeSelected, NodeSelected
from .constraint import Constraint

if TYPE_CHECKING:
from motile.solver import Solver


class Pin(Constraint):
class Pin(ExpressionConstraint):
"""Enforces the selection of certain nodes and edges based on the value of
a given attribute.

Expand All @@ -31,44 +23,4 @@ class Pin(Constraint):
"""

def __init__(self, attribute: str) -> None:
self.attribute = attribute

def instantiate(self, solver: Solver) -> list[ilpy.Constraint]:
node_indicators = solver.get_variables(NodeSelected)
edge_indicators = solver.get_variables(EdgeSelected)

must_select = [
node_indicators[node]
for node, attributes in solver.graph.nodes.items()
if self.attribute in attributes and attributes[self.attribute]
] + [
edge_indicators[(u, v)]
for (u, v), attributes in solver.graph.edges.items()
if self.attribute in attributes and attributes[self.attribute]
]

must_not_select = [
node_indicators[node]
for node, attributes in solver.graph.nodes.items()
if self.attribute in attributes and not attributes[self.attribute]
] + [
edge_indicators[(u, v)]
for (u, v), attributes in solver.graph.edges.items()
if self.attribute in attributes and not attributes[self.attribute]
]

must_select_constraint = ilpy.Constraint()
must_not_select_constraint = ilpy.Constraint()

for index in must_select:
must_select_constraint.set_coefficient(index, 1)
for index in must_not_select:
must_not_select_constraint.set_coefficient(index, 1)

must_select_constraint.set_relation(ilpy.Relation.Equal)
must_not_select_constraint.set_relation(ilpy.Relation.Equal)

must_select_constraint.set_value(len(must_select))
must_not_select_constraint.set_value(0)

return [must_select_constraint, must_not_select_constraint]
super().__init__(f"{attribute} == True", eval_nodes=True, eval_edges=True)
14 changes: 7 additions & 7 deletions motile/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, overload
from typing import TYPE_CHECKING, Any, Callable, Mapping, overload

import numpy as np

Expand All @@ -24,7 +24,7 @@
PURPLE = (127, 30, 121)


def _attr_hover_text(attrs: dict) -> str:
def _attr_hover_text(attrs: Mapping) -> str:
return "<br>".join([f"{name}: {value}" for name, value in attrs.items()])


Expand Down Expand Up @@ -99,7 +99,7 @@ def draw_track_graph(
if position_func is None:

def position_func(node: NodeId) -> float:
return float(graph.nodes[node][position_attribute])
return float(graph.nodes[node][position_attribute]) # type: ignore

alpha_node_func: ReturnsFloat
alpha_edge_func: ReturnsFloat
Expand All @@ -109,10 +109,10 @@ def position_func(node: NodeId) -> float:
if alpha_attribute is not None:

def alpha_node_func(node):
return graph.nodes[node].get(alpha_attribute, 1.0)
return graph.nodes[node].get(alpha_attribute, 1.0) # type: ignore

def alpha_edge_func(edge):
return graph.edges[edge].get(alpha_attribute, 1.0)
return graph.edges[edge].get(alpha_attribute, 1.0) # type: ignore

elif alpha_func is None:

Expand All @@ -131,10 +131,10 @@ def alpha_edge_func(_):
if label_attribute is not None:

def label_node_func(node):
return graph.nodes[node].get(label_attribute, "")
return graph.nodes[node].get(label_attribute, "") # type: ignore

def label_edge_func(edge):
return graph.edges[edge].get(label_attribute, "")
return graph.edges[edge].get(label_attribute, "") # type: ignore

elif label_func is None:

Expand Down
23 changes: 22 additions & 1 deletion tests/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest

import motile
from motile.constraints import MaxChildren, MaxParents, Pin
from motile.constraints import ExpressionConstraint, MaxChildren, MaxParents, Pin
from motile.costs import Appear, EdgeSelection, NodeSelection, Split
from motile.data import arlo_graph
from motile.variables import EdgeSelected
from motile.variables.node_selected import NodeSelected


class TestConstraints(unittest.TestCase):
Expand Down Expand Up @@ -34,3 +35,23 @@ def test_pin(self):

assert (0, 2) not in selected_edges
assert (3, 6) in selected_edges

def test_complex_expression(self):
graph = arlo_graph()
graph.nodes[5]["color"] = "red"

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(EdgeSelection(weight=1.0, attribute="prediction_distance"))

# constrain solver based on attributes of nodes/edges
expr = "x > 140 and t != 1 and color != 'red'"
solver.add_constraints(ExpressionConstraint(expr))

solution = solver.solve()
node_indicators = solver.get_variables(NodeSelected)
selected_nodes = [
node for node, index in node_indicators.items() if solution[index] > 0.5
]

assert selected_nodes == [1, 6]