In [1]:
import numpy as np
import pandas as pd
import json
import tqdm
from queue import PriorityQueue

In [2]:
#from copy import copy

class TreeTools:
    def __init__(self):
        pass
                
    def _get_subtrees(self, tree):
        yield tree
        for subtree in tree:
            if type(subtree) == list:
                for x in self._get_subtrees(subtree):
                    yield x

    # Returns pairs of paths and leaves of a tree
    def _get_leaves_paths(self, tree):
        for i, subtree in enumerate(tree):
            if type(subtree) == list:
                for path, value in self._get_leaves_paths(subtree):
                    yield [i] + path, value
            else:
                yield [i], subtree
    
    # Returns the number of nodes in a tree (not including root)
    def _count_nodes(self, tree):
        size = 0
        for node in tree:
            if type(node) == list:
                size += 1 + self._count_nodes(node)
        return size


    # Returns all the nodes in a path
    def _get_nodes(self, tree, path):
        next_node = 0
        nodes = []
        for decision in path:
            nodes.append(next_node)
            next_node += 1 + self._count_nodes(tree[:decision])
            tree = tree[decision]
        return nodes

    def _value_to_path_nodes_dict(self, tree):
        value_to_path_nodes_dict = {}
        for path, value in tqdm.tqdm(self._get_leaves_paths(tree)):
            nodes = self._get_nodes(tree, path)
            #value_to_path_nodes_dict[value] = path, nodes
            value_to_path_nodes_dict[int(value)] = path, nodes
        return value_to_path_nodes_dict

In [3]:
class Node(object):
    '''
    Wrapping List with Node class Since pythn PriorityQueue cannot recognize nested items in its input tuple.
    '''
    def __init__(self, symbol, freq, symbol_2=None):
        if symbol_2:
            self.symbol = [symbol, symbol_2]
        else:
            self.symbol = symbol
        self.freq = freq

    def __lt__(self, target):
        if type(target) == Node:
            return self.freq < target.freq
        else:
            return self.freq < target
    
    def __gt__(self, target):
        return not self.__lt__(target)

In [4]:
def create_huffman_tree(input_dict):
    '''
    input : dictionary with {category:freq} pairs
    NOTE : input dict must be ASCENDING order
    '''
    q = PriorityQueue()
    nodes = list(input_dict.items())

    for node in nodes:
        q.put(Node(node[0], node[1]))

    while q.qsize() > 1:
        node_1 = q.get()
        node_2 = q.get()
        #print(node_1.symbol, node_1.freq)
        #print(node_2.symbol, node_2.freq)
        node = Node(node_1.symbol, node_1.freq + node_2.freq, node_2.symbol)
        q.put(node)
    return q.get().symbol

In [5]:
train_set = np.genfromtxt("data/{}_train_set.csv".format("SEG_Wavenet"), delimiter="\n", dtype=np.int64)
val_set = np.genfromtxt("data/{}_val_set.csv".format("SEG_Wavenet"), delimiter="\n", dtype=np.int64)

In [6]:
dataset = np.r_[train_set, val_set]
dataset

array([  0,   0,   0, ..., 897, 242, 961], dtype=int64)

In [7]:
dataset_freq = pd.Series(dataset).value_counts(ascending=True).to_dict()

In [8]:
pd.Series(dataset).value_counts()

0        45798
1        26424
2         6295
3         1849
4         1848
         ...  
15519        3
13470        3
15986        3
14042        3
15212        3
Length: 16293, dtype: int64

In [9]:
## Due to unknown Bug, change category type from int to str, and revert to int when saving
## Bug : Category 0 is removed during tree construction. (Priority is not the cause)

dataset_freq = {str(i):j for i, j in dataset_freq.items()}

In [10]:
tree = create_huffman_tree(dataset_freq)

In [11]:
tree_tools = TreeTools()    
'''
print('All subtrees:')      # Num of All subtrees : 55276
for subtree in tree_tools._get_subtrees(tree):
    print('\t {} (Len : {})'.format(subtree, len(subtree)))

print('All paths and leaves:')
for subtree in tree_tools._get_leaves_paths(tree):
    print('\t',subtree)
'''
num_nodes = tree_tools._count_nodes(tree)
print('Number of nodes in the tree:', num_nodes) # except root

print('all nodes in path [0, 0, 0, 0]:')
for nodes in tree_tools._get_nodes(tree, [1, 1, 0, 0]):
    print('\t',nodes)

print('all nodes in path [1, 0]:')
for nodes in tree_tools._get_nodes(tree, [1, 0]):
    print('\t',nodes)

Number of nodes in the tree: 16291
all nodes in path [0, 0, 0, 0]:
	 0
	 5436
	 8746
	 8747
all nodes in path [1, 0]:
	 0
	 5436


In [12]:
# mapping_dict structure : 
# category : [path, nodes]
tree_mapping = tree_tools._value_to_path_nodes_dict(tree)

16293it [00:44, 362.54it/s]


In [13]:
dict(list(tree_mapping.items())[:5])

{2192: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]),
 2396: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]),
 2647: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14]),
 1579: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14]),
 2247: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16])}

In [14]:
len(tree_mapping)

16293

In [15]:
tree_mapping[0]

([0, 1], [0, 1])

## NOTE 

Number of all intermediate nodes = {num of leaves (=num of categories)} - 1  
intermediate nodes index = \[0 : {num of leaves (=num of categories)} - 1] (0 for root)

## NOTICE

JSON and Python Dictionary are different:
    JSON saves key values as string, whereas Python Dictionary can save key values as int.  
Thus, when load tree_mapping.json, need to convert its key to int.

import json

with open("outputs/tree_mapping.json", "w") as j:
    json.dump(tree_mapping, j)