In [55]:
from read_data import read_data
import numpy as np
from collections import deque
import math

def predict(tree, data):
    pred = tree['all'] * np.ones(data.shape[0])
    
    if not 'index' in tree:
        return pred

    split_index = tree['index']
    for key in tree:
        if not isinstance(key, np.int64):
            continue

        ind = np.nonzero(data[:, split_index] == key)[0]
        if len(ind) == 0:
            continue

        pred[ind] = predict(tree[key], data[ind, :])
    
    return pred

In [56]:
def get_accuracy(pred, truth_labels):
    return len(np.nonzero(pred == truth_labels)[0]) / len(pred)

In [57]:
def get_entropy(vals):
    m = len(vals)
    unique, count = np.unique(vals, return_counts=True)

    entropy = 0.0
    for i in range(len(unique)):
        entropy -= (count[i] / m) * math.log(count[i] / m, 2)
    
    # print(entropy)
    return entropy

In [58]:
def get_split_index(data, done):
    m, n = data.shape
    entropy_data = get_entropy(data[:, 0])

    max_gain, index = 0.0, -1
    for i in range(1, n):
        if i in done:
            continue

        unique = np.unique(data[:, i])

        # print('----------------')
        # print(i, unique, count)

        # print('****************\n')

        gain = entropy_data
        for j in range(len(unique)):
            ind = np.nonzero(data[:, i] == unique[j])[0]
            count = len(ind)
            gain -= (count / m) * get_entropy(data[ind, 0])

        # print('\n****************')
        # print(gain)
        # print('----------------')
        if(gain > max_gain):
            max_gain = gain
            index = i
    
    return index

In [59]:
def build_tree(data, tree, done):
    unique, count = np.unique(data[:, 0], return_counts=True)
    tree['all'] = unique[np.argmax(count)]

    # print(dtree)
    # print(done)
    # input()

    if len(unique) == 1 or len(done) == data.shape[1] - 1:
        return

    split_index = get_split_index(data, done)
    if split_index == -1:
        return

    unique = np.unique(data[:, split_index])
    tree['index'] = split_index
    
    for i in range(len(unique)):
        ind = np.nonzero(data[:, split_index] == unique[i])[0]
        if len(ind) == 0:
            continue
            
        tree[unique[i]] = {'parent' : tree}
        updated_done = done.copy()
        updated_done.add(split_index)
        build_tree(data[ind, :], tree[unique[i]], updated_done)
    
    return

In [100]:
def get_count(tree, data, truth_labels, count_tree):
    pred = tree['all'] * np.ones(data.shape[0])
    count_tree['total'] = data.shape[0]
    count_tree['count'] = len(np.nonzero(pred == truth_labels)[0])
    
    if not 'index' in tree:
        count_tree['sub_count'] = count_tree['count']
        count_tree['node_count'] = 1
        return pred

    split_index = tree['index']
    count_tree['index'] = split_index
    node_count = 1
    for key in tree:
        if not isinstance(key, np.int64):
            continue

        ind = np.nonzero(data[:, split_index] == key)[0]
        
        count_tree[key] = {}
        count_tree[key]['par_count'] = len(np.nonzero(pred[ind] == truth_labels[ind])[0])
        count_tree[key]['parent'] = count_tree
        pred[ind] = get_count(tree[key], data[ind, :], truth_labels[ind], count_tree[key])
        node_count += count_tree[key]['node_count']
    
    count_tree['node_count'] = node_count
    count_tree['sub_count'] = len(np.nonzero(pred == truth_labels)[0])
    return pred

In [93]:
def get_accuracy_values(count_tree):
    m = count_tree['total']
    accuracy_values = []
    queue = deque([count_tree])
    count = 0
    while queue:
        node = queue.popleft()
        count += node['count'] - node['par_count']
        accuracy_values.append(count / m)
        
        if 'index' in node:
            for key in node:
                if isinstance(key, np.int64):
                    queue.append(node[key])
    
    return np.array(accuracy_values)

In [116]:
def store_accuracy_data(tree, count_tree):
    train_acc = get_accuracy_values(count_tree[0])
    valid_acc = get_accuracy_values(count_tree[1])
    test_acc = get_accuracy_values(count_tree[2])
    
    with open('acc.txt', 'w') as file:
        for i in range(len(train_acc)):
            file.write(str(i + 1) + ', ' + str(train_acc[i]) + ', ' + str(valid_acc[i]) + ', ' + str(test_acc[i]) + '\n')

    return

In [139]:
def get_best_prune_node(tree, count_tree, key):
    if not 'index' in tree:
        return (count_tree[1]['sub_count'] - count_tree[1]['par_count'], key, tree, count_tree)
    
    min_incr = count_tree[1]['sub_count'] - count_tree[1]['par_count']
    min_key, min_node, min_count_node = key, tree, count_tree
    for key in tree:
        if not isinstance(key, np.int64):
            continue
        
        incr, index, node, count_node = get_best_prune_node(tree[key], [count_tree[x][key] for x in range(3)], key)
        if(incr <= min_incr):
            min_incr = incr
            min_key = index
            min_node, min_count_node = node, count_node

    
    return (min_incr, min_key, min_node, min_count_node)

In [207]:
def prune_tree(tree, count_tree):
    file = open('prune_acc.txt', 'w')
    file.write(', '.join([str(count_tree[1]['node_count'])] + [str(count_tree[x]['sub_count'] / count_tree[x]['total']) for x in range(3)]) + '\n')
    
    while True:
        incr, key, node, ct_node = get_best_prune_node(tree, count_tree, -1)
        if incr >= 0:
            break
        
        node = node['parent']
        del node[key]
        
        change, node_count = [0] * 3, [0] * 3
        for i in range(3):
            node_count[i] = ct_node[i]['node_count']
            ct_node[i]['parent']['node_count'] -= ct_node[i]['node_count']
            change[i] = ct_node[i]['par_count'] - ct_node[i]['sub_count']
            ct_node[i]['parent']['sub_count'] -= ct_node[i]['sub_count'] - ct_node[i]['par_count']
            ct_node[i] = ct_node[i]['parent']
            del ct_node[i][key]
        
        while node['parent']:
            node = node['parent']
            for i in range(3):
                ct_node[i]['parent']['node_count'] -= node_count[i]
                ct_node[i]['parent']['sub_count'] += change[i]
                ct_node[i] = ct_node[i]['parent']
        
        file.write(', '.join([str(count_tree[1]['node_count'])] + [str(count_tree[x]['sub_count'] / count_tree[x]['total']) for x in range(3)]) + '\n')
    
    file.close()
    return 

In [192]:
train = read_data("dtree_data/train.csv")
valid = read_data("dtree_data/valid.csv")
test = read_data("dtree_data/test.csv")

In [205]:
dtree = {'parent': None}
build_tree(train, dtree, set())

In [206]:
count_tree = [{'par_count' : 0, 'parent' : None}, {'par_count' : 0, 'parent' : None}, {'par_count' : 0, 'parent' : None}]
    
train_labels = get_count(dtree, train, train[:, 0], count_tree[0])
valid_labels = get_count(dtree, valid, valid[:, 0], count_tree[1])
test_labels = get_count(dtree, test, test[:, 0], count_tree[2])

In [195]:
store_accuracy_data(dtree, count_tree)

In [208]:
prune_tree(dtree, count_tree)

In [209]:
labels = predict(dtree, train)
print(get_accuracy(labels, train[:, 0]))

labels = predict(dtree, valid)
print(get_accuracy(labels, valid[:, 0]))

labels = predict(dtree, test)
print(get_accuracy(labels, test[:, 0]))

0.8641111111111112
0.8536666666666667
0.817
