### Learn Decision Tree from Scratch

In [2]:
# Import Dependencies
import pandas as pd

Read Dataset

In [3]:
data = pd.read_csv('../data/dummy.csv')


In [4]:
# Get the column names
header = data.columns.tolist()

# Get the column indices based on the header
column_indices = {header[i]: i for i in range(len(header))}

In [5]:
header

['color', 'diameter', 'fruit_name']

In [6]:
column_indices

{'color': 0, 'diameter': 1, 'fruit_name': 2}

In [8]:
def unique_vals(rows, col):
    """Find the unique values for a column in a dataset."""
    return set(rows.iloc[:, col])

def class_counts(rows):
    """Counts the number of each type of example in a dataset."""
    counts = {}
    for row in rows.iloc[:, -1]:
        if row not in counts:
            counts[row] = 0
        counts[row] += 1
    return counts

In [11]:
unique_vals(data, 0)

{'green', 'orange', 'red', 'yellow'}

In [12]:
class_counts(data)

{'grape': 4, 'lemon': 7, 'orange': 2, 'apple': 4}

In [None]:
unique_vals()

In [13]:
def is_numeric(value):
    """Simple function to check if a value is numeric."""
    return isinstance(value, (int, float))

Question
- Question that we will use to partition the dataset

In [15]:
class Question:
    """A Question is used to partition a dataset. To measure the value within the features."""

    def __init__(self, column, value):
        self.column = column
        self.value = value

    def match(self, example):
        # This is where we compare the feature value of sample and the feature value of "question"
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        # Helper method to make it readable
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (header[self.column], condition, str(self.value))

In [38]:
# Question for color 
Question(0, 'red') # Color is indice 0

Is color == red?

In [39]:
# Question for diameter
Question(1, 2) # Diameter in indice 1

Is diameter >= 2?

In [47]:
def partition(rows, question):
    """Partitions a dataset based on a given question. If true, add to true_rows"""
    true_rows = rows[rows.apply(lambda row: question.match(row), axis=1)]
    false_rows = rows[~rows.apply(lambda row: question.match(row), axis=1)]
    return true_rows, false_rows

In [48]:
true_rows_demo, false_rows_demo = partition(data, Question(0, 'red'))

In [49]:
# All rows containing 'red'
true_rows_demo

Unnamed: 0,color,diameter,fruit_name
0,red,1,grape
4,red,1,grape
5,red,3,apple
10,red,1,grape
11,red,1,grape
12,red,3,apple


In [50]:
false_rows_demo

Unnamed: 0,color,diameter,fruit_name
1,orange,1,lemon
2,orange,3,orange
3,green,3,apple
6,yellow,1,lemon
7,green,1,lemon
8,green,1,lemon
9,green,3,apple
13,yellow,1,lemon
14,green,1,lemon
15,orange,3,orange


In [52]:
def gini_calc(rows):
    """Calculate the Gini Impurity for a list of rows."""
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

In [59]:
# Calculate Gini impurity of the dataset
gini_impurity = gini_calc(data)
gini_impurity

0.7058823529411766

In [65]:
# Sample case of 0 impurity
dummy_apple = pd.DataFrame({'fruit_name': ['apple', 'apple']})
# dummy_apple
gini_calc(dummy_apple)


0.0

In [66]:
# Sample case of 0.5 impurity
dummy_sample = pd.DataFrame({'fruit_name': ['apple','orange']})
gini_calc(dummy_sample)


0.5

In [68]:
def info_gain(left, right, current_uncertainty):
    """Information Gain. How much information do we gain after partitioning.

    The uncertainty of the starting node, minus the weighted impurity of
    two child nodes.
    """
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini_calc(left) - (1 - p) * gini_calc(right)

In [72]:
# Partition of 'Green'
true_rows, false_rows = partition(data, Question(0, 'green'))
info_gain(true_rows, false_rows, gini_impurity)

0.07843137254901983

In [73]:
true_rows

Unnamed: 0,color,diameter,fruit_name
3,green,3,apple
7,green,1,lemon
8,green,1,lemon
9,green,3,apple
14,green,1,lemon
16,green,1,lemon


In [75]:
# Partition of 'Red'
true_rows, false_rows = partition(data, Question(0, 'red'))
info_gain(true_rows, false_rows, gini_impurity)

0.2067736185383246

In [76]:
true_rows

Unnamed: 0,color,diameter,fruit_name
0,red,1,grape
4,red,1,grape
5,red,3,apple
10,red,1,grape
11,red,1,grape
12,red,3,apple


In [77]:
def find_best_split(rows):
    """Find the best question to ask by iterating over every feature/value pair and calculating the information gain."""
    best_gain = 0
    best_question = None
    current_uncertainty = gini_calc(rows)
    n_features = len(rows.columns) - 1

    for col in range(n_features):
        values = unique_vals(rows, col)

        for val in values:
            question = Question(col, val)
            true_rows, false_rows = partition(rows, question)

            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            gain = info_gain(true_rows, false_rows, current_uncertainty)

            if gain >= best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question

In [78]:
best_gain, best_question = find_best_split(data)


In [79]:
best_gain

0.2495543672014262

In [80]:
best_question

Is diameter >= 3?

In [82]:
class Leaf:
    """Leaf node to classify data."""

    def __init__(self, rows):
        self.predictions = class_counts(rows)

class Decision_Node:
    """Decision Node to ask a question."""

    def __init__(self, question, true_branch, false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [84]:
def build_tree(rows):
    """Builds the decision tree recursively."""

    # Find the best question
    gain, question = find_best_split(rows)

    # Base case: when there's no more information to gain
    if gain == 0:
        return Leaf(rows)

    true_rows, false_rows = partition(rows, question)

    # Build true branch recursively
    true_branch = build_tree(true_rows)

    # Build false branch recursively
    false_branch = build_tree(false_rows)

    return Decision_Node(question, true_branch, false_branch)

In [103]:
def print_tree(node, spacing="=="):
    """Prints the decision tree. The purpose is to visually understand the decision making of the tree."""
    if isinstance(node, Leaf):
        print("Prediction:", node.predictions)
        return

    print(spacing + str(node.question))
    print(spacing + "==> True:")
    print_tree(node.true_branch, spacing + "==")
    print(spacing + "==> False:")
    print_tree(node.false_branch, spacing + "==")

In [104]:
my_tree = build_tree(data)


In [105]:
print("DECISION TREE RULES:")
print_tree(my_tree)

DECISION TREE RULES:
==Is diameter >= 3?
====> True:
====Is color == orange?
Prediction: {'orange': 2}
Prediction: {'apple': 4}
====> False:
====Is color == red?
Prediction: {'grape': 4}
Prediction: {'lemon': 7}


In [108]:
def classify(row, node):
    """Classifies a row using the decision tree."""

    if isinstance(node, Leaf):
        return node.predictions

    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

In [129]:
classify(['yellow', 3], my_tree)

{'apple': 4}

In [130]:
def print_leaf(counts):
    """Prints the predictions at a leaf."""
    total = sum(counts.values())
    probs = {}

    for lbl in counts.keys():
        probs[lbl] = str(round(counts[lbl] / total * 100, 2)) + "%"
    return probs

In [132]:
print_leaf(classify(['yellow', 3], my_tree))

{'apple': '100.0%'}

# FIN