In [66]:
from sklearn import datasets
import numpy as np

In [67]:
#  dataset을 분할. 각 행에 걸쳐 반복해서 속성 값이 분할 값보다 작거나 큰지 확인하고 각각 왼쪽 또는 오른쪽 그룹에 할당
def test_split(index, value, dataset): # 데이터 분할
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

In [68]:
def gini_index(groups, classes): # 지니지수 계산
    # split point 에서 모든 sample 카운트
    n_instances = float(sum([len(group) for group in groups]))
    # 각각의 그룹에 대해 gini index 가중치 sum
    gini = 0.0
    for group in groups:
        size = float(len(group))
        # 0 나누기 피하기
        if size == 0:
            continue
        score = 0.0
        # 각각의 class score를 기반으로 그룹을 점수 매김
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val)/size
            score += p * p
        # 가중치 매김
        gini += (1.0 - score) * ( size/ n_instances)
    return gini

In [69]:
# 데이터셋의 최적 split point 선택
def get_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index],dataset)
            gini = gini_index(groups,class_values)
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups}

In [70]:
# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
    left, right = node['groups']
    del(node['groups'])
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left),to_terminal(right)
        return
    # process left child
    if len(left) <=min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'], max_depth, min_size, depth +1)
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'], max_depth, min_size, depth +1)

In [71]:
# Create a terminal node value
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

In [72]:
# Build a decision tree
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return  root

In [73]:
def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['left'], depth +1)
        print_tree(node['right'], depth + 1)
    else:
        print('%s[%s]' % ((depth*' ', node)))

In [74]:
iris = datasets.load_iris()
dataset = np.c_[iris.data, iris.target]

tree = build_tree(dataset, 3, 1)
print_tree(tree)

[X3 < 3.000]
 [X1 < 5.100]
  [X1 < 4.900]
   [0.0]
   [0.0]
  [X1 < 5.100]
   [0.0]
   [0.0]
 [X4 < 1.800]
  [X3 < 5.000]
   [1.0]
   [2.0]
  [X3 < 4.900]
   [2.0]
   [2.0]
