Skip to content

Commit

Permalink
Graphviz and asciitree improvements (#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed May 15, 2023
2 parents 73be2ab + eaabf17 commit 247a765
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 72 deletions.
36 changes: 18 additions & 18 deletions nutils/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __bool__(self) -> bool:
raise NotImplementedError # pragma: no cover

@abc.abstractmethod
def _generate_asciitree_nodes(self, cache: MutableMapping['Node[Metadata]', str], subgraph_ids: Mapping[Optional[Subgraph], str], id_gen: Iterator[str], select: str, bridge: str) -> Generator[str, None, None]:
def _generate_asciitree_nodes(self, cache: MutableMapping['Node[Metadata]', str], id_gen_map: Mapping[Optional[Subgraph], Iterator[str]], select: str, bridge: str) -> Generator[str, None, None]:
raise NotImplementedError # pragma: no cover

@abc.abstractmethod
Expand All @@ -41,12 +41,12 @@ def walk(self, seen: MutableSet['Node[Metadata]']) -> Iterator['Node[Metadata]']
def generate_asciitree(self, richoutput: bool = False) -> str:
subgraph_children = _collect_subgraphs(self)
if len(subgraph_children) > 1:
subgraph_ids = {} # type: Dict[Optional[Subgraph], str]
parts = ['SUBGRAPHS\n'], _generate_asciitree_subgraphs(subgraph_children, subgraph_ids, None, '', ''), ['NODES\n'] # type: Sequence[Iterable[str]]
id_gen_map = {} # type: Dict[Optional[Subgraph], Iterator[str]]
parts = ['SUBGRAPHS\n'], _generate_asciitree_subgraphs(subgraph_children, id_gen_map, None, '', ''), ['NODES\n'] # type: Sequence[Iterable[str]]
else:
subgraph_ids = {None: ''}
id_gen_map = {None: (f'%{i}' for i in itertools.count())}
parts = []
asciitree = ''.join(itertools.chain(*parts, self._generate_asciitree_nodes({}, subgraph_ids, map(str, itertools.count()), '', '')))
asciitree = ''.join(itertools.chain(*parts, self._generate_asciitree_nodes({}, id_gen_map, '', '')))
if not richoutput:
asciitree = asciitree.replace('├', ':').replace('└', ':').replace('│', '|')
return asciitree
Expand All @@ -57,9 +57,9 @@ def generate_graphviz_source(self, *, fill_color: Optional[GraphvizColorCallback
subgraph_children = _collect_subgraphs(self)
id_gen = map(str, itertools.count())
self._collect_graphviz_nodes_edges({}, id_gen, nodes, edges, None, fill_color)
return ''.join(itertools.chain(['digraph {graph [dpi=72];'], _generate_graphviz_subgraphs(subgraph_children, nodes, None, id_gen), edges, ['}']))
return ''.join(itertools.chain(['digraph {bgcolor="darkgray";'], _generate_graphviz_subgraphs(subgraph_children, nodes, None, id_gen), edges, ['}']))

def export_graphviz(self, *, fill_color: Optional[GraphvizColorCallback] = None, dot_path: str = 'dot', image_type: str = 'png') -> None:
def export_graphviz(self, *, fill_color: Optional[GraphvizColorCallback] = None, dot_path: str = 'dot', image_type: str = 'svg') -> None:
src = self.generate_graphviz_source(fill_color=fill_color)
with treelog.infofile('dot.'+image_type, 'wb') as img:
src = src.replace(';', ';\n')
Expand All @@ -82,16 +82,15 @@ def __init__(self, label: str, args: Sequence[Node[Metadata]], kwargs: Mapping[s
def __bool__(self) -> bool:
return True

def _generate_asciitree_nodes(self, cache: MutableMapping[Node[Metadata], str], subgraph_ids: Mapping[Optional[Subgraph], str], id_gen: Iterator[str], select: str, bridge: str) -> Generator[str, None, None]:
def _generate_asciitree_nodes(self, cache: MutableMapping[Node[Metadata], str], id_gen_map: Mapping[Optional[Subgraph], str], select: str, bridge: str) -> Generator[str, None, None]:
if self in cache:
yield '{}{}\n'.format(select, cache[self])
else:
subgraph_id = subgraph_ids[self.subgraph]
cache[self] = id = '%{}{}'.format(subgraph_id, next(id_gen))
cache[self] = id = next(id_gen_map[self.subgraph])
yield '{}{} = {}\n'.format(select, id, self._label.replace('\n', '; '))
args = tuple(('', arg) for arg in self._args if arg) + tuple(('{} = '.format(name), arg) for name, arg in self._kwargs.items())
for i, (prefix, arg) in enumerate(args, 1-len(args)):
yield from arg._generate_asciitree_nodes(cache, subgraph_ids, id_gen, bridge+('├ ' if i else '└ ')+prefix, bridge+('│ ' if i else ' '))
yield from arg._generate_asciitree_nodes(cache, id_gen_map, bridge+('├ ' if i else '└ ')+prefix, bridge+('│ ' if i else ' '))

def _collect_graphviz_nodes_edges(self, cache: MutableMapping[Node[Metadata], str], id_gen: Iterator[str], nodes: MutableMapping[Optional[Subgraph], List[str]], edges: List[str], parent_subgraph: Optional[Subgraph], fill_color: Optional[GraphvizColorCallback] = None) -> Optional[str]:
if self in cache:
Expand Down Expand Up @@ -137,7 +136,7 @@ def __init__(self, label: str, metadata: Metadata) -> None:
def __bool__(self) -> bool:
return True

def _generate_asciitree_nodes(self, cache: MutableMapping[Node[Metadata], str], subgraph_ids: Mapping[Optional[Subgraph], str], id_gen: Iterator[str], select: str, bridge: str) -> Generator[str, None, None]:
def _generate_asciitree_nodes(self, cache: MutableMapping[Node[Metadata], str], id_gen_map: Mapping[Optional[Subgraph], str], select: str, bridge: str) -> Generator[str, None, None]:
yield '{}{}\n'.format(select, self._label.replace('\n', '; '))

def _collect_graphviz_nodes_edges(self, cache: MutableMapping[Node[Metadata], str], id_gen: Iterator[str], nodes: MutableMapping[Optional[Subgraph], List[str]], edges: List[str], parent_subgraph: Optional[Subgraph], fill_color: Optional[GraphvizColorCallback] = None) -> Optional[str]:
Expand All @@ -161,7 +160,7 @@ def __init__(self, metadata: Metadata) -> None:
def __bool__(self) -> bool:
return False

def _generate_asciitree_nodes(self, cache: MutableMapping[Node[Metadata], str], subgraph_ids: Mapping[Optional[Subgraph], str], id_gen: Iterator[str], select: str, bridge: str) -> Generator[str, None, None]:
def _generate_asciitree_nodes(self, cache: MutableMapping[Node[Metadata], str], id_gen_map: Mapping[Optional[Subgraph], str], select: str, bridge: str) -> Generator[str, None, None]:
yield '{}\n'.format(select)

def _collect_graphviz_nodes_edges(self, cache: MutableMapping[Node[Metadata], str], id_gen: Iterator[str], nodes: MutableMapping[Optional[Subgraph], List[str]], edges: List[str], parent_subgraph: Optional[Subgraph], fill_color: Optional[GraphvizColorCallback] = None) -> Optional[str]:
Expand Down Expand Up @@ -197,20 +196,21 @@ def _collect_subgraphs(node: Node[Metadata]) -> Dict[Optional[Subgraph], List[Su
return children


def _generate_asciitree_subgraphs(children: Mapping[Optional[Subgraph], Sequence[Subgraph]], subgraph_ids: MutableMapping[Optional[Subgraph], str], subgraph: Optional[Subgraph], select: str, bridge: str) -> Iterator[str]:
assert subgraph not in subgraph_ids
subgraph_ids[subgraph] = id = chr(ord('A') + len(subgraph_ids))
def _generate_asciitree_subgraphs(children: Mapping[Optional[Subgraph], Sequence[Subgraph]], id_gen_map: MutableMapping[Optional[Subgraph], Iterator[str]], subgraph: Optional[Subgraph], select: str, bridge: str) -> Iterator[str]:
assert subgraph not in id_gen_map
id = chr(ord('A') + len(id_gen_map))
id_gen_map[subgraph] = (f'%{id}{i}' for i in itertools.count())
if subgraph:
yield '{}{} = {}\n'.format(select, id, subgraph.label.replace('\n', '; '))
else:
yield '{}{}\n'.format(select, id)
for i, child in enumerate(children[subgraph], 1-len(children[subgraph])):
yield from _generate_asciitree_subgraphs(children, subgraph_ids, child, bridge+('├ ' if i else '└ '), bridge+('│ ' if i else ' '))
yield from _generate_asciitree_subgraphs(children, id_gen_map, child, bridge+('├ ' if i else '└ '), bridge+('│ ' if i else ' '))


def _generate_graphviz_subgraphs(children: Mapping[Optional[Subgraph], Sequence[Subgraph]], nodes: Mapping[Optional[Subgraph], Sequence[str]], subgraph: Optional[Subgraph], id_gen: Iterator[str]) -> Iterator[str]:
for child in children[subgraph]:
yield 'subgraph cluster{} {{'.format(next(id_gen))
yield 'subgraph cluster{} {{bgcolor="lightgray";color="none";'.format(next(id_gen))
yield from _generate_graphviz_subgraphs(children, nodes, child, id_gen)
yield '}'
yield from nodes.get(subgraph, ())
15 changes: 13 additions & 2 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,18 @@ def __index__(self):
__mod__ = lambda self, other: mod(self, other)
__int__ = __index__
__str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str))
_shape_str = lambda self, form: '{}:{}'.format(self.dtype.__name__[0] if hasattr(self, 'dtype') else '?', ','.join(str(int(length)) if length.isconstant else '?' for length in self.shape) if hasattr(self, 'shape') else '?')

def _shape_str(self, form):
dtype = self.dtype.__name__[0] if hasattr(self, 'dtype') else '?'
shape = [str(n.__index__()) if n.isconstant else '?' for n in self.shape]
for i in set(range(self.ndim)) - set(self._unaligned[1]):
shape[i] = f'({shape[i]})'
for i, _ in self._inflations:
shape[i] = f'~{shape[i]}'
for axes in self._diagonals:
for i in axes:
shape[i] = f'{shape[i]}/'
return f'{dtype}:{",".join(shape)}'

sum = sum
prod = product
Expand Down Expand Up @@ -4990,7 +5001,7 @@ def einsum(fmt, *args, **dims):
>>> a45 = ones(tuple(map(constant, [4,5]))) # 4x5 matrix
>>> einsum('ij->ji', a45)
nutils.evaluable.Transpose<f:5,4>
nutils.evaluable.Transpose<f:(5),(4)>
Axis labels that do not occur in the return value are summed. For example,
the following performs a matrix-vector product:
Expand Down
61 changes: 31 additions & 30 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,17 +858,18 @@ class asciitree(TestCase):

@unittest.skipIf(sys.version_info < (3, 6), 'test requires dicts maintaining insertion order')
def test_asciitree(self):
f = evaluable.Sin((evaluable.Zeros((), int))**evaluable.Diagonalize(evaluable.Argument('arg', (evaluable.constant(2),))))
n = evaluable.constant(2)
f = evaluable.Sin(evaluable.InsertAxis(evaluable.Inflate(evaluable.constant(1.), evaluable.constant(1), n), n)**evaluable.Diagonalize(evaluable.Argument('arg', (n,))))
self.assertEqual(f.asciitree(richoutput=True),
'%0 = Sin; f:2,2\n'
'└ %1 = Power; f:2,2\n'
' ├ %2 = InsertAxis; f:2,2\n'
' │ ├ %3 = InsertAxis; f:2\n'
' │ │ ├ %4 = IntToFloat; f:\n'
' │ │ │ └ 0\n'
' ├ %2 = InsertAxis; f:~2,(2)\n'
' │ ├ %3 = Inflate; f:~2\n'
' │ │ ├ 1.0\n'
' │ │ ├ 1\n'
' │ │ └ 2\n'
' │ └ 2\n'
' └ %5 = Diagonalize; f:2,2\n'
' └ %4 = Diagonalize; f:2/,2/\n'
' └ Argument; arg; f:2\n')

@unittest.skipIf(sys.version_info < (3, 6), 'test requires dicts maintaining insertion order')
Expand All @@ -894,23 +895,23 @@ def test_loop_concatenate(self):
'└ B = Loop\n'
'NODES\n'
'%B0 = LoopConcatenate\n'
'├ shape[0] = %A1 = Take; i:; [2,2]\n'
'│ ├ %A2 = _SizesToOffsets; i:3; [0,2]\n'
'│ │ └ %A3 = InsertAxis; i:2; [1,1]\n'
'├ shape[0] = %A0 = Take; i:; [2,2]\n'
'│ ├ %A1 = _SizesToOffsets; i:3; [0,2]\n'
'│ │ └ %A2 = InsertAxis; i:(2); [1,1]\n'
'│ │ ├ 1\n'
'│ │ └ 2\n'
'│ └ 2\n'
'├ start = %B4 = Take; i:; [0,2]\n'
'│ ├ %A2\n'
'│ └ %B5 = LoopIndex\n'
'├ start = %B1 = Take; i:; [0,2]\n'
'│ ├ %A1\n'
'│ └ %B2 = LoopIndex\n'
'│ └ length = 2\n'
'├ stop = %B6 = Take; i:; [0,2]\n'
'│ ├ %A2\n'
'│ └ %B7 = Add; i:; [1,2]\n'
'│ ├ %B5\n'
'├ stop = %B3 = Take; i:; [0,2]\n'
'│ ├ %A1\n'
'│ └ %B4 = Add; i:; [1,2]\n'
'│ ├ %B2\n'
'│ └ 1\n'
'└ func = %B8 = InsertAxis; i:1; [0,1]\n'
' ├ %B5\n'
'└ func = %B5 = InsertAxis; i:(1); [0,1]\n'
' ├ %B2\n'
' └ 1\n')

@unittest.skipIf(sys.version_info < (3, 6), 'test requires dicts maintaining insertion order')
Expand All @@ -923,23 +924,23 @@ def test_loop_concatenatecombined(self):
'└ B = Loop\n'
'NODES\n'
'%B0 = LoopConcatenate\n'
'├ shape[0] = %A1 = Take; i:; [2,2]\n'
'│ ├ %A2 = _SizesToOffsets; i:3; [0,2]\n'
'│ │ └ %A3 = InsertAxis; i:2; [1,1]\n'
'├ shape[0] = %A0 = Take; i:; [2,2]\n'
'│ ├ %A1 = _SizesToOffsets; i:3; [0,2]\n'
'│ │ └ %A2 = InsertAxis; i:(2); [1,1]\n'
'│ │ ├ 1\n'
'│ │ └ 2\n'
'│ └ 2\n'
'├ start = %B4 = Take; i:; [0,2]\n'
'│ ├ %A2\n'
'│ └ %B5 = LoopIndex\n'
'├ start = %B1 = Take; i:; [0,2]\n'
'│ ├ %A1\n'
'│ └ %B2 = LoopIndex\n'
'│ └ length = 2\n'
'├ stop = %B6 = Take; i:; [0,2]\n'
'│ ├ %A2\n'
'│ └ %B7 = Add; i:; [1,2]\n'
'│ ├ %B5\n'
'├ stop = %B3 = Take; i:; [0,2]\n'
'│ ├ %A1\n'
'│ └ %B4 = Add; i:; [1,2]\n'
'│ ├ %B2\n'
'│ └ 1\n'
'└ func = %B8 = InsertAxis; i:1; [0,1]\n'
' ├ %B5\n'
'└ func = %B5 = InsertAxis; i:(1); [0,1]\n'
' ├ %B2\n'
' └ 1\n')


Expand Down

0 comments on commit 247a765

Please sign in to comment.