In [1]:
import numpy as np
import pandas as pd

In [2]:
lst = ['a', 'b', 'c', 'd', 'b', 'c', 'a', 'b', 'c', 'd', 'a']
def gini(nums):
    probs = [nums.count(i)/len(nums) for i in set(nums)]
    gini = sum([p*(1-p) for p in probs])
    return gini

gini(lst)

0.743801652892562

In [3]:
df = pd.read_csv('./example_data.csv', dtype={'windy': 'str'})
df

Unnamed: 0,humility,outlook,temp,windy,play
0,high,sunny,hot,False,no
1,high,sunny,hot,True,no
2,high,overcast,hot,False,yes
3,high,rainy,mild,False,yes
4,normal,rainy,cool,False,yes
5,normal,rainy,cool,True,no
6,normal,overcast,cool,True,yes
7,high,sunny,mild,False,no
8,normal,sunny,cool,False,yes
9,normal,rainy,mild,False,yes


In [4]:
gini(df['play'].tolist())

0.4591836734693877

In [5]:
def split_dataframe(data, col):
    '''
    function: split pandas dataframe to sub-df based on data and column.
    input: dataframe, column name.
    output: a dict of splited dataframe.
    '''
    # unique value of column
    unique_values = data[col].unique()
    # empty dict of dataframe
    result_dict = {elem : pd.DataFrame for elem in unique_values}
    # split dataframe based on column value
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]
    return result_dict

In [6]:
split_dataframe(df, 'temp')

{'hot':    humility   outlook temp  windy play
 0      high     sunny  hot  FALSE   no
 1      high     sunny  hot   TRUE   no
 2      high  overcast  hot  FALSE  yes
 12   normal  overcast  hot  FALSE  yes,
 'mild':    humility   outlook  temp  windy play
 3      high     rainy  mild  FALSE  yes
 7      high     sunny  mild  FALSE   no
 9    normal     rainy  mild  FALSE  yes
 10   normal     sunny  mild   TRUE  yes
 11     high  overcast  mild   TRUE  yes
 13     high     rainy  mild   TRUE   no,
 'cool':   humility   outlook  temp  windy play
 4   normal     rainy  cool  FALSE  yes
 5   normal     rainy  cool   TRUE   no
 6   normal  overcast  cool   TRUE  yes
 8   normal     sunny  cool  FALSE  yes}

In [7]:
def choose_best_col(df, label):
    '''
    funtion: choose the best column based on infomation gain.
    input: datafram, label
    output: max infomation gain, best column, 
            splited dataframe dict based on best column.
    '''
    # Calculating label's gini index
    gini_D = gini(df[label].tolist())
    # columns list except label
    cols = [col for col in df.columns if col not in [label]]
    # initialize the max infomation gain, best column and best splited dict
    min_value, best_col = 999, None
    min_splited = None
    # split data based on different column
    for col in cols:
        splited_set = split_dataframe(df, col)
        gini_DA = 0
        for subset_col, subset in splited_set.items():
            # calculating splited dataframe label's gini index
            gini_Di = gini(subset[label].tolist())
            # calculating gini index of current feature
            gini_DA += len(subset)/len(df) * gini_Di
        
        if gini_DA < min_value:
            min_value, best_col = gini_DA, col
            min_splited = splited_set
    return min_value, best_col, min_splited
    
choose_best_col(df, 'play')

(0.34285714285714286,
 'outlook',
 {'sunny':    humility outlook  temp  windy play
  0      high   sunny   hot  FALSE   no
  1      high   sunny   hot   TRUE   no
  7      high   sunny  mild  FALSE   no
  8    normal   sunny  cool  FALSE  yes
  10   normal   sunny  mild   TRUE  yes,
  'overcast':    humility   outlook  temp  windy play
  2      high  overcast   hot  FALSE  yes
  6    normal  overcast  cool   TRUE  yes
  11     high  overcast  mild   TRUE  yes
  12   normal  overcast   hot  FALSE  yes,
  'rainy':    humility outlook  temp  windy play
  3      high   rainy  mild  FALSE  yes
  4    normal   rainy  cool  FALSE  yes
  5    normal   rainy  cool   TRUE   no
  9    normal   rainy  mild  FALSE  yes
  13     high   rainy  mild   TRUE   no})

In [15]:
class CartTree:    
    # define a Node class
    class Node:        
        def __init__(self, name):
            self.name = name
            self.connections = {}    
            
        def connect(self, label, node):
            self.connections[label] = node    
        
    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")    
    
    # print tree method
    def print_tree(self, node, tabs):
        print(tabs + node.name)        
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")    
    
    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)    
    
    # construct tree
    def construct(self, parent_node, parent_connection_label, input_data, columns):
        min_value, best_col, min_splited = choose_best_col(input_data[columns], self.label)   
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)            
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col]        
        # Recursively constructing decision trees
        for splited_value, splited_data in min_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

In [16]:
tree1 = CartTree(df, 'play')
tree1.construct_tree()
tree1.print_tree(tree1.root, "")

Root
	()
		outlook
			(sunny)
				humility
					(high)
						temp
							(hot)
								windy
									(FALSE)
										no
									(TRUE)
										no
							(mild)
								windy
									(FALSE)
										no
					(normal)
						temp
							(cool)
								windy
									(FALSE)
										yes
							(mild)
								windy
									(TRUE)
										yes
			(overcast)
				humility
					(high)
						temp
							(hot)
								windy
									(FALSE)
										yes
							(mild)
								windy
									(TRUE)
										yes
					(normal)
						temp
							(cool)
								windy
									(TRUE)
										yes
							(hot)
								windy
									(FALSE)
										yes
			(rainy)
				windy
					(FALSE)
						humility
							(high)
								temp
									(mild)
										yes
							(normal)
								temp
									(cool)
										yes
									(mild)
										yes
					(TRUE)
						humility
							(normal)
								temp
									(cool)
										no
							(high)
								temp
									(mild)
										no
