In [None]:
from read_data import read_data_cont, get_numerical_attributes
import numpy as np
from collections import deque
import math
import copy

def predict(tree, data, numel_attr):
    pred = tree['all'] * np.ones(data.shape[0])
    
    if not 'index' in tree:
        return pred

    split_index = tree['index']
    if split_index in numel_attr:
        ind_pos = np.nonzero(data[:, split_index] > tree['median'])[0]
        if 1 in tree and len(ind_pos) > 0:
            pred[ind_pos] = predict(tree[1], data[ind_pos, :], numel_attr)
            
        ind_neg = np.nonzero(data[:, split_index] <= tree['median'])[0]
        if 0 in tree and len(ind_neg) > 0:
            pred[ind_neg] = predict(tree[0], data[ind_neg, :], numel_attr)
    else:
        for key in tree:
            if not isinstance(key, np.int64):
                continue
            
            ind = np.nonzero(data[:, split_index] == key)[0]
            if len(ind) == 0:
                continue

            pred[ind] = predict(tree[key], data[ind, :], numel_attr)
    
    return pred

In [None]:
def get_accuracy(pred, truth_labels):
    return len(np.nonzero(pred == truth_labels)[0]) / len(pred)

In [None]:
def get_entropy(vals):
    m = len(vals)
    unique, count = np.unique(vals, return_counts=True)

    entropy = 0.0
    for i in range(len(unique)):
        entropy -= (count[i] / m) * math.log(count[i] / m, 2)
    
    # print(entropy)
    return entropy

In [None]:
def binarize(data, numel_attr):
    binary_data = np.array(data, dtype=np.int64)
    for index in numel_attr:
        median = np.median(data[:, index])
        binary_data[:, index] = 0
        binary_data[data[:, index] > median, index] = 1
    
    return binary_data

In [None]:
def get_split_index(data, done, numel_attr):
    m, n = data.shape
    entropy_data = get_entropy(data[:, 0])

    max_gain, index = 0.0, -1
    for i in range(1, n):
        if i in done and i not in numel_attr:
            continue

        unique = np.unique(data[:, i])

        # print('----------------')
        # print(i, unique, count)

        # print('****************\n')

        gain = entropy_data
        for j in range(len(unique)):
            ind = np.nonzero(data[:, i] == unique[j])[0]
            count = len(ind)
            gain -= (count / m) * get_entropy(data[ind, 0])

        # print('\n****************')
        # print(gain)
        # print('----------------')
        if(gain > max_gain):
            max_gain = gain
            index = i
    
    return index

In [None]:
def build_tree(data, tree, done, threshold, numel_attr):
    binary_data = binarize(data, numel_attr)
    unique, count = np.unique(binary_data[:, 0], return_counts=True)
    tree['all'] = unique[np.argmax(count)]

    # print(dtree)
    # print(done)
    # input()

    if len(unique) == 1:
        return threshold

    split_index = get_split_index(binary_data, done, numel_attr)
    if split_index == -1:
        return threshold
    
    unique = np.unique(binary_data[:, split_index])
    tree['index'] = split_index
    
    if split_index in numel_attr:
        tree['median'] = np.median(data[:, split_index])
        if split_index in threshold:
            threshold[split_index].append(tree['median'])
        else:
            threshold[split_index] = [tree['median']]
    
    max_threshold = {}
    for i in range(len(unique)):
        ind = np.nonzero(binary_data[:, split_index] == unique[i])[0]
        if len(ind) == 0:
            continue
            
        tree[unique[i]] = {'parent' : tree}
        updated_done = dict(done)
        if split_index in done:
            updated_done[split_index] += 1
        else:
            updated_done[split_index] = 1
        
        thresh = build_tree(data[ind, :], tree[unique[i]], updated_done, copy.deepcopy(threshold), numel_attr)
        for key in thresh:
            if key not in max_threshold:
                max_threshold[key] = thresh[key]
            
            if len(max_threshold[key]) < len(thresh[key]):
                max_threshold[key] = thresh[key]
    
    return max_threshold

In [None]:
def get_count(tree, data, truth_labels, count_tree, numel_attr):
    pred = tree['all'] * np.ones(data.shape[0])
    count_tree['total'] = data.shape[0]
    count_tree['count'] = len(np.nonzero(pred == truth_labels)[0])
    
    if not 'index' in tree:
        count_tree['sub_count'] = count_tree['count']
        count_tree['node_count'] = 1
        return pred

    split_index = tree['index']
    count_tree['index'] = split_index
    node_count = 1
    
    if split_index in numel_attr:
        if 1 in tree:
            ind_pos = np.nonzero(data[:, split_index] > tree['median'])[0]
            count_tree[1] = {}
            count_tree[1]['par_count'] = len(np.nonzero(pred[ind_pos] == truth_labels[ind_pos])[0])
            count_tree[1]['parent'] = count_tree
            pred[ind_pos] = get_count(tree[1], data[ind_pos, :], truth_labels[ind_pos], count_tree[1], numel_attr)
            node_count += count_tree[1]['node_count']
        
        if 0 in tree:
            ind_neg = np.nonzero(data[:, split_index] <= tree['median'])[0]
            count_tree[0] = {}
            count_tree[0]['par_count'] = len(np.nonzero(pred[ind_neg] == truth_labels[ind_neg])[0])
            count_tree[0]['parent'] = count_tree
            pred[ind_neg] = get_count(tree[0], data[ind_neg, :], truth_labels[ind_neg], count_tree[0], numel_attr)
            node_count += count_tree[0]['node_count']
    else:   
        for key in tree:
            if not isinstance(key, np.int64):
                continue

            ind = np.nonzero(data[:, split_index] == key)[0]

            count_tree[key] = {}
            count_tree[key]['par_count'] = len(np.nonzero(pred[ind] == truth_labels[ind])[0])
            count_tree[key]['parent'] = count_tree
            pred[ind] = get_count(tree[key], data[ind, :], truth_labels[ind], count_tree[key], numel_attr)
            node_count += count_tree[key]['node_count']
    
    count_tree['node_count'] = node_count
    count_tree['sub_count'] = len(np.nonzero(pred == truth_labels)[0])
    return pred

In [None]:
def get_accuracy_values(count_tree):
    m = count_tree['total']
    accuracy_values = []
    queue = deque([count_tree])
    count = 0
    while queue:
        node = queue.popleft()
        count += node['count'] - node['par_count']
        accuracy_values.append(count / m)
        
        if 'index' in node:
            for key in node:
                if isinstance(key, np.int64) or isinstance(key, int):
                    queue.append(node[key])
    
    return np.array(accuracy_values)

In [None]:
def store_accuracy_data(tree, count_tree):
    train_acc = get_accuracy_values(count_tree[0])
    valid_acc = get_accuracy_values(count_tree[1])
    test_acc = get_accuracy_values(count_tree[2])
    
    with open('acc2c.txt', 'w') as file:
        for i in range(len(train_acc)):
            file.write(str(i + 1) + ', ' + str(train_acc[i]) + ', ' + str(valid_acc[i]) + ', ' + str(test_acc[i]) + '\n')

    return

In [None]:
train = read_data_cont("dtree_data/train.csv")
valid = read_data_cont("dtree_data/valid.csv")
test = read_data_cont("dtree_data/test.csv")

numel_attr = get_numerical_attributes()

In [None]:
dtree = {'parent': None}
done = {}
threshold = build_tree(train, dtree, done, {}, numel_attr)

In [None]:
print(threshold)

In [None]:
count_tree = [{'par_count' : 0, 'parent' : None}, {'par_count' : 0, 'parent' : None}, {'par_count' : 0, 'parent' : None}]
    
train_labels = get_count(dtree, train, train[:, 0], count_tree[0], numel_attr)
valid_labels = get_count(dtree, valid, valid[:, 0], count_tree[1], numel_attr)
test_labels = get_count(dtree, test, test[:, 0], count_tree[2], numel_attr)

In [None]:
print(count_tree[0]['node_count'])

In [None]:
store_accuracy_data(dtree, count_tree)

In [None]:
labels = predict(dtree, train, numel_attr)
print(get_accuracy(labels, train[:, 0]))

labels = predict(dtree, valid, numel_attr)
print(get_accuracy(labels, valid[:, 0]))

labels = predict(dtree, test, numel_attr)
print(get_accuracy(labels, test[:, 0]))