Skip to content

Commit

Permalink
make AdjM generic
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jul 11, 2019
1 parent ba06899 commit f3f03ec
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions WDL/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pyre-strict
# misc utility functions...

from typing import Tuple, Dict, Set, Iterable, List
from typing import Tuple, Dict, Set, Iterable, List, TypeVar, Generic


def strip_leading_whitespace(txt: str) -> Tuple[int, str]:
Expand Down Expand Up @@ -30,38 +30,41 @@ def strip_leading_whitespace(txt: str) -> Tuple[int, str]:
return (to_strip, "\n".join(lines))


class AdjM:
T = TypeVar("T")


class AdjM(Generic[T]):
# A sparse adjacency matrix for topological sorting
# which we should not have implemented ourselves
_forward: Dict[int, Set[int]]
_reverse: Dict[int, Set[int]]
_unconstrained: Set[int]
_forward: Dict[T, Set[T]]
_reverse: Dict[T, Set[T]]
_unconstrained: Set[T]

def __init__(self) -> None:
self._forward = dict()
self._reverse = dict()
self._unconstrained = set()

def sinks(self, source: int) -> Iterable[int]:
def sinks(self, source: T) -> Iterable[T]:
for sink in self._forward.get(source, []):
yield sink

def sources(self, sink: int) -> Iterable[int]:
def sources(self, sink: T) -> Iterable[T]:
for source in self._reverse.get(sink, []):
yield source

@property
def nodes(self) -> Iterable[int]:
def nodes(self) -> Iterable[T]:
for node in self._forward:
yield node

@property
def unconstrained(self) -> Iterable[int]:
def unconstrained(self) -> Iterable[T]:
for n in self._unconstrained:
assert not self._reverse[n]
yield n

def add_node(self, node: int) -> None:
def add_node(self, node: T) -> None:
if node not in self._forward:
assert node not in self._reverse
self._forward[node] = set()
Expand All @@ -70,7 +73,7 @@ def add_node(self, node: int) -> None:
else:
assert node in self._reverse

def add_edge(self, source: int, sink: int) -> None:
def add_edge(self, source: T, sink: T) -> None:
self.add_node(source)
self.add_node(sink)
if sink not in self._forward[source]:
Expand All @@ -82,7 +85,7 @@ def add_edge(self, source: int, sink: int) -> None:
assert source in self._reverse[sink]
assert sink not in self._unconstrained

def remove_edge(self, source: int, sink: int) -> None:
def remove_edge(self, source: T, sink: T) -> None:
if source in self._forward and sink in self._forward[source]:
self._forward[source].remove(sink)
self._reverse[sink].remove(source)
Expand All @@ -91,7 +94,7 @@ def remove_edge(self, source: int, sink: int) -> None:
else:
assert not (sink in self._reverse and source in self._reverse[sink])

def remove_node(self, node: int) -> None:
def remove_node(self, node: T) -> None:
for source in list(self.sources(node)):
self.remove_edge(source, node)
for sink in list(self.sinks(node)):
Expand All @@ -101,7 +104,7 @@ def remove_node(self, node: int) -> None:
self._unconstrained.remove(node)


def topsort(adj: AdjM) -> List[int]:
def topsort(adj: AdjM[T]) -> List[T]:
# topsort node IDs in adj (destroys adj)
# if there's a cycle, raises err: StopIteration with err.node = ID of a
# node involved in a cycle.
Expand Down

0 comments on commit f3f03ec

Please sign in to comment.