In [1]:
import tensorflow as tf
from tensorflow import keras
import kerastuner
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import os
import json
import datetime
import dill
import tqdm
from queue import PriorityQueue

plt.rcParams["figure.figsize"] = (20, 5)

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [2]:
def create_random_tree(output_dim):
    outputs = list(range(output_dim))
    #shuffle(outputs)

    while len(outputs) > 2:
        temp_outputs = []
        for i in range(0, len(outputs), 2):
            if len(outputs) - (i+1) > 0:
                temp_outputs.append([outputs[i], outputs[i+1]])
            else:
                temp_outputs.append(outputs[i])
        outputs = temp_outputs
    return outputs

In [3]:
create_random_tree(10)

[[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], [8, 9]]

In [4]:
from random import shuffle
from copy import copy

class TreeTools:
    def __init__(self):
        #memoization for _count_nodes functions
        self._count_nodes_dict = {}
                
    def _get_subtrees(self, tree):
        yield tree
        for subtree in tree:
            if type(subtree) == list:
                for x in self._get_subtrees(subtree):
                    yield x

    # Returns pairs of paths and leaves of a tree
    def _get_leaves_paths(self, tree):
        for i, subtree in enumerate(tree):
            if type(subtree) == list:
                for path, value in self._get_leaves_paths(subtree):
                    yield [i] + path, value
            else:
                yield [i], subtree
    
    # Returns the number of nodes in a tree (not including root)
    def _count_nodes(self, tree):
        if id(tree) in self._count_nodes_dict:
            return self._count_nodes_dict[id(tree)]
        size = 0
        for node in tree:
            if type(node) == list:
                size += 1 + self._count_nodes(node)
        self._count_nodes_dict[id(self._count_nodes_dict)] = size
        return size


    # Returns all the nodes in a path
    def _get_nodes(self, tree, path):
        next_node = 0
        nodes = []
        for decision in path:
            nodes.append(next_node)
            next_node += 1 + self._count_nodes(tree[:decision])
            tree = tree[decision]
        return nodes


# turns a list to a binary tree
def random_binary_full_tree(outputs):
    outputs = copy(outputs)
    shuffle(outputs)

    while len(outputs) > 2:
        temp_outputs = []
        for i in range(0, len(outputs), 2):
            if len(outputs) - (i+1) > 0:
                temp_outputs.append([outputs[i], outputs[i+1]])
            else:
                temp_outputs.append(outputs[i])
        outputs = temp_outputs
    return outputs

In [5]:
tree = random_binary_full_tree(list(range(10)))
print('Our tree:',tree)

tree_tools = TreeTools()    

print('All subtrees:')
for subtree in tree_tools._get_subtrees(tree):
    print('\t {} (Len : {})'.format(subtree, len(subtree)))

print('All paths and leaves:')
for subtree in tree_tools._get_leaves_paths(tree):
    print('\t',subtree)
    
print('Number of nodes in the tree:',tree_tools._count_nodes(tree))

print('all nodes in path [0, 0, 0, 0]:')
for nodes in tree_tools._get_nodes(tree, [0, 0, 0, 0]):
    print('\t',nodes)

print('all nodes in path [1, 0]:')
for nodes in tree_tools._get_nodes(tree, [1, 0]):
    print('\t',nodes)

Our tree: [[[[9, 3], [4, 1]], [[8, 0], [6, 2]]], [5, 7]]
All subtrees:
	 [[[[9, 3], [4, 1]], [[8, 0], [6, 2]]], [5, 7]] (Len : 2)
	 [[[9, 3], [4, 1]], [[8, 0], [6, 2]]] (Len : 2)
	 [[9, 3], [4, 1]] (Len : 2)
	 [9, 3] (Len : 2)
	 [4, 1] (Len : 2)
	 [[8, 0], [6, 2]] (Len : 2)
	 [8, 0] (Len : 2)
	 [6, 2] (Len : 2)
	 [5, 7] (Len : 2)
All paths and leaves:
	 ([0, 0, 0, 0], 9)
	 ([0, 0, 0, 1], 3)
	 ([0, 0, 1, 0], 4)
	 ([0, 0, 1, 1], 1)
	 ([0, 1, 0, 0], 8)
	 ([0, 1, 0, 1], 0)
	 ([0, 1, 1, 0], 6)
	 ([0, 1, 1, 1], 2)
	 ([1, 0], 5)
	 ([1, 1], 7)
Number of nodes in the tree: 8
all nodes in path [0, 0, 0, 0]:
	 0
	 1
	 2
	 3
all nodes in path [1, 0]:
	 0
	 8


In [6]:
class hier_softmax:
    def __init__(self, tree, contex_size, model):
        self._tree_tools = TreeTools()
        self.str2weight = {}
        #create a weight matrix and bias vector for each node in the tree
        for i, subtree in enumerate(self._tree_tools._get_subtrees(tree)):
            self.str2weight["softmax_node_"+str(i)+"_w"] = model.add_parameters((len(subtree), contex_size))
            self.str2weight["softmax_node_" + str(i) + "_b"] = model.add_parameters(len(subtree))
        
        #create a dictionary from each value to its path
        value_to_path_and_nodes_dict = {}
        for path, value in self._tree_tools._get_leaves_paths(tree):
            nodes = self._tree_tools._get_nodes(tree, path)
            value_to_path_and_nodes_dict[data.char2int[value]] = path, nodes
        self.value_to_path_and_nodes_dict = value_to_path_and_nodes_dict
        self.model = model
        self.tree = tree
    
    #get the loss on a given value (for training)
    def get_loss(self, context, value):
        loss = []
        path, nodes = self.value_to_path_and_nodes_dict[value]
        for p, n in zip(path, nodes):
            w = dy.parameter(self.str2weight["softmax_node_"+str(n)+"_w"])
            b = dy.parameter(self.str2weight["softmax_node_" + str(n) + "_b"])
            probs = tf.nn.softmax(w*context+b)
            #loss.append(-tf.math.log(dy.pick(probs, p)))
            print(probs)
            print(p)
        #return dy.esum(loss)

    #get the most likely
    def generate(self, context):
        best_value = None
        best_loss = float(100000)
        for value in self.value_to_path_and_nodes_dict:
            loss = self.get_loss(context, value)
            if loss < best_loss:
                best_loss = loss
                best_value = value
        return best_value

In [7]:
a = [1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6]
#freq = pd.Series(a).value_counts(ascending=True).to_dict()
freq = {
    "A": 6,
    "B": 4,
    "C": 4,
    "D": 3,
    "E": 2,
    "F": 1
}
b = list(freq.keys())

In [8]:
b, freq

(['A', 'B', 'C', 'D', 'E', 'F'],
 {'A': 6, 'B': 4, 'C': 4, 'D': 3, 'E': 2, 'F': 1})

Huffman Encoding 

# Main function implementing huffman coding
def huffman_code_tree(node, path, left=True):
    if type(node) is not Node:
        return {node: path}
    (l, r) = node.children()
    d = dict()

    l_path = copy(path)
    l_path.append(0)
    r_path = copy(path)
    r_path.append(1)

    d.update(huffman_code_tree(l, l_path, True))
    d.update(huffman_code_tree(r, r_path, False))
    return d

nodes = list(freq.items())

while len(nodes) > 1:
    (key1, c1) = nodes[-1]
    (key2, c2) = nodes[-2]
    nodes = nodes[:-2]
    node = Node(key1, key2)
    nodes.append((node, c1 + c2))

    nodes = sorted(nodes, key=lambda x: x[1], reverse=True)

huffmanCode = huffman_code_tree(nodes[0][0], [])
huffmanCode

Huffman Binary Tree as Node (not as list)

class Node(object):
    node_id = 0
    def __init__(self, symbol, freq, left=None, right=None):
        self.left = left
        self.right = right
        self.symbol = symbol
        self.freq = freq

        if self.symbol == None:
            self.node_id = Node.node_id
            Node.node_id += 1

    def children(self):
        return (self.left, self.right)

    def __lt__(self, target):
        if type(target) == Node:
            return self.freq < target.freq
        else:
            return self.freq < target
    
    def __gt__(self, target):
        return not self.__lt__(target)

nodes = list(freq.items())
q = PriorityQueue()

for node in nodes:
    q.put(Node(node[0], node[1]))

while q.qsize() > 1:
    node_1 = q.get()
    node_2 = q.get()
    print(node_1.symbol, node_1.freq)
    print(node_2.symbol, node_2.freq)
    node = Node(None, node_1.freq + node_2.freq, node_1, node_2)
    q.put(node)

def traverse(tree):
    l, r = tree.children()
    if l == None or r == None:
        print(tree.symbol)
        return
    #print(tree.freq)    
    
    traverse(l)
    traverse(r)

traverse(tree)

In [9]:
class Node(object):
    def __init__(self, symbol, freq, symbol_2=None):
        if symbol_2:
            self.symbol = [symbol, symbol_2]
        else:
            self.symbol = symbol
        self.freq = freq

    def __lt__(self, target):
        if type(target) == Node:
            return self.freq < target.freq
        else:
            return self.freq < target
    
    def __gt__(self, target):
        return not self.__lt__(target)

In [10]:
q = PriorityQueue()
nodes = list(freq.items())

for node in nodes:
    q.put((node[1], [node[0]]))

In [11]:
nodes = list(freq.items())
q = PriorityQueue()

for node in nodes:
    q.put(Node(node[0], node[1]))

while q.qsize() > 1:
    node_1 = q.get()
    node_2 = q.get()
    print(node_1.symbol, node_1.freq)
    print(node_2.symbol, node_2.freq)
    node = Node(node_1.symbol, node_1.freq + node_2.freq, node_2.symbol)
    q.put(node)

F 1
E 2
D 3
['F', 'E'] 3
C 4
B 4
['D', ['F', 'E']] 6
A 6
['C', 'B'] 8
[['D', ['F', 'E']], 'A'] 12


In [13]:
tree = q.get().symbol
tree

[['C', 'B'], [['D', ['F', 'E']], 'A']]

In [37]:
tree_tools = TreeTools()    

print('All subtrees:')
for subtree in tree_tools._get_subtrees(tree):
    print('\t {} (Len : {})'.format(subtree, len(subtree)))

print('All paths and leaves:')
for subtree in tree_tools._get_leaves_paths(tree):
    print('\t',subtree)
    
print('Number of nodes in the tree:',tree_tools._count_nodes(tree))

print('all nodes in path [0, 0, 0, 0]:')
for nodes in tree_tools._get_nodes(tree, [1, 1, 0, 0]):
    print('\t',nodes)

print('all nodes in path [1, 0]:')
for nodes in tree_tools._get_nodes(tree, [1, 0]):
    print('\t',nodes)

All subtrees:
	 [['C', 'B'], ['A', [['F', 'E'], 'D']]] (Len : 2)
	 ['C', 'B'] (Len : 2)
	 ['A', [['F', 'E'], 'D']] (Len : 2)
	 [['F', 'E'], 'D'] (Len : 2)
	 ['F', 'E'] (Len : 2)
All paths and leaves:
	 ([0, 0], 'C')
	 ([0, 1], 'B')
	 ([1, 0], 'A')
	 ([1, 1, 0, 0], 'F')
	 ([1, 1, 0, 1], 'E')
	 ([1, 1, 1], 'D')
Number of nodes in the tree: 4
all nodes in path [0, 0, 0, 0]:
	 0
	 2
	 3
	 4
all nodes in path [1, 0]:
	 0
	 2
