In [328]:
import pandas as pd 
import numpy as np

In [329]:
training_data = np.array([
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
], dtype='O')

In [330]:
header = np.array(['color', 'diameter', 'label'])
df = pd.DataFrame(training_data, columns = header)
df.diameter = df.diameter.astype('int32')

In [335]:
def unique_vals(array):
    return set(array)

In [337]:
def class_count(y_data):
    counter = {}

    # Initialize table
    for keys in unique_vals(y_data):
        counter[keys] = 0

    # Fill table
    for class_ in y_data:
        counter[class_] += 1

    return counter

In [338]:
def is_numeric(value):
    return isinstance(value, int) or isinstance(value, float)

def is_col_numeric(types, col):
    type_ = str(types[col])
    return ('int' in type_) or ('float' in type_)

In [339]:
class Question:
    def __init__(self, header, column, value):
        self.name = header[column]
        self.column = column
        self.value = value

    def match(self, features, types):
        val = features[self.column]
        if is_col_numeric(types, self.column):
            return val >= self.value
        else:
            return val == self.value
        
    def __repr__(self):
        if is_numeric(self.value):
            return f"Is {self.name} >= {self.value}?"
        else:
            return f"Is {self.name} == {self.value}?"

In [340]:
def partition(data, question, types):
    true_partition = []
    false_partition =[]

    for row in data:
        # print(row)
        if (question.match(row, types)):
            true_partition.append(row)
        else:
            false_partition.append(row)

    return np.array(true_partition), np.array(false_partition)

In [342]:
def gini_impurity(y_data):
    count_db = class_count(y_data)
    total_count = float(len(y_data))
    gini = 1 
    
    for _, ctr in count_db.items():
        gini -= (ctr / total_count)**2

    return gini

In [343]:
def info_gain(left, right, current_uncertainty):
    ll = float(len(left)); lr = float(len(right))
    frac = ll / (ll + lr)
    avg_gini = frac * gini_impurity(left) + (1 - frac) * gini_impurity(right)
    
    return current_uncertainty - avg_gini

In [344]:
def get_best_split(header, data, types):
    best_gain = 0
    best_question = None 
    current_uncertainty = gini_impurity(data[:, -1])

    n_features = len(data[0, :-1])
    for col in range(n_features):
        values = unique_vals(data[:, col])

        for val in values:
            q = Question(header, col, val)
            true_data, false_data = partition(data, q, types)

            if len(true_data) == 0 or len(false_data) == 0:
                continue # skip

            gain = info_gain(true_data[:, -1], false_data[:, -1],
            current_uncertainty)

            if gain >= best_gain:
                best_gain, best_question = gain, q

    return best_gain, best_question

In [345]:
class Leaf:
    def __init__(self, data):
        self.count = class_count(data[:, -1])

    def __repr__(self):
        txt = "{"
        for key, val in self.count.items():
            txt += f"{key}: {val}, "
        txt = txt[:-2] + "}"
        return txt

In [346]:
class DecisionNode:
    def __init__(self, question, true_branch, false_branch, types):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch
        self.types = types

    def classify(self, node, X_data):
        features = np.array(X_data)
        if isinstance(node, Leaf):
            return node.count
        
        else:
            if node.question.match(features, self.types):
                return self.classify(node.true_branch, features)
            else:
                return self.classify(node.false_branch, features)
                
    def print_node(self, indent):
        spacing = ' ' * indent
        print(spacing + str(self.question))
        print(spacing + "=== TRUE ===")
        if not(isinstance(self.true_branch, Leaf)):
            self.true_branch.print_node(indent + 5)
        else:
            print(spacing, end="")
            print(self.true_branch)

        print(' ' * indent + "=== FALSE ===")
        if not(isinstance(self.false_branch, Leaf)):
            self.false_branch.print_node(indent + 5)
        else:
            print(spacing, end="")
            print(self.false_branch)

In [347]:
class DecisionTreeClassifier:
    def __init__(self, max_depth=5):
        self.tree = None
        self.types = None
        self.max_depth = max_depth

    def fit(self, X_data, y_data):
        features = np.array(X_data)
        target = np.array([y_data]).T
        data = np.concatenate((features, target), axis=1)

        header = [name for name in X_data.columns]
        header.append(y_data.name)
        
        self.types = [val for _, val in X_data.dtypes.items()][:-1]
        self.types.append(y_data.dtype)
        
        self.tree = self.grow_tree(data, header)
                
    def grow_tree(self, data, header):
        gain, question = get_best_split(header, data, self.types)
        
        if gain == 0:
            return Leaf(data)

        true_data, false_data = partition(data, question, self.types)

        true_branch = self.grow_tree(true_data, header)
        false_branch = self.grow_tree(false_data, header)

        return DecisionNode(question, true_branch, false_branch, self.types)

    def print_tree(self):
        self.tree.print_node(0)