In [15]:
import numpy as np
import pandas as pd
import random

In [16]:
df = pd.read_csv('/Users/josephbell/Downloads/iris.csv')
df = df.drop("Id", axis = 1)
df = df.rename(columns = {"Species" : "target"})
df.head()

Unnamed: 0,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,target
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


In [17]:
def train_test_split(df, test_size):
    if isinstance(test_size, float):
        test_size = round(test_size * len(df))
    indices = df.index.tolist()
    test_indices = random.sample(population = indices, k = test_size)
    test_df = df.loc[test_indices]
    train_df = df.drop(test_indices)
    return train_df, test_df

In [18]:
random.seed(0)
train_df, test_df = train_test_split(df, 30)
data = train_df.values

In [19]:
class Node(object):
    def __init__(self, target=None, attr=None, splitvalue=None, left=None, right=None):
        self.target = target
        self.attr = attr
        self.splitvalue = splitvalue
        self.left = left
        self.right = right
        
    def set_target(self, target):
        self.target = target 
        
    def set_attr(self, attr, splitvalue):
        self.attr = attr
        self.splitvalue = splitvalue

In [20]:
class DecisionTree():
    def __init__(self, min_samples_split=2):
        self.min_samples_split = min_samples_split

    # is the data pure meaning does the split contain only 1 class?
    def check_purity(self, data):
        # access all the rows of the target column of the data
        target_column = data[:, -1]
        # determine the number of unique classes
        unique_classes = np.unique(target_column)
        # if the number of unique classes is equal to 1
        if len(unique_classes) == 1:
            # the data is pure, return True
            return True
        else:
            # the data is not pure, return False
            return False
        
    def calculate_entropy(self, data):
        # access all the rows of the target column of the data
        target_column = data[:, -1]
        # determine the number of unique classes
        _, counts = np.unique(target_column, return_counts=True)
        # get probabilites of each class
        probabilities = counts / counts.sum()
        entropy = sum(probabilities * -np.log2(probabilities))
        return entropy

    def info_gain(self, data, column_index, splitval):
        split_column_values = data[:, column_index]
        data_left = data[split_column_values <= splitval]
        data_right = data[split_column_values > splitval]

        data_points = len(data_left) + len(data_right)
        p_data_left = len(data_left) / data_points
        p_data_right = len(data_right) / data_points

        info_gain = self.calculate_entropy(data) - (p_data_right * self.calculate_entropy(data_right) 
                    + p_data_left *  self.calculate_entropy(data_left))
        return info_gain

    def find_best_split(self, data):
        bestgain = 0
        _, n_columns = data.shape
        for column_index in range(n_columns-1):
            values = data[:, column_index]
            unique_values = np.unique(values)
            for i in range(1,len(unique_values)):
                splitval = (unique_values[i-1] + unique_values[i]) / 2
                gain = self.info_gain(data, column_index, splitval)
                if gain >= bestgain:
                    bestgain = gain
                    bestattr = column_index
                    bestsplitval = splitval
        return bestattr, bestsplitval
        
    # classify data by majority class
    def classify_data(self, data):
        # access all the rows of the last column of the data
        target_column = data[:, -1]
        unique_classes, counts_unique_classes = np.unique(target_column, return_counts=True)
        # finds the majority class in the data
        index = counts_unique_classes.argmax()
        classification = unique_classes[index]
        return classification
    
    def fit(self, data):
        if len(data) < self.min_samples_split or self.check_purity(data):
            node = Node()
            count = 0 
            target_column = data[:, -1]
            unique_classes, counts_unique_classes = np.unique(target_column, return_counts = True)
            index = counts_unique_classes.argmax()
            for i in counts_unique_classes:
                if counts_unique_classes[index] == i:
                    count+=1
            if count == 1:       
                node.set_target(unique_classes[index])
            return node

        node = Node()
        column_index ,split = self.find_best_split(data)
        node.set_attr(attr = column_index, splitvalue = split)
        node.left = self.fit(data[data[:, column_index] < split])
        node.right = self.fit(data[data[:, column_index] > split])

        return node
    
    def get_target(self, row, n):
        while n.target is None:
            if row[n.attr] <= n.splitvalue:
                n = n.left 
            else:
                n = n.right
        return n.target

    def predict(self, tree, X_test):
        targets = []
        for i in X_test:
            target = self.get_target(i, tree)
            targets.append(target)
        return targets

    def accuracy_score(self, y_true, y_pred):
        accuracy_score = np.sum(y_true == y_pred) / len(y_pred)
        return accuracy_score

In [24]:
X_test = test_df.values[:, :-1]

tree = DecisionTree(min_samples_split=10)
root = tree.fit(data=data)
y_pred = tree.predict(root, X_test)
accuracy_score = tree.accuracy_score(y_test, y_pred)
print(accuracy_score)

0.9666666666666667


In [25]:
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier

X_train = data[:, :-1]
y_train = data[:, -1]
X_test = test_df.values[:, :-1]
y_test = test_df['target'].values

clf = DecisionTreeClassifier(min_samples_split=10)
clf = clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(accuracy_score(y_test, y_pred))

0.9666666666666667
