# Decision Tree in Python from scratch

In [None]:
# Training data
training_data = [
    ['Red', 10, 'Apple'],
    ['Purple', 1, 'Grape'],
    ['Yellow', 8, 'Lemon'],
    ['Red', 9, 'Apple'],
    ['Purple', 1.5, 'Grape']
]

def gini_impurity(rows):
    """Calculate the Gini Impurity for a list of rows."""
    counts = {}
    for row in rows:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    impurity = 1
    total = len(rows)
    for label in counts:
        prob_of_label = counts[label] / total
        impurity -= prob_of_label**2
    return impurity

# Test the Gini Impurity function
gini_impurity(training_data)

In [None]:
def split(rows, feature, value):
    """Split the dataset based on a feature and value."""
    true_rows = [row for row in rows if isinstance(value, str) and row[feature] == value or row[feature] >= value]
    false_rows = [row for row in rows if isinstance(value, str) and row[feature] != value or row[feature] < value]
    return true_rows, false_rows

def find_best_split(rows):
    """Find the best feature and value to split on."""
    best_gini = float('inf')
    best_feature = None
    best_value = None
    n_features = len(rows[0]) - 1

    for feature in range(n_features):
        values = set([row[feature] for row in rows])
        for value in values:
            true_rows, false_rows = split(rows, feature, value)
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue
            weight_true = len(true_rows) / len(rows)
            weight_false = len(false_rows) / len(rows)
            gini = weight_true * gini_impurity(true_rows) + weight_false * gini_impurity(false_rows)
            if gini < best_gini:
                best_gini = gini
                best_feature = feature
                best_value = value

    return best_feature, best_value

# Test the find_best_split function
find_best_split(training_data)

In [None]:
class DecisionNode:
    def __init__(self, feature=None, value=None, true_branch=None, false_branch=None, prediction=None):
        self.feature = feature
        self.value = value
        self.true_branch = true_branch
        self.false_branch = false_branch
        self.prediction = prediction

def build_tree(rows):
    # Find the best feature and value to split on
    feature, value = find_best_split(rows)

    # If we couldn't find a split, this is a leaf node
    if feature is None:
        label_counts = {}
        for row in rows:
            label = row[-1]
            if label not in label_counts:
                label_counts[label] = 0
            label_counts[label] += 1
        prediction = max(label_counts, key=label_counts.get)
        return DecisionNode(prediction=prediction)

    # If we found a split, create the true and false branches
    true_rows, false_rows = split(rows, feature, value)
    true_branch = build_tree(true_rows)
    false_branch = build_tree(false_rows)

    return DecisionNode(feature, value, true_branch, false_branch)

# Build the decision tree
tree = build_tree(training_data)
tree

In [None]:
def predict(node, row):
    """Predict the label for a row using the decision tree."""
    # If this is a leaf node, return the prediction
    if node.prediction is not None:
        return node.prediction

    # Decide whether to follow the true branch or the false branch
    if isinstance(node.value, str):
        if row[node.feature] == node.value:
            return predict(node.true_branch, row)
        else:
            return predict(node.false_branch, row)
    else:
        if row[node.feature] >= node.value:
            return predict(node.true_branch, row)
        else:
            return predict(node.false_branch, row)

# Test the prediction function on a new data point
new_data_point = ['Red', 7]
prediction = predict(tree, new_data_point)
prediction

In [None]:
# For simplicity, we'll use a subset of the training data as our test dataset
test_data = [
    ['Red', 10, 'Apple'],
    ['Purple', 1, 'Grape'],
    ['Yellow', 8, 'Lemon']
]

def calculate_accuracy(tree, test_data):
    """Calculate the accuracy of the decision tree on the test data."""
    correct_predictions = 0
    for row in test_data:
        prediction = predict(tree, row)
        if prediction == row[-1]:
            correct_predictions += 1
    return correct_predictions / len(test_data)

# Calculate the accuracy
accuracy = calculate_accuracy(tree, test_data)
accuracy

In [None]:
import random

# Generate a larger dataset
fruits = [
    ['Red', 10, 'Apple'],
    ['Purple', 1, 'Grape'],
    ['Yellow', 8, 'Lemon'],
    ['Red', 9, 'Apple'],
    ['Purple', 1.5, 'Grape']
]

larger_dataset = []
for _ in range(20):
    larger_dataset.extend(fruits)

# Shuffle the dataset
random.shuffle(larger_dataset)

# Split the dataset into 70% training and 30% testing
split_index = int(0.7 * len(larger_dataset))
training_data_larger = larger_dataset[:split_index]
test_data_larger = larger_dataset[split_index:]

len(training_data_larger), len(test_data_larger)

In [None]:
# Build the decision tree using the larger training data
tree_larger = build_tree(training_data_larger)

# Calculate the accuracy on the test data
accuracy_larger = calculate_accuracy(tree_larger, test_data_larger)
accuracy_larger

In [None]:
def print_tree(node, spacing=""):
    """Print the decision tree in a textual format."""
    # Base case: we've reached a leaf
    if node.prediction is not None:
        print(spacing + "Predict", node.prediction)
        return

    # Print the question at this node
    if isinstance(node.value, str):
        question = f"Is {node.feature} == {node.value}?"
    else:
        question = f"Is {node.feature} >= {node.value}?"
    print(spacing + question)

    # Call this function recursively on the true branch
    print(spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print(spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

# Print the tree
print_tree(tree_larger)

In [None]:
def tree_to_dot(node, dot_list=None, parent_name=None, decision=None):
    """Convert the decision tree to DOT format."""
    if dot_list is None:
        dot_list = ['digraph Tree {']

    # Base case: leaf node
    if node.prediction is not None:
        leaf_name = f'"Leaf: {node.prediction}"'
        dot_list.append(f'{parent_name} -> {leaf_name} [label="{decision}"];')
        return dot_list

    # Decision node
    if isinstance(node.value, str):
        node_name = f'"{node.feature} == {node.value}?"'
    else:
        node_name = f'"{node.feature} >= {node.value}?"'
    if parent_name is not None and decision is not None:
        dot_list.append(f'{parent_name} -> {node_name} [label="{decision}"];')

    # Recursively process true and false branches
    tree_to_dot(node.true_branch, dot_list, node_name, 'True')
    tree_to_dot(node.false_branch, dot_list, node_name, 'False')

    return dot_list

# Convert the tree to DOT format
dot_representation = '\n'.join(tree_to_dot(tree_larger)) + '\n}'
dot_representation

In [None]:
!pip install -q graphviz

In [None]:
from graphviz import Digraph

def visualize_tree(node, dot=None):
    """Visualize the decision tree using Graphviz."""
    if dot is None:
        dot = Digraph()

    # Base case: leaf node
    if node.prediction is not None:
        leaf_name = f'Leaf: {node.prediction}'
        dot.node(leaf_name, leaf_name, shape='ellipse', color='lightgreen')
        return dot, leaf_name

    # Decision node
    if isinstance(node.value, str):
        node_name = f'{node.feature} == {node.value}?'
    else:
        node_name = f'{node.feature} >= {node.value}?'
    dot.node(node_name, node_name, shape='box')

    # Recursively process true and false branches
    dot, true_name = visualize_tree(node.true_branch, dot)
    dot, false_name = visualize_tree(node.false_branch, dot)
    dot.edge(node_name, true_name, label='True')
    dot.edge(node_name, false_name, label='False')

    return dot, node_name

# Visualize the tree
dot, _ = visualize_tree(tree_larger)
dot

In [None]:
def visualize_tree_fixed(node, dot=None, parent_name=None, decision=None):
    """Visualize the decision tree using Graphviz with fixed node names."""
    if dot is None:
        dot = Digraph()

    # Generate unique node names based on the object id
    node_name = str(id(node))

    # Base case: leaf node
    if node.prediction is not None:
        label = f'Leaf: {node.prediction}'
        dot.node(node_name, label, shape='ellipse', color='lightgreen')
    else:
        # Decision node
        if isinstance(node.value, str):
            label = f'{node.feature} == {node.value}?'
        else:
            label = f'{node.feature} >= {node.value}?'
        dot.node(node_name, label, shape='box')

        # Recursively process true and false branches
        true_name = visualize_tree_fixed(node.true_branch, dot, node_name, 'True')
        false_name = visualize_tree_fixed(node.false_branch, dot, node_name, 'False')
        dot.edge(node_name, true_name, label='True')
        dot.edge(node_name, false_name, label='False')

    return node_name

# Visualize the tree with fixed node names
visualize_tree_fixed(tree_larger, dot=Digraph())