Skip to content

Commit

Permalink
Generalize costs to format
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed Sep 8, 2023
1 parent 842d15d commit c256232
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 39 deletions.
19 changes: 17 additions & 2 deletions motile/costs/appear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,30 @@ class Appear(Costs):
Args:
weight (float):
The weight to apply to the cost of each starting track.
attribute (string):
The name of the attribute to use to look up costs. Default is
``None``, which means that a constant cost is used.
constant (float):
A constant cost for each node that starts a track.
"""

def __init__(self, constant: float) -> None:
def __init__(
self, weight: float = 1, attribute: str | None = None, constant: float = 0
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
self.attribute = attribute

def apply(self, solver: Solver) -> None:
appear_indicators = solver.get_variables(NodeAppear)

for index in appear_indicators.values():
for node, index in appear_indicators.items():
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
12 changes: 10 additions & 2 deletions motile/costs/edge_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@ class EdgeDistance(Costs):
weight (float):
The weight to apply to the distance to convert it into a cost.
constant (float):
A constant cost to add to all edges.
"""

def __init__(
self, position_attributes: tuple[str, ...], weight: float = 1.0
self,
position_attributes: tuple[str, ...],
weight: float = 1.0,
constant: float = 0.0,
) -> None:
self.position_attributes = position_attributes
self.weight = Weight(weight)
self.constant = Weight(constant)
self.position_attributes = position_attributes

def apply(self, solver: Solver) -> None:
edge_variables = solver.get_variables(EdgeSelected)
Expand All @@ -46,3 +53,4 @@ def apply(self, solver: Solver) -> None:
feature = np.linalg.norm(pos_u - pos_v)

solver.add_variable_cost(index, feature, self.weight)
solver.add_variable_cost(index, 1.0, self.constant)
18 changes: 16 additions & 2 deletions motile/costs/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,31 @@ class Split(Costs):
"""Costs for :class:`motile.variables.NodeSplit` variables.
Args:
weight (float):
The weight to apply to the cost of each split.
attribute (string)
The name of the attribute to use to look up costs. Default is
``None``, which means that a constant cost is used.
constant (float):
A constant cost for each node that has more than one selected
child.
"""

def __init__(self, constant: float) -> None:
def __init__(
self, weight: float = 1, attribute: str | None = None, constant: float = 0
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
self.attribute = attribute

def apply(self, solver: Solver) -> None:
split_indicators = solver.get_variables(NodeSplit)

for index in split_indicators.values():
for node, index in split_indicators.items():
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
68 changes: 35 additions & 33 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import unittest

import motile
from motile.constraints import MaxChildren, MaxParents
from motile.costs import Appear, EdgeSelection, NodeSelection, Split
from motile.costs import Appear, EdgeDistance, EdgeSelection, NodeSelection, Split
from motile.data import (
arlo_graph,
arlo_graph_nx,
Expand All @@ -11,40 +9,44 @@
)


class TestAPI(unittest.TestCase):
def test_graph_creation_with_hyperedges(self):
graph = toy_hypergraph()
assert len(graph.nodes) == 7
assert len(graph.edges) == 10
def test_graph_creation_with_hyperedges():
graph = toy_hypergraph()
assert len(graph.nodes) == 7
assert len(graph.edges) == 10


def test_graph_creation_from_multiple_nx_graphs():
g1 = toy_hypergraph_nx()
g2 = arlo_graph_nx()
graph = motile.TrackGraph()

def test_graph_creation_from_multiple_nx_graphs(self):
g1 = toy_hypergraph_nx()
g2 = arlo_graph_nx()
graph = motile.TrackGraph()
graph.add_from_nx_graph(g1)
assert len(graph.nodes) == 7
assert len(graph.edges) == 10
assert graph.nodes[6]["x"] == 35
assert "prediction_distance" not in graph.edges[(0, 2)]

graph.add_from_nx_graph(g1)
assert len(graph.nodes) == 7
assert len(graph.edges) == 10
assert graph.nodes[6]["x"] == 35
assert "prediction_distance" not in graph.edges[(0, 2)]
graph.add_from_nx_graph(g2)
assert len(graph.nodes) == 7
assert len(graph.edges) == 11
assert graph.nodes[6]["x"] == 200
assert "prediction_distance" in graph.edges[(0, 2)]

graph.add_from_nx_graph(g2)
assert len(graph.nodes) == 7
assert len(graph.edges) == 11
assert graph.nodes[6]["x"] == 200
assert "prediction_distance" in graph.edges[(0, 2)]

def test_solver(self):
graph = arlo_graph()
def test_solver():
graph = arlo_graph()

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(EdgeSelection(weight=1.0, attribute="prediction_distance"))
solver.add_costs(Appear(constant=200.0))
solver.add_costs(Split(constant=100.0))
solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(2))
solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(
EdgeSelection(weight=0.5, attribute="prediction_distance", constant=-1.0)
)
solver.add_costs(EdgeDistance(position_attributes=("x",), weight=0.5, constant=1.0))
solver.add_costs(Appear(constant=200.0, attribute="score", weight=-1.0))
solver.add_costs(Split(constant=100.0, attribute="score", weight=1.0))
solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(2))

solution = solver.solve()
solution = solver.solve()

assert solution.get_value() == -200
assert solution.get_value() == -202

0 comments on commit c256232

Please sign in to comment.