Skip to content

Commit

Permalink
Move example graphs into package
Browse files Browse the repository at this point in the history
Moves the three example graphs `arlo_graph`, `toy_graph` and
`toy_hypergraph` into `motile.data`, similar to e.g. `skimage.data`.
It seems useful to have a one-liner to get some small `TrackGraph` when
prototyping in motile.

Creating and plotting a graph now works as follows:
```python
from motile import data, plot
graph = data.arlo_graph()
plot.draw_track_graph(graph)
```

Graphs specific to a certain test can live directly in that test module.
  • Loading branch information
bentaculum committed Mar 29, 2023
1 parent 8defb53 commit 2fbc861
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 58 deletions.
3 changes: 2 additions & 1 deletion motile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .solver import Solver
from .track_graph import TrackGraph
from . import data

__all__ = ["Solver", "TrackGraph"]
__all__ = ["Solver", "TrackGraph", "data"]
__version__ = "0.1.2"
73 changes: 30 additions & 43 deletions tests/data.py → motile/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import motile
import networkx
import networkx as nx

from motile import TrackGraph

def create_arlo_nx_graph() -> networkx.DiGraph:

def arlo_nx_graph() -> nx.DiGraph:
"""Create the "Arlo graph", a simple toy graph for testing:
x
Expand Down Expand Up @@ -38,17 +39,32 @@ def create_arlo_nx_graph() -> networkx.DiGraph:
{"source": 3, "target": 6, "prediction_distance": 3.0},
]

nx_graph = networkx.DiGraph()
nx_graph = nx.DiGraph()
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
return nx_graph


def create_arlo_trackgraph() -> motile.TrackGraph:
return motile.TrackGraph(create_arlo_nx_graph())
def arlo_graph() -> TrackGraph:
return TrackGraph(arlo_nx_graph())


def toy_example_nx_graph() -> nx.DiGraph:
"""Create variation of the "Arlo graph", with
- one simple edge modified.
- normalized node and edge scores.
- sparse ground truth annotations.
def create_toy_example_nx_graph() -> networkx.DiGraph:
x
|
| --- 6
| / /
| 1---3---5
| / x
| 0---2---4
------------------------------------ t
0 1 2
"""
cells = [
{"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1},
{"id": 1, "t": 0, "x": 25, "score": 0.1},
Expand All @@ -69,17 +85,17 @@ def create_toy_example_nx_graph() -> networkx.DiGraph:
{"source": 3, "target": 4, "score": 0.3},
{"source": 3, "target": 6, "score": 0.8},
]
nx_graph = networkx.DiGraph()
nx_graph = nx.DiGraph()
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
return nx_graph


def create_toy_example_trackgraph() -> motile.TrackGraph:
return motile.TrackGraph(create_toy_example_nx_graph())
def toy_graph() -> TrackGraph:
return TrackGraph(toy_example_nx_graph())


def create_toy_hyperedge_nx_graph() -> networkx.DiGraph:
def toy_hyperedge_nx_graph() -> nx.DiGraph:
"""Create variation of the "Arlo graph", with one simple
edge modified and one hyperedge added.
Expand Down Expand Up @@ -115,7 +131,7 @@ def create_toy_hyperedge_nx_graph() -> networkx.DiGraph:
{"source": 3, "target": 6, "score": 0.8, "gt": None},
]

nx_graph = networkx.DiGraph()
nx_graph = nx.DiGraph()
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])

Expand All @@ -129,34 +145,5 @@ def create_toy_hyperedge_nx_graph() -> networkx.DiGraph:
return nx_graph


def create_toy_hyperedge_trackgraph() -> motile.TrackGraph:
return motile.TrackGraph(create_toy_hyperedge_nx_graph())


def create_ssvm_noise_trackgraph() -> motile.TrackGraph:
cells = [
{"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1, "noise": 0.5},
{"id": 1, "t": 0, "x": 25, "score": 0.9, "gt": 1, "noise": -0.5},
{"id": 2, "t": 1, "x": 0, "score": 0.9, "gt": 1, "noise": 0.5},
{"id": 3, "t": 1, "x": 26, "score": 0.8, "gt": 1, "noise": -0.5},
{"id": 4, "t": 2, "x": 2, "score": 0.9, "gt": 1, "noise": 0.5},
{"id": 5, "t": 2, "x": 24, "score": 0.1, "gt": 0, "noise": -0.5},
{"id": 6, "t": 2, "x": 35, "score": 0.7, "gt": 1, "noise": -0.5},
]

edges = [
{"source": 0, "target": 2, "score": 0.9, "gt": 1, "noise": 0.5},
{"source": 1, "target": 3, "score": 0.9, "gt": 1, "noise": -0.5},
{"source": 0, "target": 3, "score": 0.2, "gt": 0, "noise": 0.5},
{"source": 1, "target": 2, "score": 0.2, "gt": 0, "noise": -0.5},
{"source": 2, "target": 4, "score": 0.9, "gt": 1, "noise": 0.5},
{"source": 3, "target": 5, "score": 0.1, "gt": 0, "noise": -0.5},
{"source": 2, "target": 5, "score": 0.2, "gt": 0, "noise": 0.5},
{"source": 3, "target": 4, "score": 0.2, "gt": 0, "noise": -0.5},
{"source": 3, "target": 6, "score": 0.8, "gt": 1, "noise": -0.5},
]
graph = networkx.DiGraph()
graph.add_nodes_from([(cell["id"], cell) for cell in cells])
graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])

return motile.TrackGraph(graph)
def toy_hypergraph() -> TrackGraph:
return TrackGraph(toy_hyperedge_nx_graph())
18 changes: 9 additions & 9 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import unittest

import motile
from data import (
create_arlo_nx_graph,
create_arlo_trackgraph,
create_toy_hyperedge_nx_graph,
create_toy_hyperedge_trackgraph,
from motile.data import (
arlo_nx_graph,
arlo_graph,
toy_hyperedge_nx_graph,
toy_hypergraph,
)
from motile.constraints import MaxChildren, MaxParents
from motile.costs import Appear, EdgeSelection, NodeSelection, Split


class TestAPI(unittest.TestCase):
def test_graph_creation_with_hyperedges(self):
graph = create_toy_hyperedge_trackgraph()
graph = toy_hypergraph()
assert len(graph.nodes) == 7
assert len(graph.edges) == 10

def test_graph_creation_from_multiple_nx_graphs(self):
g1 = create_toy_hyperedge_nx_graph()
g2 = create_arlo_nx_graph()
g1 = toy_hyperedge_nx_graph()
g2 = arlo_nx_graph()
graph = motile.TrackGraph()

graph.add_from_nx_graph(g1)
Expand All @@ -35,7 +35,7 @@ def test_graph_creation_from_multiple_nx_graphs(self):
assert "prediction_distance" in graph.edges[(0, 2)]

def test_solver(self):
graph = create_arlo_trackgraph()
graph = arlo_graph()

solver = motile.Solver(graph)
solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import unittest

import motile
from data import create_arlo_trackgraph
from motile.data import arlo_graph
from motile.constraints import MaxChildren, MaxParents, Pin
from motile.costs import Appear, EdgeSelection, NodeSelection, Split
from motile.variables import EdgeSelected


class TestConstraints(unittest.TestCase):
def test_pin(self):
graph = create_arlo_trackgraph()
graph = arlo_graph()

# pin the value of two edges:
graph.edges[(0, 2)]["pin_to"] = False
Expand Down
36 changes: 33 additions & 3 deletions tests/test_structsvm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,45 @@
import logging

import motile
import numpy as np
from data import create_ssvm_noise_trackgraph, create_toy_example_trackgraph
import networkx
import motile
from motile.data import toy_graph
from motile.constraints import MaxChildren, MaxParents
from motile.costs import Appear, EdgeSelection, NodeSelection
from motile.variables import EdgeSelected, NodeSelected

logger = logging.getLogger(__name__)


def create_ssvm_noise_trackgraph() -> motile.TrackGraph:
cells = [
{"id": 0, "t": 0, "x": 1, "score": 0.8, "gt": 1, "noise": 0.5},
{"id": 1, "t": 0, "x": 25, "score": 0.9, "gt": 1, "noise": -0.5},
{"id": 2, "t": 1, "x": 0, "score": 0.9, "gt": 1, "noise": 0.5},
{"id": 3, "t": 1, "x": 26, "score": 0.8, "gt": 1, "noise": -0.5},
{"id": 4, "t": 2, "x": 2, "score": 0.9, "gt": 1, "noise": 0.5},
{"id": 5, "t": 2, "x": 24, "score": 0.1, "gt": 0, "noise": -0.5},
{"id": 6, "t": 2, "x": 35, "score": 0.7, "gt": 1, "noise": -0.5},
]

edges = [
{"source": 0, "target": 2, "score": 0.9, "gt": 1, "noise": 0.5},
{"source": 1, "target": 3, "score": 0.9, "gt": 1, "noise": -0.5},
{"source": 0, "target": 3, "score": 0.2, "gt": 0, "noise": 0.5},
{"source": 1, "target": 2, "score": 0.2, "gt": 0, "noise": -0.5},
{"source": 2, "target": 4, "score": 0.9, "gt": 1, "noise": 0.5},
{"source": 3, "target": 5, "score": 0.1, "gt": 0, "noise": -0.5},
{"source": 2, "target": 5, "score": 0.2, "gt": 0, "noise": 0.5},
{"source": 3, "target": 4, "score": 0.2, "gt": 0, "noise": -0.5},
{"source": 3, "target": 6, "score": 0.8, "gt": 1, "noise": -0.5},
]
graph = networkx.DiGraph()
graph.add_nodes_from([(cell["id"], cell) for cell in cells])
graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])

return motile.TrackGraph(graph)


def create_toy_solver(graph):
solver = motile.Solver(graph)

Expand All @@ -27,7 +57,7 @@ def create_toy_solver(graph):


def test_structsvm_common_toy_example():
graph = create_toy_example_trackgraph()
graph = toy_graph()

solver = create_toy_solver(graph)

Expand Down

0 comments on commit 2fbc861

Please sign in to comment.