Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
216 lines (181 sloc) 6.2 KB
"""
This library contains useful functionality for working with Huffman trees
and codes.
Note:
There is function for directly creating a Huffman code from a frequency
map: bitarray.util.huffman_code()
"""
from __future__ import print_function
from heapq import heappush, heappop
from bitarray import bitarray
class Node(object):
def __init__(self):
self.child = [None, None]
self.symbol = None
self.freq = None
def __lt__(self, other):
# heapq needs to be able to compare the nodes
return self.freq < other.freq
def huff_tree(freq):
"""
Given a dictionary mapping symbols to thier frequency, construct a Huffman
tree and return its root node.
"""
minheap = []
# create all the leaf nodes and push them onto the queue
for sym in sorted(freq):
nd = Node()
nd.symbol = sym
nd.freq = freq[sym]
heappush(minheap, nd)
# repeat the process until only one node remains
while len(minheap) > 1:
# take the nodes with smallest frequencies from the queue
child_0 = heappop(minheap)
child_1 = heappop(minheap)
# construct the new internal node and push it onto the queue
parent = Node()
parent.child = [child_0, child_1]
parent.freq = child_0.freq + child_1.freq
heappush(minheap, parent)
# return the one remaining node, which is the root of the Huffman tree
return minheap[0]
def huff_code(tree):
"""
Given a Huffman tree, traverse the tree and return the Huffman code, i.e.
a dictionary mapping symbols to bitarrays.
"""
result = {}
def traverse(nd, prefix=bitarray()):
if nd.symbol is None: # parent, so traverse each of the children
traverse(nd.child[0], prefix + bitarray([0]))
traverse(nd.child[1], prefix + bitarray([1]))
else: # leaf
result[nd.symbol] = prefix
traverse(tree)
return result
def insert_symbol(tree, ba, sym):
"""
Insert symbol into a tree at the position described by the bitarray,
creating nodes as necessary.
"""
if sym is None:
raise ValueError("symbol cannot be None")
nd = tree
for k in ba:
prev = nd
nd = nd.child[k]
if nd and nd.symbol is not None:
raise ValueError("ambiguity")
if not nd:
nd = Node()
prev.child[k] = nd
if nd.symbol is not None or nd.child[0] or nd.child[1]:
raise ValueError("ambiguity")
nd.symbol = sym
def make_tree(codedict):
"""
Create a tree from the given code dictionary, and return its root node.
Unlike trees created by huff_tree, all nodes will have .freq set to None.
"""
tree = Node()
for sym, ba in codedict.items():
insert_symbol(tree, ba, sym)
return tree
def traverse(tree, it):
"""
Traverse tree until a leaf node is reached, and return its symbol.
This function consumes an iterator on which next() is called during each
step of traversing.
"""
nd = tree
while 1:
nd = nd.child[next(it)]
if not nd:
raise ValueError("prefix code does not match data in bitarray")
if nd.symbol is not None:
return nd.symbol
if nd != tree:
raise ValueError("decoding not terminated")
def iterdecode(tree, bitsequence):
"""
Given a tree and a bitsequence, decode the bitsequence and generate
the symbols.
"""
it = iter(bitsequence)
while True:
try:
yield traverse(tree, it)
except StopIteration:
break
def write_dot(tree, fn, binary=False):
"""
Given a tree (which may or may not contain frequencies), write
a graphviz '.dot' file with a visual representation of the tree.
"""
special_ascii = {' ': 'SPACE', '\n': 'LF', '\r': 'CR', '\t': 'TAB',
'\\': r'\\', '"': r'\"'}
def disp_sym(i):
if binary:
return '0x%02x' % i
else:
c = chr(i)
res = special_ascii.get(c, c)
assert res.strip(), repr(c)
return res
def disp_freq(f):
if f is None:
return ''
return '%d' % f
with open(fn, 'w') as fo: # dot -Tpng tree.dot -O
def write_nd(fo, nd):
if nd.symbol is not None: # leaf node
a, b = disp_freq(nd.freq), disp_sym(nd.symbol)
fo.write(' %d [label="%s%s%s"];\n' %
(id(nd), a, ': ' if a and b else '', b))
else: # parent node
fo.write(' %d [shape=circle, style=filled, '
'fillcolor=grey, label="%s"];\n' %
(id(nd), disp_freq(nd.freq)))
for k in range(2):
if nd.child[k]:
fo.write(' %d->%d;\n' % (id(nd), id(nd.child[k])))
for k in range(2):
if nd.child[k]:
write_nd(fo, nd.child[k])
fo.write('digraph BT {\n')
fo.write(' node [shape=box, fontsize=20, fontname="Arial"];\n')
write_nd(fo, tree)
fo.write('}\n')
def print_code(freq, codedict):
"""
Given a frequency map (dictionary mapping symbols to thier frequency)
and a codedict, print them in a readable form.
"""
special_ascii = {0: 'NUL', 9: 'TAB', 10: 'LF', 13: 'CR', 127: 'DEL'}
def disp_char(i):
if 32 <= i < 127:
return repr(chr(i))
return special_ascii.get(i, '')
print(' symbol char hex frequency Huffman code')
print(70 * '-')
for i in sorted(codedict, key=lambda c: (freq[c], c), reverse=True):
print('%7r %-4s 0x%02x %10i %s' % (
i, disp_char(i), i, freq[i], codedict[i].to01()))
def test():
freq = {'a': 10, 'b': 2, 'c': 1}
tree = huff_tree(freq)
code = huff_code(tree)
assert len(code['a']) == 1
assert len(code['b']) == len(code['c']) == 2
code = {'a': bitarray('0'),
'b': bitarray('10'),
'c': bitarray('11')}
tree = make_tree(code)
txt = 'abca'
a = bitarray()
a.encode(code, txt)
assert a == bitarray('010110')
assert list(iterdecode(tree, a)) == ['a', 'b', 'c', 'a']
if __name__ == '__main__':
test()
You can’t perform that action at this time.