In [1]:
# https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/

In [72]:
import numpy as np
import time
import math
from sklearn.model_selection import train_test_split

In [9]:
math.inf

inf

In [250]:
def gini_index(groups, labels):
    gini = 0.0
    instances_len = float(sum([len(group) for group in groups]))
    print(instances_len)
    for group in groups:
        group_size = float(len(group))
        if group_size == 0:
            continue
        score = 0.0
        for class_label in labels:
            p = [row[-1] for row in group].count(class_label) / group_size
            score += p * p 
        gini += (1.0 - score) * (group_size / instances_len)
        
    return gini

In [251]:
assert gini_index([[[1, 1], [1, 0]], [[1, 1], [1, 0]]], [0, 1]) == 0.5

4.0


In [252]:
def test_split(attribute_index, value, dataset):
    left, right = list(), list()
#     print('test_split')
    for row in dataset:
        if row[attribute_index] < value:
            left.append(row)
        else: 
            right.append(row)
    return left, right

In [253]:
def get_split(dataset):
    
    class_values = list(set([row[-1] for row in dataset]))
    attribute_index, value, score, groups = 999,999,999, None
    for index in range(len(dataset[0]) - 1):
        start_time = time.time()
        for row in dataset:
            split_groups = test_split(index, row[index], dataset)
            gini = gini_index(split_groups, class_values)
            if gini < score:
                attribute_index, value, score, groups = index, row[index], gini, split_groups
        print(index, "--- get_split index %s seconds ---" % (time.time() - start_time))
    return {'attribute_index': attribute_index, 'value': value, 'groups': groups}

In [254]:
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

In [255]:
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)

In [256]:
# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

In [257]:
def predict(node, row):
	if row[node['attribute_index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

In [258]:
def decision_tree(train, test, max_depth, min_size):
    tree = build_tree(train, max_depth, min_size)
    predictions = []
    for row in test:
        prediction = predict(tree, row)
        predictions.append(prediction)
    return predictions

In [259]:
bank_data = np.genfromtxt('data_banknote_authentication.txt', delimiter=',')

In [260]:
bank_data[-1]

array([-2.5419 , -0.65804,  2.6842 ,  1.1952 ,  1.     ])

In [261]:
a = get_split(bank_data)

1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0
1372.0

In [249]:
# pred = decision_tree(bank_data, bank_data, 5, 10)

In [65]:
len(pred) == len(bank_data)

True

In [66]:
correct = 0
for i in range(len(pred)):
    if pred[i] == bank_data[i][-1]:
        correct+=1
    
correct / len(bank_data)

0.9839650145772595