In [32]:
from bitarray import bitarray
from src.core.tree import HoffmanTree
from src.lossless.encoders.hoffman_encoder import HuffmanEncoder
from src.lossless.decoders.base_decoder import Decoder, Symbol
from src.core.dist import Dist
from typing import Tuple

class HuffmanDecodingException(Exception):
    
    def __init__(self):
        super().__init__("Unable to decode sequence of bits.")

class HuffmanDecoder(Decoder):

    def __init__(self, tree:HoffmanTree):
        self.tree = tree

    def __call__(self, bits:bitarray, start:int=0) -> Tuple[Symbol, int]:
        node = self.tree.root
        if node is None:
            raise ValueError("Cannot decode with empty Huffman Tree")

        pos = start
        while not node.is_leaf_node():
            bit = bits[pos]
            node = node.left if bit == 0 else node.right
            pos += 1
            if node is None:
                raise HuffmanDecodingException()
        return node.symbol, pos

In [18]:
from typing import Iterable, Dict
from src.core.node import Node
from src.core.dist import Dist
from bitarray import bitarray

class HoffmanTree:

    def __init__(self, dist:Dist):
        self.root = self.create_tree(dist)
        self._encoding_table = {}

    @staticmethod
    def create_tree(dist:Iterable) -> Node:
        if not dist:
            return
        list_of_nodes = [Node(symbol=symbol, count=count) for symbol,count in dist.items()]
        print(sorted(list_of_nodes, reverse=True))
        while len(list_of_nodes) > 1:
            list_of_nodes = sorted(list_of_nodes, reverse=True)
            print(list_of_nodes)
            n1 = list_of_nodes.pop()
            n2 = list_of_nodes.pop()
            merged_node = Node(symbol=".", count=n1.count + n2.count)
            merged_node.left = n1
            merged_node.right = n2
            list_of_nodes.append(merged_node)
            print(list_of_nodes)
        return list_of_nodes.pop()

    def get_encoding_table(self) -> Dict[str, bitarray]:
        encoding_table = {}

        def _dfs(tree:Node, code:str) -> None:
            if tree.is_leaf_node():
                # Must handle edge case where tree is only root node
                code = code if code else "0"
                encoding_table[tree.symbol] = bitarray(code)
                return
            if tree.left is not None:
                _dfs(tree.left, code + "0")
            if tree.right is not None:
                _dfs(tree.right, code + "1")

        if self.root is not None:
            _dfs(self.root, code="")
        return encoding_table

    def max_depth(self) -> int:
        global max_depth
        max_depth = 0
        def _dfs(tree:Node, level:int) -> None:
            if tree.is_leaf_node():
                global max_depth
                max_depth = max(level, max_depth)
                return 
            if tree.left is not None:
                _dfs(tree.left, level + 1)
            if tree.right is not None:
                _dfs(tree.right, level + 1)
        if self.root is not None:
            _dfs(self.root, 0)
        return max_depth

In [27]:
tree = HoffmanTree(Dist("AAAAAAABBC"))

[Node(count=0.7, symbol='A'), Node(count=0.2, symbol='B'), Node(count=0.1, symbol='C')]
[Node(count=0.7, symbol='A'), Node(count=0.2, symbol='B'), Node(count=0.1, symbol='C')]
[Node(count=0.7, symbol='A'), Node(count=0.30000000000000004, symbol='.')]
[Node(count=0.7, symbol='A'), Node(count=0.30000000000000004, symbol='.')]
[Node(count=1.0, symbol='.')]


In [36]:
decoder = HuffmanDecoder(tree)
encoder = HuffmanEncoder(Dist("AAAAAAABBC"))

encoder("A")
encoder("C")

bitarray('00')

In [30]:
tree.get_encoding_table()

{'C': bitarray('00'), 'B': bitarray('01'), 'A': bitarray('1')}