In [84]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import math
from collections import Counter
import numpy as np

iris = load_iris()

x = iris.data
y = iris.target

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=123)
print(type(x_test[0])) #numpy.ndarray

<class 'numpy.ndarray'>


In [85]:
def entropy_func(class_count, num_samples):
    return (-1) * sum([count/num_samples * math.log(count/num_samples) for count in class_count.values()])


class Group:
    def __init__(self, group_classes):
        self.group_classes = group_classes
        self.entropy = self.group_entropy()

    def __len__(self):
        return self.group_classes.size

    def group_entropy(self):
        # Type Counter creates dict with a count of every occurence of an object
        return entropy_func(Counter(self.group_classes), len(self.group_classes))


class Node:
    def __init__(self, split_feature, split_val, depth=None, child_node_a=None, child_node_b=None, val=None):
        self.split_feature = split_feature
        self.split_val = split_val
        self.depth = depth
        self.child_node_a = child_node_a
        self.child_node_b = child_node_b
        self.val = val

    def predict(self, data):
        if self.val is not None:
            return self.val
        elif data[self.split_feature] <= self.split_val:
            return self.child_node_a.predict(data)
        else:
            return self.child_node_b.predict(data)


class DecisionTreeClassifier(object):
    def __init__(self, max_depth):
        self.depth = 0
        self.max_depth = max_depth
        self.tree = None

    @staticmethod
    def get_split_entropy(group_a, group_b):
        num_samples = len(group_a) + len(group_b)
        return group_a.entropy * (len(group_a) / num_samples) + group_b.entropy * (len(group_b) / num_samples)

    def get_information_gain(self, parent_group, child_group_a, child_group_b):
        return parent_group.entropy - self.get_split_entropy(child_group_a, child_group_b)

    def get_best_feature_split(self, feature_values, classes):
        parent_group = Group(classes)
        best_gain = 0
        best_split_value = None
        best_feature_index = None
        # print(feature_values[0])
        for feature_index in range(len(feature_values[0])):

            chosen_feature_values = feature_values[:, feature_index]
            for chosen_value in chosen_feature_values:
                group_a, group_b = self.separate_by_value(chosen_feature_values, chosen_value, feature_index)
            group_a_classes, group_b_classes = self.separate_groups_by_index(classes, group_a, group_b)

            inf_gain = self.get_information_gain(parent_group, Group(np.array(group_a_classes)), Group(np.array(group_b_classes)))
            if inf_gain > best_gain:
                best_split_value = chosen_value
                best_feature_index = feature_index
                best_gain = inf_gain
        return best_split_value, best_feature_index

    def get_best_split(self, data, classes):
        if self.depth >= self.max_depth or len(np.unique(classes)) == 1:
            return Node(split_feature=None, split_val=None, val=Counter(classes).most_common(1)[0][0])
        print(f'{data}\n\n')

        best_split_value, best_feature_index = self.get_best_feature_split(data, classes)
        data_a, data_b = self.separate_by_value(data, best_split_value, best_feature_index)
        print(data_a)
        print(data_b)
        classes_a, classes_b = self.separate_by_value(classes, best_split_value, -1)

        child_group_a = self.get_best_split(data_a, classes_a)
        child_group_b = self.get_best_split(data_b, classes_b)

        self.depth += 1
        return Node(best_feature_index, best_split_value, self.depth, child_group_a, child_group_b)


    def build_tree(self, data, classes, depth=0):
        self.tree = self.get_best_split(data, classes)

    def predict(self, data):
        return self.tree.predict(data)

    @staticmethod
    def separate_groups_by_index(parent_classes, group_a_indexes, group_b_indexes):
        group_a_classes = []
        group_b_classes = []
        for id in range(len(parent_classes)):
            if id in group_a_indexes:
                group_a_classes.append(parent_classes[id])
            if id in group_b_indexes:
                group_b_classes.append(parent_classes[id])
        return np.asarray(group_a_classes), np.asarray(group_b_classes)

    @staticmethod
    def separate_by_value(parent_values: np.ndarray, chosen_value, feature_index):
        group_a = []
        group_b = []
        for index, value in np.ndenumerate(parent_values):
            if feature_index == -1:
                if value >= chosen_value:
                    group_a.append(parent_values[index[0]])
                else:
                    group_b.append(parent_values[index[0]])
            else:
                if value >= chosen_value:
                    group_a.append(parent_values[index[0], feature_index])
                else:
                    group_b.append(parent_values[index[0], feature_index])
            return np.asarray(group_a), np.asarray(group_b)

In [86]:
dc = DecisionTreeClassifier(3)
dc.build_tree(x_train, y_train)
for sample, gt in zip(x_test, y_test):
    prediction = dc.predict(sample)

[[6.5 3.  5.8 2.2]
 [5.5 3.5 1.3 0.2]
 [4.3 3.  1.1 0.1]
 [6.1 2.9 4.7 1.4]
 [4.8 3.  1.4 0.3]
 [5.2 3.4 1.4 0.2]
 [6.3 2.8 5.1 1.5]
 [4.8 3.4 1.9 0.2]
 [6.1 3.  4.9 1.8]
 [5.1 3.8 1.6 0.2]
 [5.4 3.4 1.7 0.2]
 [5.4 3.4 1.5 0.4]
 [5.6 2.8 4.9 2. ]
 [7.7 3.8 6.7 2.2]
 [5.  3.6 1.4 0.2]
 [7.4 2.8 6.1 1.9]
 [6.  2.2 5.  1.5]
 [4.7 3.2 1.6 0.2]
 [5.1 3.5 1.4 0.2]
 [6.  2.2 4.  1. ]
 [5.  2.3 3.3 1. ]
 [7.9 3.8 6.4 2. ]
 [5.4 3.9 1.7 0.4]
 [5.4 3.9 1.3 0.4]
 [5.8 2.7 3.9 1.2]
 [5.  2.  3.5 1. ]
 [5.  3.2 1.2 0.2]
 [6.8 3.2 5.9 2.3]
 [6.7 3.  5.2 2.3]
 [5.8 2.7 5.1 1.9]
 [5.8 2.8 5.1 2.4]
 [6.3 3.4 5.6 2.4]
 [5.5 2.3 4.  1.3]
 [5.1 3.8 1.5 0.3]
 [4.4 3.  1.3 0.2]
 [6.5 3.2 5.1 2. ]
 [5.1 3.3 1.7 0.5]
 [4.9 3.1 1.5 0.1]
 [6.7 3.1 4.7 1.5]
 [6.1 3.  4.6 1.4]
 [5.5 2.5 4.  1.3]
 [5.7 2.6 3.5 1. ]
 [5.8 2.7 5.1 1.9]
 [6.7 3.1 4.4 1.4]
 [6.4 3.2 5.3 2.3]
 [4.5 2.3 1.3 0.3]
 [6.7 3.3 5.7 2.1]
 [5.7 3.  4.2 1.2]
 [5.1 3.7 1.5 0.4]
 [4.8 3.4 1.6 0.2]
 [6.3 2.9 5.6 1.8]
 [6.4 2.9 4.3 1.3]
 [7.7 2.8 6.

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

In [None]:
# print(x_test)
# print(y_test)
# print(len(x_test))
# print(len(y_test))
# print(x_test[:,2])