In [None]:
from heapq import heappush, heappop
from typing import Dict
from bitarray import bitarray


class Node:
    def __init__(self, weight: int, letter: str = None, left: 'Node' = None, right: 'Node' = None) -> None:
        super().__init__()
        self.weight = weight
        self.letter = letter
        self.left = left
        self.right = right

    def create_dict(self, dictionary: Dict, code: str = ''):
        if self.letter:
            dictionary[self.letter] = bitarray(code)
        if self.left:
            self.left.create_dict(dictionary, code + '0')
        if self.right:
            self.right.create_dict(dictionary, code + '1')

    def __lt__(self, other):
        if not isinstance(other, type(self)):
            return NotImplemented
        else:
            return self.weight < other.weight

    def __le__(self, other):
        if not isinstance(other, type(self)):
            return NotImplemented
        else:
            return self.weight <= other.weight

    def __repr__(self):
        return f"{self.weight} : {self.letter if self.letter is not None else '_'}"


class HuffmanTree:
    def __init__(self, letters: Dict[str, int]) -> None:
        super().__init__()
        h = []
        for letter, weight in letters.items():
            heappush(h, Node(weight, letter))
        while 1 < len(h):
            node1 = heappop(h)
            node2 = heappop(h)
            heappush(h, Node(node1.weight+node2.weight, left=node1, right=node2))
        self.root = h[0]
        self.dictionary = {}
        self.root.create_dict(self.dictionary)

def encode(text: str) -> bitarray:
    letters = {}
    for letter in text:
        if letter not in letters.keys():
            letters[letter] = 0
        letters[letter] += 1
    tree = HuffmanTree(letters)
    bits = encoding_table(tree)
    for letter in text:
        bits += tree.dictionary[letter]
    return bits


def bitarray_to_int(bits: bitarray) -> int:
    code = 0
    for bit in bits:
        code = (code << 1) | bit
    return code


def decode(bits: bitarray) -> str:
    table_length = int.from_bytes(bits[:8], byteorder='big', signed=True)
    bits = bits[8:]
    i = 0
    table = {}
    while i < table_length:
        letter = bits[:8].tobytes().decode()
        i += 8
        bits = bits[8:]
        length = bitarray_to_int(bits[:8])
        i += 8
        bits = bits[8:]
        code = bits[:length]
        i += length
        bits = bits[length:]
        table[bitarray_to_int(code)] = letter
    i = 1
    text = ""
    while bits:
        key = bitarray_to_int(bits[:i])
        if key in table.keys():
            text += table[key]
            bits=bits[i:]
            i = 0
        i += 1
    return text