In [4]:
import numpy as np
import pandas as pd
from math import log

df=pd.read_csv('F:/decision_tree/example_data.csv')
df

Unnamed: 0,humility,outlook,temp,windy,play
0,high,sunny,hot,False,no
1,high,sunny,hot,True,no
2,high,overcast,hot,False,yes
3,high,rainy,mild,False,yes
4,normal,rainy,cool,False,yes
5,normal,rainy,cool,True,no
6,normal,overcast,cool,True,yes
7,high,sunny,mild,False,no
8,normal,sunny,cool,False,yes
9,normal,rainy,mild,False,yes


In [9]:
def entropy(ele):
    ele=ele.tolist()
    #计算每个类别的概率
    probs=[ele.count(i)/len(ele) for i in set(ele)]
    #计算交叉熵
    entropy=-sum([prob*log(prob,2) for prob in probs])
    return entropy

In [10]:
entropy(df['play'])

0.9402859586706309

In [11]:
def split_dataframe(data,col):
    unique_values=data[col].unique()
    result_dict={elem:pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key]=data[:][data[col]==key]
    return result_dict

In [20]:
split_dataframe(df,'windy')

{False:    humility   outlook  temp  windy play
 0      high     sunny   hot  False   no
 2      high  overcast   hot  False  yes
 3      high     rainy  mild  False  yes
 4    normal     rainy  cool  False  yes
 7      high     sunny  mild  False   no
 8    normal     sunny  cool  False  yes
 9    normal     rainy  mild  False  yes
 12   normal  overcast   hot  False  yes,
 True:    humility   outlook  temp  windy play
 1      high     sunny   hot   True   no
 5    normal     rainy  cool   True   no
 6    normal  overcast  cool   True  yes
 10   normal     sunny  mild   True  yes
 11     high  overcast  mild   True  yes
 13     high     rainy  mild   True   no}

In [21]:
def choose_best_col(df,label):
    entropy_D=entropy(df[label])
    cols=[col for col in df.columns if col not in [label]]
    max_value,best_col=-9999,None
    max_splited=None
    for col in cols:
        splited_set=split_dataframe(df,col)
        entropy_DA=0
        for subset_col,subset in splited_set.items():
            entropy_Di=entropy(subset[label])
            entropy_DA+=len(subset)/len(df)*entropy_Di
        info_gain=entropy_D-entropy_DA
        if info_gain>max_value:
            max_value,best_col=info_gain,col
            max_splited=splited_set
    return max_value,best_col,max_splited

In [22]:
choose_best_col(df,'play')

(0.2467498197744391, 'outlook', {'sunny':    humility outlook  temp  windy play
  0      high   sunny   hot  False   no
  1      high   sunny   hot   True   no
  7      high   sunny  mild  False   no
  8    normal   sunny  cool  False  yes
  10   normal   sunny  mild   True  yes,
  'overcast':    humility   outlook  temp  windy play
  2      high  overcast   hot  False  yes
  6    normal  overcast  cool   True  yes
  11     high  overcast  mild   True  yes
  12   normal  overcast   hot  False  yes,
  'rainy':    humility outlook  temp  windy play
  3      high   rainy  mild  False  yes
  4    normal   rainy  cool  False  yes
  5    normal   rainy  cool   True   no
  9    normal   rainy  mild  False  yes
  13     high   rainy  mild   True   no})

In [33]:
class ID3Tree:
    class Node:
        def __init__(self,name):
            self.name=name
            self.connections={}
            
        def connect(self,label,node):
            self.connections[label]=node
            
    def __init__(self,data,label):
        self.columns=data.columns
        self.data=data
        self.label=label
        self.root=self.Node("Root")
        
    def construct(self,parent_node,parent_connection_label,input_data,columns):
        max_value,best_col,max_splited=choose_best_col(input_data[columns],self.label)
        if not best_col:
            node=self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label,node)
            return
    
        node=self.Node(best_col)
        parent_node.connect(parent_connection_label,node)
        
        new_columns=[col for col in columns if col!=best_col]
        for splited_value,splited_data in max_splited.items():
            self.construct(node,splited_value,splited_data,new_columns)
    
    def construct_tree(self):
        self.construct(self.root," ",self.data,self.columns)
        
    def print_tree(self,node,tabs):
        print(tabs+node.name)
        for connection,child_node in node.connections.items():
            print(tabs+"\t"+"("+str(connection)+")")
            self.print_tree(child_node,tabs+"\t\t")

In [34]:
tree1=ID3Tree(df,'play')
tree1.construct_tree()
tree1.print_tree(tree1.root,"")

Root
	( )
		outlook
			(sunny)
				humility
					(high)
						temp
							(hot)
								windy
									(False)
										no
									(True)
										no
							(mild)
								windy
									(False)
										no
					(normal)
						temp
							(cool)
								windy
									(False)
										yes
							(mild)
								windy
									(True)
										yes
			(overcast)
				humility
					(high)
						temp
							(hot)
								windy
									(False)
										yes
							(mild)
								windy
									(True)
										yes
					(normal)
						temp
							(cool)
								windy
									(True)
										yes
							(hot)
								windy
									(False)
										yes
			(rainy)
				windy
					(False)
						humility
							(high)
								temp
									(mild)
										yes
							(normal)
								temp
									(cool)
										yes
									(mild)
										yes
					(True)
						humility
							(normal)
								temp
									(cool)
										no
							(high)
								temp
									(mild)
										no
