Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize costs #54

Merged
merged 4 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions motile/costs/appear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,30 @@ class Appear(Costs):
"""Costs for :class:`~motile.variables.NodeAppear` variables.

Args:
weight:
The weight to apply to the cost of each starting track.

attribute:
The name of the attribute to use to look up costs. Default is
``None``, which means that a constant cost is used.

constant:
A constant cost for each node that starts a track.
"""

def __init__(self, constant: float) -> None:
def __init__(
self, weight: float = 1, attribute: str | None = None, constant: float = 0
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
self.attribute = attribute

def apply(self, solver: Solver) -> None:
appear_indicators = solver.get_variables(NodeAppear)

for index in appear_indicators.values():
for node, index in appear_indicators.items():
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
19 changes: 17 additions & 2 deletions motile/costs/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,31 @@ class Split(Costs):
"""Costs for :class:`~motile.variables.NodeSplit` variables.

Args:
weight:
The weight to apply to the cost of each split.

attribute:
The name of the attribute to use to look up costs. Default is
``None``, which means that a constant cost is used.

constant:
A constant cost for each node that has more than one selected
child.
"""

def __init__(self, constant: float) -> None:
def __init__(
self, weight: float = 1, attribute: str | None = None, constant: float = 0
) -> None:
self.weight = Weight(weight)
self.constant = Weight(constant)
self.attribute = attribute

def apply(self, solver: Solver) -> None:
split_indicators = solver.get_variables(NodeSplit)

for index in split_indicators.values():
for node, index in split_indicators.items():
if self.attribute is not None:
solver.add_variable_cost(
index, solver.graph.nodes[node][self.attribute], self.weight
)
solver.add_variable_cost(index, 1.0, self.constant)
Empty file added tests/__init__.py
Empty file.
106 changes: 65 additions & 41 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,75 @@
import unittest

import motile
from motile.constraints import MaxChildren, MaxParents
from motile.costs import Appear, Disappear, EdgeSelection, NodeSelection, Split
from motile.costs import (
Appear,
EdgeDistance,
EdgeSelection,
NodeSelection,
Split,
)
from motile.data import (
arlo_graph,
arlo_graph_nx,
toy_hypergraph,
toy_hypergraph_nx,
)
from motile.variables import EdgeSelected, NodeSelected


def _selected_nodes(solver: motile.Solver) -> list:
node_indicators = solver.get_variables(NodeSelected)
solution = solver.solve()
return sorted([n for n, i in node_indicators.items() if solution[i] > 0.5])


def _selected_edges(solver: motile.Solver) -> list:
edge_indicators = solver.get_variables(EdgeSelected)
solution = solver.solve()
return sorted([e for e, i in edge_indicators.items() if solution[i] > 0.5])


def test_graph_creation_with_hyperedges():
graph = toy_hypergraph()
assert len(graph.nodes) == 7
assert len(graph.edges) == 10


def test_graph_creation_from_multiple_nx_graphs():
g1 = toy_hypergraph_nx()
g2 = arlo_graph_nx()
graph = motile.TrackGraph()

graph.add_from_nx_graph(g1)
assert len(graph.nodes) == 7
assert len(graph.edges) == 10
assert graph.nodes[6]["x"] == 35
assert "prediction_distance" not in graph.edges[(0, 2)]

graph.add_from_nx_graph(g2)
assert len(graph.nodes) == 7
assert len(graph.edges) == 11
assert graph.nodes[6]["x"] == 200
assert "prediction_distance" in graph.edges[(0, 2)]


def test_solver():
graph = arlo_graph()

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(
EdgeSelection(weight=0.5, attribute="prediction_distance", constant=-1.0)
)
solver.add_costs(EdgeDistance(position_attributes=("x",), weight=0.5))
solver.add_costs(Appear(constant=200.0, attribute="score", weight=-1.0))
solver.add_costs(Split(constant=100.0, attribute="score", weight=1.0))

solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(2))

solution = solver.solve()

class TestAPI(unittest.TestCase):
def test_graph_creation_with_hyperedges(self):
graph = toy_hypergraph()
assert len(graph.nodes) == 7
assert len(graph.edges) == 10

def test_graph_creation_from_multiple_nx_graphs(self):
g1 = toy_hypergraph_nx()
g2 = arlo_graph_nx()
graph = motile.TrackGraph()

graph.add_from_nx_graph(g1)
assert len(graph.nodes) == 7
assert len(graph.edges) == 10
assert graph.nodes[6]["x"] == 35
assert "prediction_distance" not in graph.edges[(0, 2)]

graph.add_from_nx_graph(g2)
assert len(graph.nodes) == 7
assert len(graph.edges) == 11
assert graph.nodes[6]["x"] == 200
assert "prediction_distance" in graph.edges[(0, 2)]

def test_solver(self):
graph = arlo_graph()

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
solver.add_costs(EdgeSelection(weight=1.0, attribute="prediction_distance"))
solver.add_costs(Appear(constant=100.0))
solver.add_costs(Disappear(constant=100.0))
solver.add_costs(Split(constant=100.0))
solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(2))

solution = solver.solve()

assert solution.get_value() == -200
assert _selected_edges(solver) == [(0, 2), (1, 3), (2, 4), (3, 5)]
assert _selected_nodes(solver) == [0, 1, 2, 3, 4, 5]
cost = solution.get_value()
assert cost == -206.0, f"{cost=}"
16 changes: 2 additions & 14 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,15 @@
from motile.constraints import ExpressionConstraint, MaxChildren, MaxParents, Pin
from motile.costs import EdgeSelection, NodeSelection
from motile.data import arlo_graph
from motile.variables import EdgeSelected
from motile.variables.node_selected import NodeSelected

from .test_api import _selected_edges, _selected_nodes


@pytest.fixture
def solver():
return motile.Solver(arlo_graph())


def _selected_edges(solver: motile.Solver) -> list:
edge_indicators = solver.get_variables(EdgeSelected)
solution = solver.solve()
return [e for e, i in edge_indicators.items() if solution[i]]


def _selected_nodes(solver: motile.Solver) -> list:
node_indicators = solver.get_variables(NodeSelected)
solution = solver.solve()
return [e for e, i in node_indicators.items() if solution[i]]


def test_pin(solver: motile.Solver) -> None:
# pin the value of two edges:
solver.graph.edges[(0, 2)]["pin_to"] = False # type: ignore
Expand Down
Loading