Skip to content

Commit

Permalink
Merge pull request #112 from funkelab/hyperedges
Browse files Browse the repository at this point in the history
Rename and tighten type definitions
  • Loading branch information
cmalinmayor authored Aug 13, 2024
2 parents 825a461 + fb955bc commit 8ad0ffa
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 72 deletions.
30 changes: 20 additions & 10 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,38 @@ API Reference
.. automodule:: motile
:noindex:

.. admonition:: A note on ``NodeId`` and ``EdgeId`` types
.. admonition:: A note on ``Node`` and ``Edge`` types
:class: note, dropdown

The following types are used throughout the docs

- All objects in a graph (both ``Nodes`` and ``Edges``) are represented as
dictionaries mapping string attribute names to value. For example, a node
might be ``{ "id": 1, "x": 0.5, "y": 0.5, "t": 0 }``
- Nodes are integers

``GraphObject: TypeAlias = Mapping[str, Any]``
``Node: TypeAlias = int``

- Node IDs may be integers, or a "meta-node" as a tuple of integers.
- Collections of nodes are tuples of ``Node``

``NodeId: TypeAlias = Union[int, tuple[int, ...]]``
``Nodes: TypeAlias = tuple[Node, ...]``

- Edges IDs are tuples of ``NodeId``.
- Edges are 2-tuples of ``Node``.

``EdgeId: TypeAlias = tuple[NodeId, ...]``
``Edge: TypeAlias = tuple[Node, Node]``

- Hyperedges are 2-tuples of ``Nodes``:

``HyperEdge: TypeAlias = tuple[Nodes, Nodes]``

Examples:

- ``(0, 1)`` is an edge from node 0 to node 1.
- ``((0, 1), 2)`` is a hyperedge from nodes 0 and 1 to node 2 (i.e. a merge).
- ``((0,), (1, 2))`` is a hyperedge from node 0 to nodes 1 and 2 (i.e. a split).
- ``((0, 1), 2)`` is a not a valid edge.

- All attributes in a graph (for both ``Node``s and ``(Hyper)Edge``s) are
dictionaries mapping string attribute names to values. For example, a
node's attributes might be ``{ "x": 0.5, "y": 0.5, "t": 0 }``

``Attributes: TypeAlias = Mapping[str, Any]``



Expand Down
22 changes: 13 additions & 9 deletions motile/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

from typing import Any, Mapping, TypeAlias, Union

# Nodes are represented as integers, or a "meta-node" tuple of integers.
NodeId: TypeAlias = Union[int, tuple[int, ...]]
# Nodes are integers
Node: TypeAlias = int
# Collections of nodes (for hyperedges) are tuples
Nodes: TypeAlias = tuple[int, ...]

# objects in the graph are represented as dicts
# eg. { "id": 1, "x": 0.5, "y": 0.5, "t": 0 }
GraphObject: TypeAlias = Mapping[str, Any]

# Edges are represented as tuples of NodeId.
# Edges are tuples of Node or Nodes.
# (0, 1) is an edge from node 0 to node 1.
# ((0, 1), 2) is a hyperedge from nodes 0 and 1 to node 2 (i.e. a merge).
# ((0,), (1, 2)) is a hyperedge from node 0 to nodes 1 and 2 (i.e. a split).
EdgeId: TypeAlias = tuple[NodeId, ...]
# ((0, 1), 2) is not valid.
Edge: TypeAlias = tuple[Node, Node]
HyperEdge: TypeAlias = tuple[Nodes, Nodes]
GenericEdge: TypeAlias = Union[Edge, HyperEdge]

# objects in the graph are represented as dicts
# eg. { "id": 1, "x": 0.5, "y": 0.5, "t": 0 }
Attributes: TypeAlias = Mapping[str, Any]
4 changes: 2 additions & 2 deletions motile/constraints/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from .constraint import Constraint

if TYPE_CHECKING:
from motile._types import EdgeId, GraphObject, NodeId
from motile._types import Attributes, GenericEdge, Node
from motile.solver import Solver

NodesOrEdges = Union[dict[NodeId, GraphObject], dict[EdgeId, GraphObject]]
NodesOrEdges = Union[dict[Node, Attributes], dict[GenericEdge, Attributes]]


class ExpressionConstraint(Constraint):
Expand Down
66 changes: 29 additions & 37 deletions motile/track_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@

if TYPE_CHECKING:
import networkx
from typing_extensions import TypeGuard

from motile._types import EdgeId, GraphObject, NodeId
from motile._types import (
Attributes,
Edge,
GenericEdge,
HyperEdge,
Node,
Nodes,
)


class TrackGraph:
Expand Down Expand Up @@ -44,17 +52,17 @@ def __init__(
self.frame_attribute = frame_attribute
self._graph_changed = True

self.nodes: dict[NodeId, GraphObject] = {}
self.edges: dict[EdgeId, GraphObject] = {}
self.prev_edges: defaultdict[NodeId, list[EdgeId]] = DefaultDict(list)
self.next_edges: defaultdict[NodeId, list[EdgeId]] = DefaultDict(list)
self.nodes: dict[Node, Attributes] = {}
self.edges: dict[GenericEdge, Attributes] = {}
self.prev_edges: defaultdict[Node, list[GenericEdge]] = DefaultDict(list)
self.next_edges: defaultdict[Node, list[GenericEdge]] = DefaultDict(list)

if nx_graph:
self.add_from_nx_graph(nx_graph)

self._update_metadata()

def add_node(self, node_id: NodeId, data: GraphObject) -> None:
def add_node(self, node_id: Node, data: Attributes) -> None:
"""Adds a new node to this TrackGraph.
Args:
Expand All @@ -68,25 +76,25 @@ def add_node(self, node_id: NodeId, data: GraphObject) -> None:
self.nodes[node_id] = data
self._graph_changed = True

def add_edge(self, edge_id: EdgeId, data: GraphObject) -> None:
def add_edge(self, edge_id: GenericEdge, data: Attributes) -> None:
"""Adds an edge to this TrackGraph.
Args:
edge_id: an ``EdgeId`` (tuple of NodeIds) defining the edge
edge_id: an ``GenericEdge`` (tuple of Nodes) defining the edge
(or hyperedge) to be added.
data: all properties associated to the added edge.
"""
self.edges[edge_id] = data

if self.is_hyperedge(edge_id):
us, vs = cast("tuple[tuple[int], tuple[int]]", edge_id)
us, vs = edge_id
for v in vs:
self.prev_edges[v].append(edge_id)
for u in us:
self.next_edges[v].append(edge_id)
else:
# normal (u, v) edge
u, v = cast("tuple[int, int]", edge_id)
u, v = cast("Edge", edge_id)
self.prev_edges[v].append(edge_id)
self.next_edges[u].append(edge_id)

Expand Down Expand Up @@ -128,11 +136,8 @@ def add_from_nx_graph(self, nx_graph: networkx.DiGraph) -> None:
continue
# add hyperedge when nx_edge leads to hyperedge node
if self._is_hyperedge_nx_node(nx_graph, v):
(
edge,
in_nodes,
out_nodes,
) = self._hyperedge_nx_node_to_edge_tuple_and_neighbors(nx_graph, v)
edge = self._convert_nx_hypernode(nx_graph, v)
in_nodes, out_nodes = edge
# avoid adding duplicates
if edge not in self.edges:
self.edges[edge] = data
Expand All @@ -147,7 +152,7 @@ def add_from_nx_graph(self, nx_graph: networkx.DiGraph) -> None:
self.prev_edges[v].append((u, v))
self.next_edges[u].append((u, v))

def nodes_of(self, edge: EdgeId | int) -> Iterator[int]:
def nodes_of(self, edge: GenericEdge | Nodes | Node) -> Iterator[Node]:
"""Returns an ``Iterator`` of node id's that are incident to the given edge.
Args:
Expand All @@ -156,13 +161,15 @@ def nodes_of(self, edge: EdgeId | int) -> Iterator[int]:
Yields:
all nodes incident to the given edge.
"""
# recursively descent into tuples and yield their elements if they are
# not tuples
if isinstance(edge, tuple):
for x in edge:
yield from self.nodes_of(x)
else:
yield edge

def is_hyperedge(self, edge: EdgeId) -> bool:
def is_hyperedge(self, edge: GenericEdge) -> TypeGuard[HyperEdge]:
"""Test if the given edge is a hyperedge in this track graph."""
assert len(edge) == 2, "(Hyper)edges need to be 2-tuples"
num_tuples = sum(isinstance(x, tuple) for x in edge)
Expand Down Expand Up @@ -191,9 +198,9 @@ def _is_hyperedge_nx_node(self, nx_graph: networkx.DiGraph, nx_node: Any) -> boo
"""
return self.frame_attribute not in nx_graph.nodes[nx_node]

def _hyperedge_nx_node_to_edge_tuple_and_neighbors(
def _convert_nx_hypernode(
self, nx_graph: networkx.DiGraph, hyperedge_node: Any
) -> tuple[tuple[NodeId, ...], list[NodeId], list[NodeId]]:
) -> HyperEdge:
"""Creates a hyperedge tuple for hyperedge node in a given networkx ``DiGraph``.
Args:
Expand All @@ -208,25 +215,10 @@ def _hyperedge_nx_node_to_edge_tuple_and_neighbors(
"""
assert self._is_hyperedge_nx_node(nx_graph, hyperedge_node)

in_nodes = list(nx_graph.predecessors(hyperedge_node))
out_nodes = list(nx_graph.successors(hyperedge_node))
nx_nodes = in_nodes + out_nodes

frameset = {
nx_graph.nodes[nx_node][self.frame_attribute] for nx_node in nx_nodes
}
frames = list(sorted(frameset))

edge_tuple = tuple(
tuple(
node
for node in nx_nodes
if nx_graph.nodes[node][self.frame_attribute] == frame
)
for frame in frames
)
in_nodes = tuple(nx_graph.predecessors(hyperedge_node))
out_nodes = tuple(nx_graph.successors(hyperedge_node))

return edge_tuple, in_nodes, out_nodes
return (in_nodes, out_nodes)

def get_frames(self) -> tuple[int | None, int | None]:
"""Return tuple with first and last (exclusive) frame this graph has nodes for.
Expand Down
6 changes: 3 additions & 3 deletions motile/variables/edge_selected.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
if TYPE_CHECKING:
import ilpy

from motile._types import EdgeId
from motile._types import GenericEdge
from motile.solver import Solver


class EdgeSelected(Variable["EdgeId"]):
class EdgeSelected(Variable["GenericEdge"]):
"""Binary variable indicates whether an edge is part of the solution or not."""

@staticmethod
def instantiate(solver: Solver) -> Collection[EdgeId]:
def instantiate(solver: Solver) -> Collection[GenericEdge]:
return solver.graph.edges

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions motile/variables/node_appear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
if TYPE_CHECKING:
import ilpy

from motile._types import NodeId
from motile._types import Node
from motile.solver import Solver


class NodeAppear(Variable["NodeId"]):
class NodeAppear(Variable["Node"]):
r"""Binary variable indicating whether a node is the start of a track.
(i.e., the node is selected and has no selected incoming edges).
Expand All @@ -35,7 +35,7 @@ class NodeAppear(Variable["NodeId"]):
"""

@staticmethod
def instantiate(solver: Solver) -> Collection[NodeId]:
def instantiate(solver: Solver) -> Collection[Node]:
return solver.graph.nodes

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions motile/variables/node_disappear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
if TYPE_CHECKING:
import ilpy

from motile._types import NodeId
from motile._types import Node
from motile.solver import Solver


class NodeDisappear(Variable["NodeId"]):
class NodeDisappear(Variable["Node"]):
r"""Binary variable to indicate whether a node disappears.
This variable indicates whether the node is the end of a track (i.e., the node is
Expand All @@ -35,7 +35,7 @@ class NodeDisappear(Variable["NodeId"]):
"""

@staticmethod
def instantiate(solver: Solver) -> Collection[NodeId]:
def instantiate(solver: Solver) -> Collection[Node]:
return solver.graph.nodes

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions motile/variables/node_selected.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from .variable import Variable

if TYPE_CHECKING:
from motile._types import NodeId
from motile._types import Node
from motile.solver import Solver


class NodeSelected(Variable["NodeId"]):
class NodeSelected(Variable["Node"]):
"""Binary variable indicating whether a node is part of the solution or not."""

@staticmethod
def instantiate(solver: Solver) -> Collection[NodeId]:
def instantiate(solver: Solver) -> Collection[Node]:
return solver.graph.nodes
4 changes: 2 additions & 2 deletions motile/variables/node_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .variable import Variable

if TYPE_CHECKING:
from motile._types import NodeId
from motile._types import Node
from motile.solver import Solver


Expand All @@ -32,7 +32,7 @@ class NodeSplit(Variable):
"""

@staticmethod
def instantiate(solver: Solver) -> Collection[NodeId]:
def instantiate(solver: Solver) -> Collection[Node]:
return solver.graph.nodes

@staticmethod
Expand Down

0 comments on commit 8ad0ffa

Please sign in to comment.