In [None]:
import os
import heapq

class SymbolProb():
    def __init__(self):
        self.map={}
#         self.map[0] = 10 # for unknown symbols
    
    def __len__(self):
        return len(self.map)
        
    def count(self, symbol):
        if symbol not in self.map:
            self.map[symbol]=1
        else:
            self.map[symbol]+=1
    
    def dist(self):
        return self.map

    def parseFile(self, filepath):
        with open(filepath, 'r') as f:
            strlist = f.readlines()
            for s in strlist:
                if s[0] == '<':
                    continue
                cntlist = s.split(' ')
                for sym in cntlist:
                    self.count(sym)

class HuffmanCoding:
    def __init__(self):
        self.heap = []
        self.codes = {}
        self.reverse_mapping = {}
    
    # Using priority queue. (heapq library)
    class HeapNode:
        def __init__(self, sym, freq, l=None, r=None):
            self.sym = sym
            self.freq = freq
            self.left = l
            self.right = r
        
        # overwriting built-in functions
        def __lt__(self, other):
            return self.freq < other.freq
        def __eq__(self, other):
            if (other == None):
                return False
            if (not isinstance(other, HeapNode)):
                return False
            return self.freq == other.freq

    # internal helper functions
    def _make_heap(self, dist):
        for key, val in dist.items():
            node = self.HeapNode(key, val)
            heapq.heappush(self.heap, node)
    def _merge_nodes(self):
        while (len(self.heap)>1):
            node1 = heapq.heappop(self.heap)
            node2 = heapq.heappop(self.heap)
            
            merged_node = self.HeapNode(None, node1.freq + node2.freq, node1, node2)
            heapq.heappush(self.heap, merged_node)
    def _make_codes_helper(self, root, current_code):
        if (root == None):
            return
        
        if (root.sym != None):
            self.codes[root.sym] = current_code
            self.reverse_mapping[current_code] = root.sym
            return
        
        self._make_codes_helper(root.left, current_code + '0')
        self._make_codes_helper(root.right, current_code + '1')
    def _int_to_bit_stream(self, n):
        # little endian, 32bit n
        if not isinstance(n, int):
            raise Exception('_int_to_bit_stream(): input is not integer')
        res = '1' if n%2==1 else '0'
        for i in range(31):
            n = n//2
            if n%2 == 1:
                res += '1'
            else:
                res += '0'
        return res
    def _bit_stream_to_int(self, bit_stream):
        # little endian, 32bit stream
        if not isinstance(bit_stream, str):
            raise Exception('_bit_stream_to_int(): input is not bitstream')
        res = 0
        for i in range(32):
            res *= 2
            # when error occurs, the bit_stream does not have 32bit-long str.
            if bit_stream[31-i] != '0':
                res += 1
        return res
    
    # use functions
    def make_codes(self, dist):
        self._make_heap(dist)
        self._merge_nodes()
        
        root = heapq.heappop(self.heap)
        start_code = ''
        self._make_codes_helper(root, start_code)

    def get_encoded_stream(self, symlist):
        encoded_bit_stream = ''
        for sym in symlist:
            if sym in self.codes:
                encoded_bit_stream += self.codes[sym]
            else:
                # This version cannot encode unknown symbols, so it will omit errors.
                word_bit_stream = self.codes[0]
                word_bit_stream = word_bit_stream + self._int_to_bit_stream(len(sym))
                for c in sym:
                    word_bit_stream = word_bit_stream + self._int_to_bit_stream(ord(c))
                encoded_bit_stream += word_bit_stream
                
        return encoded_bit_stream
    
    def get_decoded_symbols(self, bit_stream):
        current_code = ''
        decoded_data = ''
        
        L = len(bit_stream) # for avoiding segFault
        run_idx = -1
        for idx, bit in enumerate(bit_stream):
            if idx <= run_idx:
                continue
            current_code += bit
            if (current_code in self.reverse_mapping):
                sym = self.reverse_mapping[current_code]
                if sym != 0:
                    decoded_data += sym
                    current_code = ''
                else:
                    N=0
                    if L > idx+32:
                        current_code = bit_stream[idx+1:idx+33]
                        N = self._bit_stream_to_int(current_code)
                    N = N if N>0 else 0
                    if L > idx+32+(N*32):
                        res = ''
                        for i in range(N):
                            idx+=32
                            current_code = bit_stream[idx+1:idx+33]
                            i_c = self._bit_stream_to_int(current_code)
                            if (i_c < 0x1100):
                                c = chr(self._bit_stream_to_int(current_code))
                            else:
                                c = ''
                            res = res + c
                        run_idx = idx+32
                        decoded_data += res
                    current_code = ''
        return decoded_data
    
    def get_decoded_txt_symbols(self, bit_stream):
        current_code = ''
        decoded_text = ''
        
        L = len(bit_stream) # for avoiding segFault
        run_idx = -1
        for idx, bit in enumerate(bit_stream):
            if idx <= run_idx:
                continue
            current_code += bit
            if (current_code in self.reverse_mapping):
                sym = self.reverse_mapping[current_code]
                if sym != 0:
                    decoded_text += sym + ' '
                    current_code = ''
                else:
                    N=0
                    if L > idx+32:
                        current_code = bit_stream[idx+1:idx+33]
                        N = self._bit_stream_to_int(current_code)
                    N = N if N>0 else 0
                    if L > idx+32+(N*32):
                        res = ''
                        for i in range(N):
                            idx+=32
                            current_code = bit_stream[idx+1:idx+33]
                            i_c = self._bit_stream_to_int(current_code)
                            if (i_c <= 0x1100):
                                c = chr(self._bit_stream_to_int(current_code))
                            else:
                                c = ''
                            res = res + c
                        run_idx = idx+32
                        decoded_text += res + ' '
                    current_code = ''
        return decoded_text[0:len(decoded_text)-1]



In [None]:
## Use-case example

# target folder
target = './txt/en'

filelist = os.listdir(target)
filelist = [os.path.join(target, filename) for filename in filelist]

# make a blank distribution as SP. (for training)
SP = SymbolProb()
print(SP.dist())

# parse a file, and save it to SP.
SP.parseFile(filelist[0])
print(SP.dist())

In [None]:
# Define a blank Huffman code map.
h = HuffmanCoding()

# make a codebook.
h.make_codes(SP.dist())

# produce an example output.
def parseTXT(filepath):
    with open(filepath, 'r') as f:
        return f.readlines()

strlist = parseTXT(filelist[0])
org_str = strlist[2]
symlist = org_str.split(' ')

print(f'original string: {org_str}')
print(f'symbols: {symlist}')

bit_stream = h.get_encoded_stream(symlist)

print(f'encoded bit stream (length {len(bit_stream)}): {bit_stream}')

decoded_txt = h.get_decoded_txt_symbols(bit_stream)

print(f'decoded txt string: {decoded_txt}')






In [None]:
print(h.codes)