Skip to content

Commit

Permalink
Merge pull request #45 from Miguellissimo/master
Browse files Browse the repository at this point in the history
Export tree to the dot format
  • Loading branch information
Xiaming committed Jul 18, 2015
2 parents c7d4804 + a7deb98 commit 90d61a7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 4 deletions.
83 changes: 83 additions & 0 deletions tests/test_plugins.py
@@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from treelib import Tree
from treelib.plugins import *
import os
import unittest

class DotExportCase(unittest.TestCase):
"""Test class for the export to dot format function"""

def setUp(self):
tree = Tree()
tree.create_node("Hárry", "hárry")
tree.create_node("Jane", "jane", parent="hárry")
tree.create_node("Bill", "bill", parent="hárry")
tree.create_node("Diane", "diane", parent="jane")
tree.create_node("George", "george", parent="bill")
self.tree = tree

def read_generated_output(self, filename):
output = codecs.open(filename, 'r', 'utf-8')
generated = output.read()
output.close()

return generated

def test_export_to_dot(self):
export_to_dot(self.tree, 'tree.dot')
expected = """\
digraph tree {
\thárry [label="Hárry", shape=circle]
\tbill [label="Bill", shape=circle]
\tjane [label="Jane", shape=circle]
\tgeorge [label="George", shape=circle]
\tdiane [label="Diane", shape=circle]
\thárry -> jane
\thárry -> bill
\tbill -> george
\tjane -> diane
}"""

self.assertTrue(os.path.isfile('tree.dot'), "The file tree.dot could not be found.")
generated = self.read_generated_output('tree.dot')

self.assertEqual(generated, expected, "Generated dot tree is not the expected one")
os.remove('tree.dot')

def test_export_to_dot_empty_tree(self):
empty_tree = Tree()
export_to_dot(empty_tree, 'tree.dot')

expected = """\
digraph tree {
}"""
self.assertTrue(os.path.isfile('tree.dot'), "The file tree.dot could not be found.")
generated = self.read_generated_output('tree.dot')

self.assertEqual(expected, generated, 'The generated output for an empty tree is not empty')
os.remove('tree.dot')

def test_unicode_filename(self):
tree = Tree()
tree.create_node('Node 1', 'node_1')
export_to_dot(tree, 'ŕʩϢ.dot')

expected = """\
digraph tree {
\tnode_1 [label="Node 1", shape=circle]
}"""
self.assertTrue(os.path.isfile('ŕʩϢ.dot'), "The file ŕʩϢ.dot could not be found.")
generated = self.read_generated_output('ŕʩϢ.dot')
self.assertEqual(expected, generated, "The generated file content is not the expected one")
os.remove('ŕʩϢ.dot')

def tearDown(self):
self.tree = None

if __name__ == "__main__":
unittest.main()
6 changes: 3 additions & 3 deletions tests/test_treelib.py
Expand Up @@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import sys
import os
import codecs
try:
from StringIO import StringIO as BytesIO
except ImportError:
Expand Down Expand Up @@ -275,8 +277,7 @@ def test_show(self):
def tearDown(self):
self.tree = None
self.copytree = None



def suite():
suites = [NodeCase, TreeCase]
suite = unittest.TestSuite()
Expand All @@ -285,6 +286,5 @@ def suite():
return suite



if __name__ == '__main__':
unittest.main(defaultTest='suite')
35 changes: 34 additions & 1 deletion treelib/plugins.py
@@ -1,5 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""This is a public location to maintain contributed
utlities to extend the basic Tree class.
utilities to extend the basic Tree class.
"""
from __future__ import unicode_literals
import codecs

def export_to_dot(tree, filename, shape='circle', graph='digraph'):
"""Exports the tree in the dot format of the graphviz software"""

nodes, connections = [], []
if tree.nodes:

for n in tree.expand_tree(mode=tree.WIDTH):
nid = tree[n].identifier
state = nid + ' [label="' + tree[n].tag + '", shape=' + shape + ']'
nodes.append(state)

for c in tree.children(nid):
cid = c.identifier

connections.append(nid + ' -> ' + cid)

# write nodes and connections to dot format
with codecs.open(filename, 'w', 'utf-8') as f:
f.write(graph +' tree {\n')
for n in nodes:
f.write('\t' + n + '\n')

f.write('\n')
for c in connections:
f.write('\t' + c + '\n')

f.write('}')

if __name__ == '__main__':
pass

0 comments on commit 90d61a7

Please sign in to comment.