In [1]:
header = ['色', '直径', 'ラベル']

training_data = [
    ['緑', 3, 'リンゴ'],
    ['黄色', 3, 'リンゴ'],
    ['紫', 1, 'ぶどう'],
    ['紫', 1, 'ぶどう'],
    ['黄色', 3, 'レモン'],
]

In [2]:
def is_numeric(value):
    return isinstance(value, int) or isinstance(value, float)

In [3]:
class Question:
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
        
    def match(self, exampled):
        val = exampled[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
        
    def __repr__(self):
        condition = ">=" if is_numeric(self.value) else "=="
        return "{0}は{1}{2}ですか？".format(header[self.column], condition, self.value)

In [4]:
question = Question(0, '緑')
question

色は==緑ですか？

In [5]:
question.match(training_data[0])

True

In [6]:
def partition(rows, question):
    true_rows = []
    false_rows = []
    
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
            
    return true_rows, false_rows

In [7]:
good, bad = partition(training_data, question)
good

[['緑', 3, 'リンゴ']]

In [8]:
bad

[['黄色', 3, 'リンゴ'], ['紫', 1, 'ぶどう'], ['紫', 1, 'ぶどう'], ['黄色', 3, 'レモン']]

In [9]:
def count_labels(rows):
    counts = {}
    for row in rows:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
            
        counts[label] += 1
    
    return counts

In [10]:
count_results = count_labels(training_data)
count_results

{'ぶどう': 2, 'リンゴ': 2, 'レモン': 1}

In [11]:
count_results['ぶどう']

2

In [12]:
def gini_impurity(rows):
    # https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity
    counts = count_labels(rows)
    impurity = 1
    for label in counts:
        prob_of_label = counts[label] / float(len(rows))
        impurity -= prob_of_label**2
    return impurity

In [13]:
no_mixing = [['猫'], ['猫'], ['猫']]
gini_impurity(no_mixing)

0.0

In [14]:
mixing = [['犬'], ['猫'], ['魚']]
gini_impurity(mixing)

0.6666666666666665

In [15]:
def info_gain(left, right, current_uncertainty):
    p_left = float(len(left)) / (len(left) + len(right))
    p_right = 1 - p_left
    
    return current_uncertainty - p_left * gini_impurity(left) - p_right * gini_impurity(right)

In [16]:
uncertainty = gini_impurity(training_data)
uncertainty

0.6399999999999999

In [17]:
true_rows, false_rows = partition(training_data, Question(0, '赤'))
info_gain(true_rows, false_rows, uncertainty)

0.0

In [18]:
true_rows, false_rows = partition(training_data, Question(0, '緑'))
info_gain(true_rows, false_rows, uncertainty)

0.1399999999999999

In [19]:
true_rows, false_rows = partition(training_data, Question(0, '黄色'))
info_gain(true_rows, false_rows, uncertainty)

0.17333333333333323

In [20]:
def find_best_split(rows):
    best_info_gain = 0
    best_question = None
    
    current_uncertainty = gini_impurity(rows)
    n_features = len(rows[0]) - 1 # columns minus 1 for the label
    
    for col in range(n_features):
        values = set([row[col] for row in rows]) # unique values in the column
        
        for val in values:
            question = Question(col, val)
            
            # Split
            true_rows, false_rows = partition(rows, question)

            # Skip if it doesn't result in dividing the dataset
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue
                
            gain = info_gain(true_rows, false_rows, current_uncertainty)
            
            # > is okay too
            if gain >= best_info_gain:
                best_question = question
                best_info_gain = gain
                
        return best_info_gain, best_question

In [21]:
best_gain, best_question = find_best_split(training_data)
best_gain, best_question

(0.37333333333333324, 色は==紫ですか？)

In [22]:
class Leaf:
    def __init__(self, rows):
        self.predictions = count_labels(rows)

In [23]:
class DecisionNode:
    def __init__(self, question, true_branch, false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [24]:
def build_tree(rows):
    gain, question = find_best_split(rows)
    
    if gain == 0:
        return Leaf(rows)
    
    true_rows, false_rows = partition(rows, question)
    
    true_branch = build_tree(true_rows)
    
    false_branch = build_tree(false_rows)
    
    return DecisionNode(question, true_branch, false_branch)

In [25]:
def print_tree(node, spacing=""):
    
    # Base case: reached a Leaf
    if isinstance(node, Leaf):
        print(spacing + "Predict", node.predictions)
        return
    
    # Print question at this node
    print (spacing + str(node.question))
    
    # Call this function recursively on the true branch
    print(spacing + "--> True:")
    print_tree(node.true_branch, spacing + "     ")
    
    # Call this function recursively on the false branch
    print(spacing + "--> False:")
    print_tree(node.false_branch, spacing + "     ")

In [26]:
def print_leaf(counts):
    total = sum(counts.values()) * 1.0
    probs = {}
    for label in counts.keys():
        probs[label] = str(int(counts[label] / total * 100)) + '%'
    return probs

def classify(row, node):
    
    # Reached a leaf
    if isinstance(node, Leaf):
        return node.predictions
    
    # Decide to follow true or false branch
    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

In [27]:
my_tree = build_tree(training_data)

In [28]:
print_tree(my_tree)

色は==紫ですか？
--> True:
     Predict {'ぶどう': 2}
--> False:
     色は==黄色ですか？
     --> True:
          Predict {'リンゴ': 1, 'レモン': 1}
     --> False:
          Predict {'リンゴ': 1}


In [31]:
testing_data = [
    ['緑', 3, 'リンゴ'],
    ['黄色', 4, 'リンゴ'],
    ['紫', 2, 'ぶどう'],
    ['紫', 1, 'ぶどう'],
    ['黄色', 3, 'レモン'],
]

In [32]:
for row in testing_data:
    result = classify(row, my_tree)
    print("Actual: {}, Predicted: {}".format(row[-1], print_leaf(result)))

Actual: リンゴ, Predicted: {'リンゴ': '100%'}
Actual: リンゴ, Predicted: {'リンゴ': '50%', 'レモン': '50%'}
Actual: ぶどう, Predicted: {'ぶどう': '100%'}
Actual: ぶどう, Predicted: {'ぶどう': '100%'}
Actual: レモン, Predicted: {'リンゴ': '50%', 'レモン': '50%'}
