Compute Gini

In [6]:
import numpy as np

def gini_index(groups,classes):
    total_samples = float(sum([len(group) for group in groups]))
    gini = 0.0
    for group in groups:
        size = float(len(group))
        if size == 0:
            continue
    score = 0.0
    for c in classes:
        p_i  = [row[-1] for row in group].count(c) / size
        score += p_i ** 2
    
    gini += (1 - score) *( size / total_samples)
    
    return gini

dataset = [
    [2.8, 'Yes'],
    [3.6, 'Yes'],
    [1.2, 'No'],    
    [4.5, 'No'],
    [5.1, 'Yes']
]

def split_data(dataset, feature_index, threshold):
    left = [row for row in dataset if row[feature_index] < threshold]
    right = [row for row in dataset if row[feature_index] >= threshold]
    return left, right

groups = split_data(dataset, 0, 3.0)
classes = ['Yes','No']
gini = gini_index(groups,classes)
print(f'Gini Index: {gini}')


Gini Index: 0.26666666666666666


Tree Node

In [11]:
class TreeNode:
    def __init__(self, feature_index = None, threshold = None, left = None, right = None, label = None ):
        self.feature_index = feature_index
        self.threshold = threshold
        self.right = right
        self.left = left
        self.label = label



Decision Tree

In [12]:
class DecisionTree:
    def __init__(self, max_depth=3):
        self.max_depth = max_depth
        self.root = None
    def gini_index(self,groups,classes):
        total_samples = float(sum([len(group) for group in groups]))
        gini = 0.0
        for group in groups:
            size = float(len(group))
            if size == 0:
                continue
        score = 0.0
        for c in classes:
            p_i  = [row[-1] for row in group].count(c) / size
            score += p_i ** 2
        
        gini += (1 - score) *( size / total_samples)
        
        return gini
    
    def split_data(self,dataset,feature_index,threshold):
        left = [row for row in dataset if row[feature_index] < threshold]
        right = [row for row in dataset if row[feature_index] >= threshold]
        return left, right

    def best_split(self,dataset):
        class_value = list(set(row[-1] for row in dataset))
        best_index, best_threshold, best_score, best_groups = None, None, float('inf'), None
        for index in range(len(dataset[0])-1):
            for row in dataset:
                groups = self.split_data(dataset, index, row[index])
                gini = self.gini_index(groups, class_value)
                if gini < best_score:
                    best_index, best_threshold, best_score, best_groups = index, row[index], gini, groups
                return best_index, best_threshold, best_groups
            
    def build_tree(self,dataset,depth=0):
        class_values = [row[1] for row in dataset]

        if len(set(class_values)) == 1 or depth == self.max_depth:
            return TreeNode(label=max(set(class_values), key=class_values.count))
        
        feature_index, threshold, groups = self.best_split(dataset)

        if not groups[0] or not groups[1]:
            return TreeNode(label=max(set(class_values), key=class_values.count))
        
        left_node = self.build_tree(groups[0], depth + 1)
        right_node = self.build_tree(groups[1], depth + 1)

        return TreeNode(feature_index, threshold, left_node, right_node)    
    
    def fit(self,dataset):
        self.root = self.build_tree(dataset)

    def  print_tree(self,node=None,depth=0):
        if node is None:
            node = self.root

        if node.label is not None:
            print(f'{"|  " * depth}[Leaf] Label: {node.label}')
        else:
            print(f'{"|  " * depth}[Node] Feature {node.feature_index} < {node.threshold}')
            self.print_tree(node.left, depth + 1)
            self.print_tree(node.right, depth + 1)


In [13]:
data = [
    [50,'Yes'],
    [20,'No'],
    [30,'No'],
    [70,'Yes'],
    [40,'No'],
    [60,'Yes']
]
tree = DecisionTree(max_depth=3)
tree.fit(data)
tree.print_tree()

[Node] Feature 0 < 50
|  [Leaf] Label: No
|  [Leaf] Label: Yes
