In [220]:
import math
import sys
import pandas as pd
import numpy as np

In [240]:
def entropy(elements, base):
    length = float(len(elements))
    probs = [elements.count(element)/length for element in set(elements)]
    return -sum([p * math.log(p, base) for p in probs])

In [400]:
def split_dataframe(data, header):
    unique_values = data[header].unique()
    result_dict = {elem : pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key] = data[:][data[header] == key]
    return result_dict

In [408]:
raw_data = pd.DataFrame(pd.read_csv("data.tsv", sep="\t"))
headers = list(raw_data)[1:]
target_header = "PlayTennis"
base = 2

In [409]:
min_value, min_header = sys.maxint, None
for header in headers[1:-1]:
    h = entropy(data[header].tolist(), base)
    if h < min_value:
        min_value, min_header = h, header
print min_value, min_header

0.985228136034 Wind


In [410]:
def tree_split(data):
    length = float(len(data))
    max_value, max_header = -sys.maxint, None
    max_splited = None
    H = entropy(data[target_header].tolist(), base)
    for header in list(data)[:-1]:
        splited_set = split_dataframe(data, header)
        print splited_set
        IS = 0
        for subset_header, subset in splited_set.items():
            subset_length = float(len(subset))
            subset_h = entropy(subset[target_header].tolist(), base)
            print header, subset_header, subset_h
            IS += subset_length/length * subset_h
        IG = H - IS
        print header, H, IS
        if IG > max_value:
            max_value, max_header = IG, header
            max_splited = splited

    return max_value, max_header, max_splited


In [411]:
max_value, max_header, max_splited = tree_split(data[headers])
new_headers = [header for header in headers if header != max_header]

for split_value, split_data in max_splited.items():
    max_value, max_header, max_splited = tree_split(split_data[new_headers])

{'Overcast':      Outlook Temperature Humidity    Wind PlayTennis
2   Overcast         Hot     High    Weak        Yes
6   Overcast        Cool   Normal  Strong        Yes
11  Overcast        Mild     High  Strong        Yes
12  Overcast         Hot   Normal    Weak        Yes, 'Sunny':    Outlook Temperature Humidity    Wind PlayTennis
0    Sunny         Hot     High    Weak         No
1    Sunny         Hot     High  Strong         No
7    Sunny        Mild     High    Weak         No
8    Sunny        Cool   Normal    Weak        Yes
10   Sunny        Mild   Normal  Strong        Yes, 'Rain':    Outlook Temperature Humidity    Wind PlayTennis
3     Rain        Mild     High    Weak        Yes
4     Rain        Cool   Normal    Weak        Yes
5     Rain        Cool   Normal  Strong         No
9     Rain        Mild   Normal    Weak        Yes
13    Rain        Mild     High  Strong         No}
Outlook Overcast -0.0
Outlook Sunny 0.970950594455
Outlook Rain 0.970950594455
Outlook 0.9

In [412]:
class ID3Tree(object):
    class Node(object):
        def __init__(self, name):
            self.name = name
            self.connections = {}
        
        def connect(self, label, node):
            self.connections[label] = node
        
    
    def __init__(self, data, target_header, base=2):
        self.headers = list(data)[1:]
        self.data = data
        self.target_header = target_header
        self.base = base
        self.root = self.Node("Root")
        
    def build(self):
        self.step(self.root, "", self.data, self.headers)
        
        
    def step(self, parent_node, parent_connection_label, input_data, headers):
        max_value, max_header, max_splited = tree_split(input_data[headers])
        
        if not max_header:
            return

        node = self.Node(max_header)
        parent_node.connect(parent_connection_label, node)
        
        new_headers = [header for header in headers if header != max_header]
         
        for splited_value, splited_data in max_splited.items():
            self.step(node, splited_value, splited_data, new_headers)
        
        
        

In [413]:
tree = ID3Tree(raw_data, "PlayTennis")
tree.build()

{'Overcast':      Outlook Temperature Humidity    Wind PlayTennis
2   Overcast         Hot     High    Weak        Yes
6   Overcast        Cool   Normal  Strong        Yes
11  Overcast        Mild     High  Strong        Yes
12  Overcast         Hot   Normal    Weak        Yes, 'Sunny':    Outlook Temperature Humidity    Wind PlayTennis
0    Sunny         Hot     High    Weak         No
1    Sunny         Hot     High  Strong         No
7    Sunny        Mild     High    Weak         No
8    Sunny        Cool   Normal    Weak        Yes
10   Sunny        Mild   Normal  Strong        Yes, 'Rain':    Outlook Temperature Humidity    Wind PlayTennis
3     Rain        Mild     High    Weak        Yes
4     Rain        Cool   Normal    Weak        Yes
5     Rain        Cool   Normal  Strong         No
9     Rain        Mild   Normal    Weak        Yes
13    Rain        Mild     High  Strong         No}
Outlook Overcast -0.0
Outlook Sunny 0.970950594455
Outlook Rain 0.970950594455
Outlook 0.9

In [414]:
def print_tree(node, tabs):
    print tabs + node.name
    for connection, child_node in node.connections.items():
        print_tree(child_node, tabs+"\t")

In [415]:
print_tree(tree.root, "")

Root
	Outlook
		Temperature
			Humidity
				Wind
				Wind
			Humidity
				Wind
				Wind
		Humidity
			Temperature
				Wind
				Wind
			Temperature
				Wind
				Wind
