In [1]:
from __future__ import print_function

In [2]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]
X = [
    ['Green', 3],
    ['Yellow', 3],
    ['Red', 1],
    ['Red', 1],
    ['Yellow', 3],
]
# y = [
#     ['Apple'],
#     ['Apple'],
#     ['Grape'],
#     ['Grape'],
#     ['Lemon'],
# ]

y = [
    'Apple',
    'Apple',
    'Grape',
    'Grape',
    'Lemon',
]

In [3]:
header = ["Colour", "Diameter", "Label"]

In [4]:
def unique_vals(X, col):
    return set([row[col] for row in X])

In [5]:
def class_counts(y):
    counts = {}
    for label in y:
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [6]:
def is_numeric(value):
    if type(value) == int or type(value) == float:
        return True
    return False

In [7]:
class Question:
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
        
    def match(self, example):
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
    
    def __repr__(self):
        condition = '=='
        if is_numeric(self.value):
            condition = '>='
        return f"Is {header[self.column]} {condition} {str(self.value)}?"

In [8]:
def partition(X, y, question):
    true_X, false_X, true_y, false_y = [], [], [], []
    for row, lbl in zip(X, y):
        if question.match(row):
            true_X.append(row)
            true_y.append(lbl)
        else:
            false_X.append(row)
            false_y.append(lbl)
    return true_X, false_X, true_y, false_y

In [9]:
def gini(X, y):
    counts = class_counts(y)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(X))
        impurity -= prob_of_lbl**2
    return impurity

In [10]:
def info_gain(left_X, left_y, right_X, right_y, current_uncertainty):
    p = float(len(left_X)) / (len(left_X) + len(right_X))
    return current_uncertainty - p * gini(left_X, left_y) - (1 - p) * gini(right_X, right_y)

In [11]:
def find_best_split(X, y):
    best_gain = 0
    best_question = None
    current_uncertainty = gini(X, y)
    n_features = len(X[0])
    
    for col in range(n_features):
        values = unique_vals(X, col)
        
        for val in values:
            question = Question(col, val)
            true_X, false_X, true_y, false_y = partition(X, y, question)
            if len(true_X) == 0 or len(false_X) == 0:
                continue
            
            gain = info_gain(true_X, true_y, false_X, false_y, current_uncertainty)
            if gain > best_gain:
                best_gain, best_question = gain, question
    
    return best_gain, best_question

In [12]:
def percentage(counts):
    total = float(sum(counts.values()))
    probs = {}
    for lbl in counts:
        probs[lbl] = str(int(counts[lbl] / total * 100)) + "%"
    return probs

In [13]:
class Leaf:
    def __init__(self, y):
        self.predictions = class_counts(y)
        self.percentage = percentage(self.predictions)

In [14]:
class Decision_Node:
    def __init__(self, question, true_branch, false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [15]:
def build_tree(X, y):
    gain, question = find_best_split(X, y)
    
    if gain == 0:
        return Leaf(y)
    
    true_X, false_X, true_y, false_y = partition(X, y, question)
    
    true_branch = build_tree(true_X, true_y)
    
    false_branch = build_tree(false_X, false_y)
    
    return Decision_Node(question, true_branch, false_branch)

In [16]:
def print_tree(node, spacing=""):
    
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    print (spacing + str(node.question))

    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [17]:
def classify(X, node):
    if isinstance(node, Leaf):
        return node.predictions
    
    if node.question.match(X):
        return classify(X, node.true_branch)
    else:
        return classify(X, node.false_branch)

In [18]:
my_tree = build_tree(X, y)

In [19]:
print_tree(my_tree)

Is Colour == Red?
--> True:
  Predict {'Grape': 2}
--> False:
  Is Colour == Yellow?
  --> True:
    Predict {'Apple': 1, 'Lemon': 1}
  --> False:
    Predict {'Apple': 1}


In [20]:
value = X[1]
print(value," : ", classify(value, my_tree))

['Yellow', 3]  :  {'Apple': 1, 'Lemon': 1}


In [60]:
def print_accuracy(predictions):
    print(f"Accuracy: {predictions / len(X_test) * 100:.4}%")

In [62]:
print_accuracy(correct_predictions)

Accuracy: 77.01%
