diff --git a/motile/costs/appear.py b/motile/costs/appear.py index acac553..a0e36d8 100644 --- a/motile/costs/appear.py +++ b/motile/costs/appear.py @@ -15,15 +15,30 @@ class Appear(Costs): Args: + weight (float): + The weight to apply to the cost of each starting track. + + attribute (string): + The name of the attribute to use to look up costs. Default is + ``None``, which means that a constant cost is used. + constant (float): A constant cost for each node that starts a track. """ - def __init__(self, constant: float) -> None: + def __init__( + self, weight: float = 1, attribute: str | None = None, constant: float = 0 + ) -> None: + self.weight = Weight(weight) self.constant = Weight(constant) + self.attribute = attribute def apply(self, solver: Solver) -> None: appear_indicators = solver.get_variables(NodeAppear) - for index in appear_indicators.values(): + for node, index in appear_indicators.items(): + if self.attribute is not None: + solver.add_variable_cost( + index, solver.graph.nodes[node][self.attribute], self.weight + ) solver.add_variable_cost(index, 1.0, self.constant) diff --git a/motile/costs/edge_distance.py b/motile/costs/edge_distance.py index 733286e..6bb0cb7 100644 --- a/motile/costs/edge_distance.py +++ b/motile/costs/edge_distance.py @@ -24,13 +24,20 @@ class EdgeDistance(Costs): weight (float): The weight to apply to the distance to convert it into a cost. + + constant (float): + A constant cost to add to all edges. """ def __init__( - self, position_attributes: tuple[str, ...], weight: float = 1.0 + self, + position_attributes: tuple[str, ...], + weight: float = 1.0, + constant: float = 0.0, ) -> None: - self.position_attributes = position_attributes self.weight = Weight(weight) + self.constant = Weight(constant) + self.position_attributes = position_attributes def apply(self, solver: Solver) -> None: edge_variables = solver.get_variables(EdgeSelected) @@ -46,3 +53,4 @@ def apply(self, solver: Solver) -> None: feature = np.linalg.norm(pos_u - pos_v) solver.add_variable_cost(index, feature, self.weight) + solver.add_variable_cost(index, 1.0, self.constant) diff --git a/motile/costs/split.py b/motile/costs/split.py index f0332bc..d19aa06 100644 --- a/motile/costs/split.py +++ b/motile/costs/split.py @@ -14,17 +14,31 @@ class Split(Costs): """Costs for :class:`motile.variables.NodeSplit` variables. Args: + weight (float): + The weight to apply to the cost of each split. + + attribute (string) + The name of the attribute to use to look up costs. Default is + ``None``, which means that a constant cost is used. constant (float): A constant cost for each node that has more than one selected child. """ - def __init__(self, constant: float) -> None: + def __init__( + self, weight: float = 1, attribute: str | None = None, constant: float = 0 + ) -> None: + self.weight = Weight(weight) self.constant = Weight(constant) + self.attribute = attribute def apply(self, solver: Solver) -> None: split_indicators = solver.get_variables(NodeSplit) - for index in split_indicators.values(): + for node, index in split_indicators.items(): + if self.attribute is not None: + solver.add_variable_cost( + index, solver.graph.nodes[node][self.attribute], self.weight + ) solver.add_variable_cost(index, 1.0, self.constant) diff --git a/tests/test_api.py b/tests/test_api.py index c1a5cea..e8f244f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,6 @@ -import unittest - import motile from motile.constraints import MaxChildren, MaxParents -from motile.costs import Appear, EdgeSelection, NodeSelection, Split +from motile.costs import Appear, EdgeDistance, EdgeSelection, NodeSelection, Split from motile.data import ( arlo_graph, arlo_graph_nx, @@ -11,40 +9,44 @@ ) -class TestAPI(unittest.TestCase): - def test_graph_creation_with_hyperedges(self): - graph = toy_hypergraph() - assert len(graph.nodes) == 7 - assert len(graph.edges) == 10 +def test_graph_creation_with_hyperedges(): + graph = toy_hypergraph() + assert len(graph.nodes) == 7 + assert len(graph.edges) == 10 + + +def test_graph_creation_from_multiple_nx_graphs(): + g1 = toy_hypergraph_nx() + g2 = arlo_graph_nx() + graph = motile.TrackGraph() - def test_graph_creation_from_multiple_nx_graphs(self): - g1 = toy_hypergraph_nx() - g2 = arlo_graph_nx() - graph = motile.TrackGraph() + graph.add_from_nx_graph(g1) + assert len(graph.nodes) == 7 + assert len(graph.edges) == 10 + assert graph.nodes[6]["x"] == 35 + assert "prediction_distance" not in graph.edges[(0, 2)] - graph.add_from_nx_graph(g1) - assert len(graph.nodes) == 7 - assert len(graph.edges) == 10 - assert graph.nodes[6]["x"] == 35 - assert "prediction_distance" not in graph.edges[(0, 2)] + graph.add_from_nx_graph(g2) + assert len(graph.nodes) == 7 + assert len(graph.edges) == 11 + assert graph.nodes[6]["x"] == 200 + assert "prediction_distance" in graph.edges[(0, 2)] - graph.add_from_nx_graph(g2) - assert len(graph.nodes) == 7 - assert len(graph.edges) == 11 - assert graph.nodes[6]["x"] == 200 - assert "prediction_distance" in graph.edges[(0, 2)] - def test_solver(self): - graph = arlo_graph() +def test_solver(): + graph = arlo_graph() - solver = motile.Solver(graph) - solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) - solver.add_costs(EdgeSelection(weight=1.0, attribute="prediction_distance")) - solver.add_costs(Appear(constant=200.0)) - solver.add_costs(Split(constant=100.0)) - solver.add_constraints(MaxParents(1)) - solver.add_constraints(MaxChildren(2)) + solver = motile.Solver(graph) + solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) + solver.add_costs( + EdgeSelection(weight=0.5, attribute="prediction_distance", constant=-1.0) + ) + solver.add_costs(EdgeDistance(position_attributes=("x",), weight=0.5, constant=1.0)) + solver.add_costs(Appear(constant=200.0, attribute="score", weight=-1.0)) + solver.add_costs(Split(constant=100.0, attribute="score", weight=1.0)) + solver.add_constraints(MaxParents(1)) + solver.add_constraints(MaxChildren(2)) - solution = solver.solve() + solution = solver.solve() - assert solution.get_value() == -200 + assert solution.get_value() == -202