In [1]:
import numpy as np
import pandas as pd

df=pd.read_csv('F:/decision_tree/example_data.csv')
df

Unnamed: 0,humility,outlook,temp,windy,play
0,high,sunny,hot,False,no
1,high,sunny,hot,True,no
2,high,overcast,hot,False,yes
3,high,rainy,mild,False,yes
4,normal,rainy,cool,False,yes
5,normal,rainy,cool,True,no
6,normal,overcast,cool,True,yes
7,high,sunny,mild,False,no
8,normal,sunny,cool,False,yes
9,normal,rainy,mild,False,yes


In [2]:
def gini(nums):
    nums=nums.tolist()
    probs=[nums.count(i)/len(nums) for i in set(nums)]
    gini=sum([p*(1-p) for p in probs])
    return gini

In [3]:
gini(df['play'])

0.4591836734693877

In [4]:
def split_dataframe(data,col):
    unique_values=data[col].unique()
    result_dict={elem:pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key]=data[:][data[col]==key]
    return result_dict

In [5]:
split_dataframe(df,'temp')

{'hot':    humility   outlook temp  windy play
 0      high     sunny  hot  False   no
 1      high     sunny  hot   True   no
 2      high  overcast  hot  False  yes
 12   normal  overcast  hot  False  yes,
 'mild':    humility   outlook  temp  windy play
 3      high     rainy  mild  False  yes
 7      high     sunny  mild  False   no
 9    normal     rainy  mild  False  yes
 10   normal     sunny  mild   True  yes
 11     high  overcast  mild   True  yes
 13     high     rainy  mild   True   no,
 'cool':   humility   outlook  temp  windy play
 4   normal     rainy  cool  False  yes
 5   normal     rainy  cool   True   no
 6   normal  overcast  cool   True  yes
 8   normal     sunny  cool  False  yes}

In [6]:
def choose_best_col(df,label):
    gini_D=gini(df[label])
    cols=[col for col in df.columns if col not in [label]]
    min_value,best_col=9999,None
    min_splited=None
    for col in cols:
        splited_set=split_dataframe(df,col)
        gini_DA=0
        for subset_col,subset in splited_set.items():
            gini_Di=gini(subset[label])
            gini_DA+=len(subset)/len(df)*gini_Di
        if gini_DA<min_value:
            min_value,best_col=gini_DA,col
            min_splited=splited_set
    return min_value,best_col,min_splited

In [7]:
choose_best_col(df,'play')

(0.34285714285714286,
 'outlook',
 {'sunny':    humility outlook  temp  windy play
  0      high   sunny   hot  False   no
  1      high   sunny   hot   True   no
  7      high   sunny  mild  False   no
  8    normal   sunny  cool  False  yes
  10   normal   sunny  mild   True  yes,
  'overcast':    humility   outlook  temp  windy play
  2      high  overcast   hot  False  yes
  6    normal  overcast  cool   True  yes
  11     high  overcast  mild   True  yes
  12   normal  overcast   hot  False  yes,
  'rainy':    humility outlook  temp  windy play
  3      high   rainy  mild  False  yes
  4    normal   rainy  cool  False  yes
  5    normal   rainy  cool   True   no
  9    normal   rainy  mild  False  yes
  13     high   rainy  mild   True   no})

In [8]:
class CARTTree:
    class Node:
        def __init__(self,name):
            self.name=name
            self.connections={}
            
        def connect(self,label,node):
            self.connections[label]=node
            
    def __init__(self,data,label):
        self.columns=data.columns
        self.data=data
        self.label=label
        self.root=self.Node("Root")
        
    def construct(self,parent_node,parent_connection_label,input_data,columns):
        max_value,best_col,max_splited=choose_best_col(input_data[columns],self.label)
        if not best_col:
            node=self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label,node)
            return
    
        node=self.Node(best_col)
        parent_node.connect(parent_connection_label,node)
        
        new_columns=[col for col in columns if col!=best_col]
        for splited_value,splited_data in max_splited.items():
            self.construct(node,splited_value,splited_data,new_columns)
    
    def construct_tree(self):
        self.construct(self.root," ",self.data,self.columns)
        
    def print_tree(self,node,tabs):
        print(tabs+node.name)
        for connection,child_node in node.connections.items():
            print(tabs+"\t"+"("+str(connection)+")")
            self.print_tree(child_node,tabs+"\t\t")

In [9]:
tree2=CARTTree(df,'play')
tree2.construct_tree()
tree2.print_tree(tree2.root,"")

Root
	( )
		outlook
			(sunny)
				humility
					(high)
						temp
							(hot)
								windy
									(False)
										no
									(True)
										no
							(mild)
								windy
									(False)
										no
					(normal)
						temp
							(cool)
								windy
									(False)
										yes
							(mild)
								windy
									(True)
										yes
			(overcast)
				humility
					(high)
						temp
							(hot)
								windy
									(False)
										yes
							(mild)
								windy
									(True)
										yes
					(normal)
						temp
							(cool)
								windy
									(True)
										yes
							(hot)
								windy
									(False)
										yes
			(rainy)
				windy
					(False)
						humility
							(high)
								temp
									(mild)
										yes
							(normal)
								temp
									(cool)
										yes
									(mild)
										yes
					(True)
						humility
							(normal)
								temp
									(cool)
										no
							(high)
								temp
									(mild)
										no
