In [1]:
# encoding:utf-8

# 决策树分类器

from enum import Enum, unique
import random, math
from collections import Counter

In [2]:
@unique
class AttrType(Enum):
    category = 0
    numeric = 1

In [3]:
def data_loader():
    insts = [['sunny','85','85','FALSE'],
             ['sunny','80','90','TRUE'],
             ['overcast','83','86','FALSE'],
             ['rainy','70','96','FALSE'],
             ['rainy','68','80','FALSE'],
             ['rainy','65','70','TRUE'],
             ['overcast','64','65','TRUE'],
             ['sunny','72','95','FALSE'],
             ['sunny','69','70','FALSE'],
             ['rainy','75','80','FALSE'],
             ['sunny','75','70','TRUE'],
             ['overcast','72','90','TRUE'],
             ['overcast','81','75','FALSE'],
             ['rainy','71','91','TRUE']]
    attrs = [('outlook', AttrType.category),
             ('temperature', AttrType.category),
             ('humidity', AttrType.category),
             ('windy', AttrType.category)]
    labels = ['no', 'no', 'yes', 'yes', 'yes', 'no', 'yes',
              'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'no']
    total_instances = zip(insts, labels)
    random.shuffle(total_instances)
    inst_num_part = int(len(total_instances) * 0.1)
    train_insts = total_instances[:7 * inst_num_part]
    test_insts = total_instances[7 * inst_num_part:]

    return attrs, train_insts, test_insts

In [4]:
# 计算熵
def calc_shannon_ent(insts):
    labels = [x[1] for x in insts]
    label_frqs = dict()
    for label in labels:
        label_frqs[label] = label_frqs.get(label, 0) + 1
    total_num = len(labels)
    shannon_ent = 0.0
    for key in label_frqs:
        prob = float(label_frqs[key]) / total_num
        shannon_ent -= prob * math.log(prob, 2)
    return shannon_ent

In [5]:
def extract_values(insts, attr_id):
    return list(set([x[0][attr_id] for x in insts]))

In [6]:
def split_instances(insts, attr_id, value):
    ext_insts = list()
    for inst in insts:
        if inst[0][attr_id] == value:
            ext_insts.append(inst)
    return ext_insts

In [7]:
def train(train_insts, attr_ids, tree=None, attr_basic_ents=None):
    if tree is None:
        tree = dict()
    labels = [x[1] for x in train_insts]
    insts = [x[0] for x in train_insts]
    if len(set(labels)) == 1: # 类别相同，停止划分
        return labels[0]
    if len(attr_ids) == 0:
        label_counter = Counter()
        label_counter.update(labels)
        return label_counter.most_common(1)[0][0]
    # 选择决策的属性
    basic_ent = calc_shannon_ent(train_insts)
    best_info_gain, best_attr = -100000, -1
    for attr in attr_ids:
        attr_values = [x[attr] for x in insts]
        uni_values = set(attr_values)
        new_ent = 0
        for value in uni_values:
            sub_insts = split_instances(train_insts, attr, value)
            prob = len(sub_insts) * 1.0 / len(insts)
            new_ent += prob * calc_shannon_ent(sub_insts)
        info_gain = (basic_ent - new_ent) / attr_basic_ents[attr]
        if info_gain > best_info_gain:
            best_info_gain, best_attr = info_gain, attr
    tree[best_attr] = dict()
    for value in set([x[best_attr] for x in insts]):
        sub_insts = split_instances(train_insts, best_attr, value)
        new_attr_ids = [x for x in attr_ids]
        new_attr_ids.remove(best_attr)
        tree[best_attr][value] = train(sub_insts, new_attr_ids,
                                       attr_basic_ents=attr_basic_ents)
    return tree

In [8]:
def classify(inst, attrs, tree):
    if isinstance(tree, str): # 叶子节点
        return tree
    curr_attr_id = list(tree.keys())[0] # 当前分类的属性ID
    inst_attr_value = inst[curr_attr_id] # 当前样本的属性值
    if inst_attr_value not in tree[curr_attr_id]: # 随机猜测
        inst_attr_value = random.choice(list(tree[curr_attr_id].keys()))
    sub_tree = tree[curr_attr_id][inst_attr_value]
    if isinstance(sub_tree, dict):
        return classify(inst, attrs, sub_tree)
    else:
        return sub_tree

In [9]:
attrs, train_insts, test_insts = data_loader()
# 计算不同属性的固有熵
attr_basic_ents = list()
for id, attr in enumerate(attrs):
    attr_values = [x[0][id] for x in train_insts]
    attr_basic_ents.append(calc_shannon_ent(attr_values))
# classifier = train(train_insts, list(range(len(attrs))), attr_basic_ents=attr_basic_ents)
classifier = train(train_insts, [0, 1, 2, 3], attr_basic_ents=attr_basic_ents)
print classifier
# {2: {'90': {0: {'overcast': 'yes', 'sunny': 'no'}}, '80': 'yes', '70': {3: {'FALSE': 'yes', 'TRUE': 'no'}}, '95': 'no', '85': 'no'}}
accs = list()
for inst, real_label in test_insts:
    pred_label = classify(inst, list(range(len(attrs))), classifier)
    if pred_label == real_label:
        accs.append(1)
    else:
        accs.append(0)
print "accuracy: ", sum(accs) * 100.0 / len(accs)

{3: {'FALSE': {0: {'rainy': 'yes', 'overcast': 'yes', 'sunny': 'no'}}, 'TRUE': 'no'}}
accuracy:  42.8571428571
