diff --git a/tests/data.py b/motile/data.py similarity index 64% rename from tests/data.py rename to motile/data.py index f66b4da..575d3b8 100644 --- a/tests/data.py +++ b/motile/data.py @@ -1,8 +1,9 @@ -import motile -import networkx +import networkx as nx +from motile import TrackGraph -def create_arlo_nx_graph() -> networkx.DiGraph: + +def arlo_graph_nx() -> nx.DiGraph: """Create the "Arlo graph", a simple toy graph for testing: x @@ -38,17 +39,32 @@ def create_arlo_nx_graph() -> networkx.DiGraph: {"source": 3, "target": 6, "prediction_distance": 3.0}, ] - nx_graph = networkx.DiGraph() + nx_graph = nx.DiGraph() nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells]) nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) return nx_graph -def create_arlo_trackgraph() -> motile.TrackGraph: - return motile.TrackGraph(create_arlo_nx_graph()) +def arlo_graph() -> TrackGraph: + return TrackGraph(arlo_graph_nx()) + +def toy_graph_nx() -> nx.DiGraph: + """Create variation of the "Arlo graph", with + - one simple edge modified. + - normalized node and edge scores. + - sparse ground truth annotations. -def create_toy_example_nx_graph() -> networkx.DiGraph: + x + | + | --- 6 + | / / + | 1---3---5 + | / x + | 0---2---4 + ------------------------------------ t + 0 1 2 + """ cells = [ {"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1}, {"id": 1, "t": 0, "x": 25, "score": 0.1}, @@ -69,17 +85,17 @@ def create_toy_example_nx_graph() -> networkx.DiGraph: {"source": 3, "target": 4, "score": 0.3}, {"source": 3, "target": 6, "score": 0.8}, ] - nx_graph = networkx.DiGraph() + nx_graph = nx.DiGraph() nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells]) nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) return nx_graph -def create_toy_example_trackgraph() -> motile.TrackGraph: - return motile.TrackGraph(create_toy_example_nx_graph()) +def toy_graph() -> TrackGraph: + return TrackGraph(toy_graph_nx()) -def create_toy_hyperedge_nx_graph() -> networkx.DiGraph: +def toy_hypergraph_nx() -> nx.DiGraph: """Create variation of the "Arlo graph", with one simple edge modified and one hyperedge added. @@ -115,9 +131,11 @@ def create_toy_hyperedge_nx_graph() -> networkx.DiGraph: {"source": 3, "target": 6, "score": 0.8, "gt": None}, ] - nx_graph = networkx.DiGraph() - nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells]) - nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) + nx_graph = nx.DiGraph() + nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells]) # type: ignore + nx_graph.add_edges_from( + [(edge["source"], edge["target"], edge) for edge in edges] # type: ignore + ) # this is how to add a TrackGraph hyperedge into a nx_graph: nx_graph.add_node( @@ -129,34 +147,5 @@ def create_toy_hyperedge_nx_graph() -> networkx.DiGraph: return nx_graph -def create_toy_hyperedge_trackgraph() -> motile.TrackGraph: - return motile.TrackGraph(create_toy_hyperedge_nx_graph()) - - -def create_ssvm_noise_trackgraph() -> motile.TrackGraph: - cells = [ - {"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1, "noise": 0.5}, - {"id": 1, "t": 0, "x": 25, "score": 0.9, "gt": 1, "noise": -0.5}, - {"id": 2, "t": 1, "x": 0, "score": 0.9, "gt": 1, "noise": 0.5}, - {"id": 3, "t": 1, "x": 26, "score": 0.8, "gt": 1, "noise": -0.5}, - {"id": 4, "t": 2, "x": 2, "score": 0.9, "gt": 1, "noise": 0.5}, - {"id": 5, "t": 2, "x": 24, "score": 0.1, "gt": 0, "noise": -0.5}, - {"id": 6, "t": 2, "x": 35, "score": 0.7, "gt": 1, "noise": -0.5}, - ] - - edges = [ - {"source": 0, "target": 2, "score": 0.9, "gt": 1, "noise": 0.5}, - {"source": 1, "target": 3, "score": 0.9, "gt": 1, "noise": -0.5}, - {"source": 0, "target": 3, "score": 0.2, "gt": 0, "noise": 0.5}, - {"source": 1, "target": 2, "score": 0.2, "gt": 0, "noise": -0.5}, - {"source": 2, "target": 4, "score": 0.9, "gt": 1, "noise": 0.5}, - {"source": 3, "target": 5, "score": 0.1, "gt": 0, "noise": -0.5}, - {"source": 2, "target": 5, "score": 0.2, "gt": 0, "noise": 0.5}, - {"source": 3, "target": 4, "score": 0.2, "gt": 0, "noise": -0.5}, - {"source": 3, "target": 6, "score": 0.8, "gt": 1, "noise": -0.5}, - ] - graph = networkx.DiGraph() - graph.add_nodes_from([(cell["id"], cell) for cell in cells]) - graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) - - return motile.TrackGraph(graph) +def toy_hypergraph() -> TrackGraph: + return TrackGraph(toy_hypergraph_nx()) diff --git a/tests/test_api.py b/tests/test_api.py index 5bc36fd..c1a5cea 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,25 +1,25 @@ import unittest import motile -from data import ( - create_arlo_nx_graph, - create_arlo_trackgraph, - create_toy_hyperedge_nx_graph, - create_toy_hyperedge_trackgraph, -) from motile.constraints import MaxChildren, MaxParents from motile.costs import Appear, EdgeSelection, NodeSelection, Split +from motile.data import ( + arlo_graph, + arlo_graph_nx, + toy_hypergraph, + toy_hypergraph_nx, +) class TestAPI(unittest.TestCase): def test_graph_creation_with_hyperedges(self): - graph = create_toy_hyperedge_trackgraph() + graph = toy_hypergraph() assert len(graph.nodes) == 7 assert len(graph.edges) == 10 def test_graph_creation_from_multiple_nx_graphs(self): - g1 = create_toy_hyperedge_nx_graph() - g2 = create_arlo_nx_graph() + g1 = toy_hypergraph_nx() + g2 = arlo_graph_nx() graph = motile.TrackGraph() graph.add_from_nx_graph(g1) @@ -35,7 +35,7 @@ def test_graph_creation_from_multiple_nx_graphs(self): assert "prediction_distance" in graph.edges[(0, 2)] def test_solver(self): - graph = create_arlo_trackgraph() + graph = arlo_graph() solver = motile.Solver(graph) solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 7baaf51..f0cdc77 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1,15 +1,15 @@ import unittest import motile -from data import create_arlo_trackgraph from motile.constraints import MaxChildren, MaxParents, Pin from motile.costs import Appear, EdgeSelection, NodeSelection, Split +from motile.data import arlo_graph from motile.variables import EdgeSelected class TestConstraints(unittest.TestCase): def test_pin(self): - graph = create_arlo_trackgraph() + graph = arlo_graph() # pin the value of two edges: graph.edges[(0, 2)]["pin_to"] = False diff --git a/tests/test_plot.py b/tests/test_plot.py index d940bd0..8a74b16 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,7 +1,7 @@ import motile import pytest -from data import create_arlo_trackgraph from motile.costs import Appear, EdgeSelection, NodeSelection, Split +from motile.data import arlo_graph from motile.plot import draw_solution, draw_track_graph try: @@ -12,7 +12,7 @@ @pytest.fixture def graph() -> motile.TrackGraph: - return create_arlo_trackgraph() + return arlo_graph() @pytest.fixture diff --git a/tests/test_structsvm.py b/tests/test_structsvm.py index c0e084f..c41b00e 100644 --- a/tests/test_structsvm.py +++ b/tests/test_structsvm.py @@ -1,15 +1,45 @@ import logging import motile +import networkx import numpy as np -from data import create_ssvm_noise_trackgraph, create_toy_example_trackgraph from motile.constraints import MaxChildren, MaxParents from motile.costs import Appear, EdgeSelection, NodeSelection +from motile.data import toy_graph from motile.variables import EdgeSelected, NodeSelected logger = logging.getLogger(__name__) +def create_ssvm_noise_trackgraph() -> motile.TrackGraph: + cells = [ + {"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1, "noise": 0.5}, + {"id": 1, "t": 0, "x": 25, "score": 0.9, "gt": 1, "noise": -0.5}, + {"id": 2, "t": 1, "x": 0, "score": 0.9, "gt": 1, "noise": 0.5}, + {"id": 3, "t": 1, "x": 26, "score": 0.8, "gt": 1, "noise": -0.5}, + {"id": 4, "t": 2, "x": 2, "score": 0.9, "gt": 1, "noise": 0.5}, + {"id": 5, "t": 2, "x": 24, "score": 0.1, "gt": 0, "noise": -0.5}, + {"id": 6, "t": 2, "x": 35, "score": 0.7, "gt": 1, "noise": -0.5}, + ] + + edges = [ + {"source": 0, "target": 2, "score": 0.9, "gt": 1, "noise": 0.5}, + {"source": 1, "target": 3, "score": 0.9, "gt": 1, "noise": -0.5}, + {"source": 0, "target": 3, "score": 0.2, "gt": 0, "noise": 0.5}, + {"source": 1, "target": 2, "score": 0.2, "gt": 0, "noise": -0.5}, + {"source": 2, "target": 4, "score": 0.9, "gt": 1, "noise": 0.5}, + {"source": 3, "target": 5, "score": 0.1, "gt": 0, "noise": -0.5}, + {"source": 2, "target": 5, "score": 0.2, "gt": 0, "noise": 0.5}, + {"source": 3, "target": 4, "score": 0.2, "gt": 0, "noise": -0.5}, + {"source": 3, "target": 6, "score": 0.8, "gt": 1, "noise": -0.5}, + ] + graph = networkx.DiGraph() + graph.add_nodes_from([(cell["id"], cell) for cell in cells]) + graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) + + return motile.TrackGraph(graph) + + def create_toy_solver(graph): solver = motile.Solver(graph) @@ -27,7 +57,7 @@ def create_toy_solver(graph): def test_structsvm_common_toy_example(): - graph = create_toy_example_trackgraph() + graph = toy_graph() solver = create_toy_solver(graph)