In [1]:
# Calculate entropy of given dataset
from math import log


def calc_info_entropy(dataset):
    freq = dict() # Frequdncies of each probal situation
    for vect in dataset:
        label = vect[-1]
        if label not in freq:
            freq[label] = 0
        freq[label] += 1
    total = len(dataset)
    entropy = 0
    for k in freq:
        prob = freq[k] / total
        entropy -= prob * log(prob, 2)
    return entropy    

In [2]:
# Create sample dataset
def create_dataset():
    dataset = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no']
    ]
    labels = ['no surfacing', 'flippers']
    return dataset, labels

dataset, labels = create_dataset()
dataset

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

In [3]:
calc_info_entropy(dataset)

0.9709505944546686

In [4]:
# Split dataset per feature
def split_dataset(dataset, feat_axis, feat_val):
    result = list()
    for sample in dataset:
        if sample[feat_axis] == feat_val:
            rest = sample[ : feat_axis]
            rest.extend(sample[feat_axis + 1 : ])
            result.append(rest)
    return result

split_dataset(dataset, 0, 1)

[[1, 'yes'], [1, 'yes'], [0, 'no']]

In [5]:
# Find the best split way
def find_best_split(dataset):
    n_features = len(dataset[0]) - 1
    org_entropy = calc_info_entropy(dataset)
    best_split_feat = -1
    max_info_gain = 0
    for feat_axis in range(n_features):
        feat_vals = [x[feat_axis] for x in dataset]
        uq_feat_vals = set(feat_vals)
        for val in uq_feat_vals:
            split_try = split_dataset(dataset, feat_axis, val)
            new_entropy = calc_info_entropy(split_try)
            info_gain = org_entropy - new_entropy
            if info_gain > max_info_gain:
                max_info_gain = info_gain
                best_split_feat = feat_axis
    return best_split_feat

find_best_split(dataset)

0

In [6]:
# If all of the features have been introduced to split,
# but there are different labels still,
# we will use voting to decide the classify
import operator


def final_vote(labels):
    votes = dict()
    for label in labels:
        if label not in votes:
            votes[label] = 0
        votes[label] += 1
    sorted_votes = sorted(
        votes.items(),
        key=operator.itemgetter(1),
        reverse=True
    )
    return sorted_votes[0][0]

In [7]:
a = {'a': 1, 'b': 2}
b = sorted(
    a.items(),
    key=operator.itemgetter(1)
)
a.items()

dict_items([('a', 1), ('b', 2)])

In [8]:
# Create decision tree
def create_tree(dataset, labels):
    # copy labels to prevent changes on original
    feat_labels = labels[:]
    # Recurse base
    class_labels = [x[-1] for x in dataset]
    if len(set(class_labels)) == 1:
        return class_labels[0]
    if len(labels) == 0:
        return final_vote(class_labels)
    # Recurse body
    best_feat = find_best_split(dataset)
    best_feat_label = feat_labels[best_feat]
    del(feat_labels[best_feat])
    sub_labels = feat_labels[:]
    tree = {best_feat_label: dict()}
    for val in set([x[best_feat] for x in dataset]):
        tree[best_feat_label][val] = \
            create_tree(
                split_dataset(dataset, best_feat, val),
                sub_labels
            )
    return tree

In [9]:
# Test create_tree() function
create_tree(dataset, labels)

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

### Draw the Decision Tree

In [10]:
# Try to figure part of tree
import matplotlib.pyplot as plt


decision_node = dict(boxstyle="sawtooth", fc="white")
leaf_node = dict(boxstyle="round4", fc="white")
arrow_args = dict(arrowstyle="<-")

def plot_node(node_text, cent_pt, parent_pt, node_type):
    create_plot.ax1.annotate(
        node_text,
        xy=parent_pt,
        xycoords='axes fraction',
        xytext=cent_pt,
        textcoords="axes fraction",
        va="center",
        ha="center",
        bbox=node_type,
        arrowprops=arrow_args
    )
    
def create_plot():
    fig = plt.figure(1, facecolor="white")
    fig.clf()
    create_plot.ax1 = plt.subplot(111, frameon=False)
    plot_node("Decision Node", (0.5, 0.1), (0.1, 0.5), decision_node)
    plot_node("Leaf Node", (0.8, 0.1), (0.3, 0.8), leaf_node)
    plt.show()
        
    
create_plot()


<Figure size 640x480 with 1 Axes>

In [11]:
# Get num of leaf nodes
def get_leaf_num(tree):
    leaf_num = 0
    # recurse base
    if type(tree) is not dict:
        leaf_num = 1
    # recurse body
    else:
        root = list(tree.keys())[0]
        for k in tree[root]:
            leaf_num += get_leaf_num(tree[root][k])
    return leaf_num

get_leaf_num(create_tree(dataset, labels))

3

In [12]:
# Get depth of tree
def get_tree_depth(tree):
    tree_depth = 0
    # Recurse base
    if type(tree) is not dict:
        tree_depth = 0
    # Recurse body
    else:
        root = list(tree.keys())[0]
        max_depth = 0
        for k in tree[root]:
            depth = get_tree_depth(tree[root][k])
            if depth > max_depth:
                max_depth = depth
        tree_depth = max_depth + 1
    return tree_depth


get_tree_depth(create_tree(dataset, labels))

2

In [14]:
# Use decision tree to classify
def tree_classify(tree, x_in, x_labels):
    # Recurse base
    if type(tree) is not dict:
        return tree
    # Recurse body
    root = list(tree.keys())[0]
    feat_index = x_labels.index(root)
    for feat_val in tree[root]:
        if x_in[feat_index] == feat_val:
            return tree_classify(tree[root][feat_val], x_in, x_labels)

        
tree_classify(
    create_tree(dataset, labels),
    [1, 1],
    ['no surfacing', 'flippers']
)

'yes'

In [16]:
# Store decision tree
import pickle


def store_tree(tree, store_file):
    with open(store_file, 'wb') as f:
        pickle.dump(tree, f)
        
        
store_tree(
    create_tree(dataset, labels),
    'fish_decision_tree.pkl'
)

In [17]:
# Load decisiin tree from pickle fie
def load_tree(store_file):
    with open(store_file, 'rb') as f:
        return pickle.load(f)
    
tree = load_tree('fish_decision_tree.pkl')
print(tree)

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}


### Lenses

In [18]:
# Format lenses dataset
def get_lenses_dataset():
    dataset = list()
    labels = ['age', 'prescript', 'astigmatic', 'tearrate']
    with open('lenses.txt', 'r') as f:
        for line in f:
            dataset.append([x.strip() for x in line.split('\t')])
    return dataset, labels

get_lenses_dataset()
    

([['young', 'myope', 'no', 'reduced', 'no lenses'],
  ['young', 'myope', 'no', 'normal', 'soft'],
  ['young', 'myope', 'yes', 'reduced', 'no lenses'],
  ['young', 'myope', 'yes', 'normal', 'hard'],
  ['young', 'hyper', 'no', 'reduced', 'no lenses'],
  ['young', 'hyper', 'no', 'normal', 'soft'],
  ['young', 'hyper', 'yes', 'reduced', 'no lenses'],
  ['young', 'hyper', 'yes', 'normal', 'hard'],
  ['pre', 'myope', 'no', 'reduced', 'no lenses'],
  ['pre', 'myope', 'no', 'normal', 'soft'],
  ['pre', 'myope', 'yes', 'reduced', 'no lenses'],
  ['pre', 'myope', 'yes', 'normal', 'hard'],
  ['pre', 'hyper', 'no', 'reduced', 'no lenses'],
  ['pre', 'hyper', 'no', 'normal', 'soft'],
  ['pre', 'hyper', 'yes', 'reduced', 'no lenses'],
  ['pre', 'hyper', 'yes', 'normal', 'no lenses'],
  ['presbyopic', 'myope', 'no', 'reduced', 'no lenses'],
  ['presbyopic', 'myope', 'no', 'normal', 'no lenses'],
  ['presbyopic', 'myope', 'yes', 'reduced', 'no lenses'],
  ['presbyopic', 'myope', 'yes', 'normal', 'hard

In [20]:
dataset, labels = get_lenses_dataset()
create_tree(dataset, labels)

{'tearrate': {'normal': {'astigmatic': {'no': {'age': {'presbyopic': {'prescript': {'hyper': 'soft',
        'myope': 'no lenses'}},
      'pre': 'soft',
      'young': 'soft'}},
    'yes': {'age': {'presbyopic': {'prescript': {'hyper': 'no lenses',
        'myope': 'hard'}},
      'pre': {'prescript': {'hyper': 'no lenses', 'myope': 'hard'}},
      'young': 'hard'}}}},
  'reduced': 'no lenses'}}