# Decision Tree from scratch

## Import Libraries
Let's bring in the tools we need to build our decision tree!

In [None]:
import numpy as np
import pandas as pd
from collections import Counter

## Create Sample Data
Here's our fun dataset about people's preferences - we want to predict if they'll love the movie "Cool As Ice"!

In [None]:
data = {
    'Loves Popcorn': ['Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No'],
    'Loves Soda': ['Yes', 'No', 'Yes', 'Yes', 'Yes', 'No', 'No'],
    'Age': [7, 12, 18, 35, 38, 50, 83],
    'Loves Cool As Ice': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'No']
}


## Display the Data
Let's see what our data looks like in a nice table!

In [None]:
df = pd.DataFrame(data)
df

Unnamed: 0,Loves Popcorn,Loves Soda,Age,Loves Cool As Ice
0,Yes,Yes,7,No
1,Yes,No,12,No
2,No,Yes,18,Yes
3,No,Yes,35,Yes
4,Yes,Yes,38,Yes
5,Yes,No,50,No
6,No,No,83,No


## Gini Impurity Function
This measures how "mixed up" a group is. Pure groups (all same answer) = 0, mixed groups = higher numbers.

In [None]:
def gini(labels):
  total = len(labels)
  if total ==0:
    return 0
  impurity = 1
  unique_labels = set(labels)
  for label in unique_labels:
    p = labels.count(label)/total
    impurity -= p**2
  return impurity

## Split Data Function
This divides our data into two groups based on a question (like "Do you love popcorn?").

In [None]:
def split_data_set(dataset,feature,value):
  left=[]
  right=[]
  for index,row in dataset.iterrows():
    if row[feature] == value:
      left.append(row)
    else:
      right.append(row)
  return left,right



## Weighted Gini Score
This tells us how good a split is by looking at both groups together. Lower scores = better splits!

In [None]:
def weighted_gini(left,right,label_name):
  if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
    # if they are df's, extract the labels directly from the specified column
    # .tolist() converts the pandas Series (the column) into a Python list
    left_labels = left[label_name].tolist()
    right_labels= right[label_name].tolist()
  else:
    # if they are not DataFrames then assume they are lists of rows
    # this handles the case where split_data_set was used
    left_labels = [row[label_name] for row in left]
    right_labels= [row[label_name] for row in right]
  gi_left = gini(left_labels)
  gi_right = gini(right_labels)

  total=len(left_labels)+len(right_labels)
  return round((len(left_labels)/total)*gi_left + (len(right_labels)/total)*gi_right,3)



## Helper Functions
These help us work with numbers (like age) by finding good cutoff points to split on.

In [None]:
def is_numeric(series):
    return pd.api.types.is_numeric_dtype(series)

def get_thresholds(column):
    values = sorted(column.unique())
    # print([(values[i] + values[i+1])/2 for i in range(len(values)-1)])
    return [(values[i] + values[i+1])/2 for i in range(len(values)-1)]


# Finding the best split


## Find the Best Question
This tries all possible questions and picks the one that creates the purest groups!

In [None]:
def find_best_split(dataset, label_name):
    best_gini = 1
    best_feature = None
    best_value = None
    best_groups = None
    features = [col for col in dataset.columns if col != label_name]

    for feature in features:
        if is_numeric(dataset[feature]):
            thresholds = get_thresholds(dataset[feature])
            for threshold in thresholds:
                left = dataset[dataset[feature] <= threshold]
                right = dataset[dataset[feature] > threshold]
                gini_score = weighted_gini(left, right, label_name)
                print(f"Feature: {feature}, Threshold: {threshold}, Gini: {gini_score}")
                if gini_score < best_gini:
                    best_gini = gini_score
                    best_feature = feature
                    best_value = threshold
                    best_groups = (left, right)
        else:
            values = set(dataset[feature])
            for value in values:
                left = dataset[dataset[feature] == value]
                right = dataset[dataset[feature] != value]
                if left.empty or right.empty:
                    continue
                gini_score = weighted_gini(left, right, label_name)
                if gini_score < best_gini:
                    best_gini = gini_score
                    best_feature = feature
                    best_value = value
                    best_groups = (left, right)

    return best_feature, best_value, best_groups


# Recursion for building the decision tree

## Build the Decision Tree
This is the main function that creates our tree by asking questions and making branches!

In [None]:
def build_tree(dataset, label_name, depth=0, max_depth=5):
    # if dataset is pure or empty then return a leaf
    labels = dataset[label_name].tolist()
    if len(set(labels)) == 1 or len(dataset) == 0:
        return {'type': 'leaf', 'class': majority_class(labels)}

    #  stopping based on depth
    if depth >= max_depth:
        return {'type': 'leaf', 'class': majority_class(labels)}

    #find the best split
    best_feature, best_value, best_groups = find_best_split(dataset, label_name)

    if best_feature is None or best_groups is None:
        return {'type': 'leaf', 'class': majority_class(labels)}

    left_group, right_group = best_groups

    # recursively build left and right branches
    left_branch = build_tree(left_group, label_name, depth + 1, max_depth)
    right_branch = build_tree(right_group, label_name, depth + 1, max_depth)

    # return a decision node
    return {
        'type': 'node',
        'feature': best_feature,
        'value': best_value,
        'left': left_branch,
        'right': right_branch
    }


In [None]:
def majority_class(labels):
    return Counter(labels).most_common(1)[0][0]


## Find Most Common Answer
When we can't split anymore, this picks the most common answer in the group.

In [None]:
decision_tree = build_tree(df, label_name='Loves Cool As Ice', max_depth=3)
print(decision_tree)

[np.float64(9.5), np.float64(15.0), np.float64(26.5), np.float64(36.5), np.float64(44.0), np.float64(66.5)]
Feature: Age, Threshold: 9.5, Gini: 0.429
Feature: Age, Threshold: 15.0, Gini: 0.343
Feature: Age, Threshold: 26.5, Gini: 0.476
Feature: Age, Threshold: 36.5, Gini: 0.476
Feature: Age, Threshold: 44.0, Gini: 0.343
Feature: Age, Threshold: 66.5, Gini: 0.429
[np.float64(12.5), np.float64(26.5), np.float64(36.5)]
Feature: Age, Threshold: 12.5, Gini: 0.0
Feature: Age, Threshold: 26.5, Gini: 0.25
Feature: Age, Threshold: 36.5, Gini: 0.333
{'type': 'node', 'feature': 'Loves Soda', 'value': 'No', 'left': {'type': 'leaf', 'class': 'No'}, 'right': {'type': 'node', 'feature': 'Age', 'value': np.float64(12.5), 'left': {'type': 'leaf', 'class': 'No'}, 'right': {'type': 'leaf', 'class': 'Yes'}}}


## Let's Build Our Tree!
Time to create our decision tree and see what it looks like!

In [None]:
def predict(tree, sample):
    if tree['type'] == 'leaf':
        return tree['class']

    feature = tree['feature']
    value = tree['value']

    # numeric feature
    if isinstance(value, (int, float, np.float64)):
        if sample[feature] <= value:
            return predict(tree['left'], sample)
        else:
            return predict(tree['right'], sample)
    else:  # categorical feature
        if sample[feature] == value:
            return predict(tree['left'], sample)
        else:
            return predict(tree['right'], sample)


## Make Predictions
This function follows the tree from top to bottom to predict new data!

In [None]:
test_sample = {
    'Loves Popcorn': 'No',
    'Loves Soda': 'Yes',
    'Age': 20
    }

result = predict(decision_tree, test_sample)
print("Predicted:", result)


Predicted: Yes


## Test with Multiple People
Let's try predicting for several different people to see how our tree works!

In [None]:
#now testing with different samples
test_people = [
    {'Loves Popcorn': 'Yes', 'Loves Soda': 'Yes', 'Age': 25},
    {'Loves Popcorn': 'No', 'Loves Soda': 'No', 'Age': 60},
    {'Loves Popcorn': 'Yes', 'Loves Soda': 'No', 'Age': 15},
    {'Loves Popcorn': 'No', 'Loves Soda': 'Yes', 'Age': 40}
]

for i, person in enumerate(test_people, 1):
    prediction = predict(decision_tree, person)
    print(f"Person {i}: {person} -> Prediction: {prediction}")

## Check Our Tree's Accuracy
Let's see how well our tree predicts the original data!

In [None]:
# testing accuracy on original data
correct_predictions = 0
total_predictions = len(df)

print("Testing on original data:")
for index, row in df.iterrows():
    # converitng row to dict
    sample = {
        'Loves Popcorn': row['Loves Popcorn'],
        'Loves Soda': row['Loves Soda'],
        'Age': row['Age']
    }
    
    predicted = predict(decision_tree, sample)
    actual = row['Loves Cool As Ice']
    
    if predicted == actual:
        correct_predictions += 1
        result = "✓ Correct"
    else:
        result = "✗ Wrong"
    
    print(f"Person {index+1}: Predicted {predicted}, Actual {actual} - {result}")

accuracy = (correct_predictions / total_predictions) * 100
print(f"\nAccuracy: {correct_predictions}/{total_predictions} = {accuracy:.1f}%")

## Understanding Our Tree
Let's create a simple function to visualize what questions our tree asks!

In [None]:
def print_tree(tree, depth=0):
    """Print the tree in an easy-to-read format"""
    indent = "  " * depth
    
    if tree['type'] == 'leaf':
        print(f"{indent}→ Predict: {tree['class']}")
    else:
        feature = tree['feature']
        value = tree['value']
        
        if isinstance(value, (int, float)):
            print(f"{indent}Is {feature} <= {value}?")
            print(f"{indent}├─ Yes:")
            print_tree(tree['left'], depth + 1)
            print(f"{indent}└─ No:")
            print_tree(tree['right'], depth + 1)
        else:
            print(f"{indent}Is {feature} = {value}?")
            print(f"{indent}├─ Yes:")
            print_tree(tree['left'], depth + 1)
            print(f"{indent}└─ No:")
            print_tree(tree['right'], depth + 1)

print("Our Decision Tree Structure:")
print("=" * 40)
print_tree(decision_tree)