In [1]:
import pandas as pd
import numpy as np
from typing import Union
import unittest
import numpy as np

In [69]:
class LeafNode:
    """A Leaf node classifies data.
    This holds a dictionary of class (e.g., "Apple") -> number of times
    it appears in the rows from the training data that reach this leaf.
    """

    def __init__(self, counts):
        self._predictions = counts

class Decision_Node:
    """A Decision Node asks a question.
    This holds a reference to the question, and to the two child nodes.
    """

    def __init__(self, question, branch1, branch2):
        self._question = question
        self._true_branch = branch1
        self._false_branch = branch2

class DescisionTree:

    def __init__(self, data: pd.DataFrame, target_col: str) -> None:
        self._df = data
        self._target_col = target_col
        self._tree = build_tree(data, target_col)

    def _predict(self, node, observation):

        # If we have a leaf node, check preds and assign on majority
        if isinstance(node, LeafNode):
            # List of labels, take most common
            preds = node._predictions
            print(f"Found Leaf node with preds: {preds}")
            counts = _count_labels(preds)
            most_common = max(counts.items(), key=lambda x: x[1])
            return most_common

        # Check which branch to follow
        column = node._question.column
        value = node._question.value

        print(f"Checking question of {column} >= {value}")

        # If we have some gain, push forward with another tree
        if isinstance(node._question.value, str):
            if observation[column] == value:
                return self._predict(node._true_branch, observation)
            else:
                return self._predict(node._false_branch, observation)
        else:
            if observation[column] >= value:
                return self._predict(node._true_branch, observation)
            else:
                return self._predict(node._false_branch, observation)


    def predict(self, observation: pd.Series):
        # Base case, reached a leaf
        print(f"Getting Pred for {observation}")
        pred_label = self._predict(self._tree, observation)
        return pred_label

def build_tree(data: pd.DataFrame, target_col: str):
    # 1. Find best split that gives lowest impurity
    # 2. Split data on that
    # 3. Call again on splits, recursively
    # 4. Add leaf nodes when totally pure
    feats = data.drop(target_col, axis=1)
    target = data[target_col]
    gain, question, total_counts = _find_best_split(feats, target)

    if gain==0:
        leaf = LeafNode(total_counts)
        print(f"No gain found, returing Leaf Node of {total_counts}")
        return leaf

    # If we have some gain, push forward with another tree
    if isinstance(question.value, str):
        data_split1 = data[data[question.column] == question.value]
        data_split2 = data[data[question.column] != question.value]
    else:
        data_split1 = data[data[question.column] >= question.value]
        data_split2 = data[data[question.column] < question.value]

    print(f"Split1: {data_split1}")
    print(f"Split2: {data_split2}")

    branch1 = build_tree(data_split1, target_col)
    branch2 = build_tree(data_split2, target_col)

    return Decision_Node(question, branch1, branch2)


def _find_best_split(data: pd.DataFrame, target: pd.Series):
    print("Finding the best split")

    # Calculate each split at each column
    # For each unique value, split the data and calculate info gain
    best_gain = 0.0
    best_q = Question()
    total_counts = None
    for col_name in data.columns:

        # Get the column we are looking at
        column: pd.Series = data[col_name]
        # Find Unqiue Values
        uniques = pd.unique(column)

        for unique in uniques:

            print(f"Iteration: Col=='{col_name}'; Critera >= {unique}")
            # Split data
            q = Question()
            split = column.apply(lambda x: q._question(x, unique, col_name))

            group1 = data[split]
            group1_target = target[split]
            group2 = data[np.invert(split)]
            group2_target = target[np.invert(split)]

            # Has to split the data somehow or doesn't help
            if not all([len(group1), len(group2)]):
                continue

            gain = _info_gain(target, group1_target, group2_target)
            if gain > best_gain:
                print(f"New Best Found with gain @ {gain}\nGreater than last of {best_gain}")
                best_gain = gain
                best_q = q
                total_counts = list(group1_target)+list(group2_target)

    if not total_counts:
        total_counts = list(group1_target)+list(group2_target)

    print(f"Found best question at column '{best_q.column}' and value '{best_q.value}'")
    print(f"With Info Gain of {best_gain}")
    print(f"And Target Counts: {total_counts}")
    return best_gain, best_q, total_counts

def _count_labels(labels: np.array):
    counts = dict()
    for label in labels:
        if not counts.get(label):
            counts[label] = 1
        else:
            counts[label] += 1
    return counts

def _gini_impurity(labels: np.array):
    """Calculates Gini Impurity
    Get Unique Counts of values
    Sum the squared probabilities of getting those unique values
    """
    counts = _count_labels(labels)
    impurity = 1.0
    for _, count in counts.items():
        impurity -= (count/len(labels))**2
    return impurity

def _info_gain(current: pd.Series, group1: pd.Series, group2: pd.Series):

    current_impurity = _gini_impurity(labels=current)
    group1_impurity = _gini_impurity(labels=group1)
    group2_impurity = _gini_impurity(labels=group2)

    # print(group1)
    # print(group2)
    # print(f"Current Impurity = {current_impurity}")
    # print(f"Group1 Impurity = {group1_impurity}")
    # print(f"Group2 Impurity = {group2_impurity}")

    p = float(len(group1) / (len(group1) + len(group2)))
    weighted = current_impurity - (p*group1_impurity) - ((1-p)*group2_impurity)

    print(f"Total Info Gain = {weighted}\n")
    return weighted

def print_tree(node):

    # Base case, reached a leaf
    if isinstance(node, LeafNode):
        print ("Predict", node._predictions)
        return

    # Print the question at this node
    print (f"{node._question.column} >= {node._question.value}")

    # Call this function recursively on the true branch
    print ('--> True:')
    print_tree(node._true_branch)

    # Call this function recursively on the false branch
    print ('--> False:')
    print_tree(node._false_branch)


class Question:

    def __init__(self) -> None:
        self.column = None
        self.value = None

    def _question(self, data: Union[str, int, float], gt_equal: Union[str, int, float], column: str) -> bool:
        """Method to ask a question about a certain point.
            Either if strings match or if a number is gt or equal to another

        Args:
            data (Union[str, int, float]): Data we want to ask a question about
            gt_equal (Union[str, int, float]): Condition to check equivalence for

        Raises:
            ValueError: Mismatching Types

        Returns:
            bool: If data matched or was greater than condition
        """
        self.value = gt_equal
        self.column = column

        if isinstance(data, np.generic):
            data = data.item()
        if isinstance(gt_equal, np.generic):
            gt_equal = gt_equal.item()

        if isinstance(data, (int, float)) and isinstance(gt_equal, (int, float)):
            return data >= gt_equal
        elif isinstance(data, str) and isinstance(gt_equal, str):
            return data.lower() == gt_equal.lower()
        else:
            raise ValueError(f"Data and check not of same type; data={type(data)}; gt_equal={type(gt_equal)}")

In [70]:
fruits = pd.DataFrame(
            {
                "fruit": ["grape"]*5 + ["orange"]*3 + ["pear"]*4,
                "size": [1,2,1,3,4] + [5,6,7] + [4,5,6,5],
                "color": (["green"]*2 + ["red"]*3) + (["orange"]*3) + (["green"]*3 + ["yellow"])
            }
        )
dt = DescisionTree(data=fruits, target_col="fruit")
# dt.build_tree()


Finding the best split
Iteration: Col=='size'; Critera >= 1
Iteration: Col=='size'; Critera >= 2
Total Info Gain = 0.10277777777777763

New Best Found with gain @ 0.10277777777777763
Greater than last of 0.0
Iteration: Col=='size'; Critera >= 3
Total Info Gain = 0.17129629629629622

New Best Found with gain @ 0.17129629629629622
Greater than last of 0.10277777777777763
Iteration: Col=='size'; Critera >= 4
Total Info Gain = 0.25694444444444436

New Best Found with gain @ 0.25694444444444436
Greater than last of 0.17129629629629622
Iteration: Col=='size'; Critera >= 5
Total Info Gain = 0.26388888888888884

New Best Found with gain @ 0.26388888888888884
Greater than last of 0.25694444444444436
Iteration: Col=='size'; Critera >= 6
Total Info Gain = 0.1157407407407407

Iteration: Col=='size'; Critera >= 7
Total Info Gain = 0.07702020202020188

Iteration: Col=='color'; Critera >= green
Total Info Gain = 0.09563492063492057

Iteration: Col=='color'; Critera >= red
Total Info Gain = 0.17129629

In [73]:
print_tree(dt._tree)
dt.predict(pd.Series({"size": 1, "color": "yellow"}))

color >= orange
--> True:
Predict ['orange', 'orange', 'orange']
--> False:
size >= 4
--> True:
color >= red
--> True:
Predict ['grape']
--> False:
Predict ['pear', 'pear', 'pear', 'pear']
--> False:
Predict ['grape', 'grape', 'grape', 'grape']
Getting Pred for size          1
color    yellow
dtype: object
Checking question of color >= orange
Checking question of size >= 4
Found Leaf node with preds: ['grape', 'grape', 'grape', 'grape']


('grape', 4)

In [5]:
class TestDecisionTree(unittest.TestCase):

    def setUp(self) -> None:
        self.df = pd.DataFrame()
        self.tree = DescisionTree(df, np.array([1,2]))

    def test_question_number(self):

        self.assertTrue(_question(2, gt_equal=1))
        self.assertTrue(_question(2, gt_equal=2))
        self.assertFalse(_question(2, gt_equal=3))
        self.assertFalse(_question(2.0, gt_equal=3))

    def test_question_string(self):

        self.assertTrue(_question("apple", gt_equal="apple"))
        self.assertTrue(_question("apple", gt_equal="Apple"))
        self.assertFalse(_question("apple", gt_equal="Orange"))

    def test_question_mismatch(self):

        with self.assertRaises(ValueError):
            self.assertTrue(_question("apple", gt_equal=1))
            self.assertTrue(_question(1, gt_equal="apple"))

class TestGiniImpurity(unittest.TestCase):

    def setUp(self) -> None:
        fruits = pd.DataFrame(
            {
                "fruit": ["grape"]*5 + ["orange"]*3 + ["pear"]*4,
                "size": [1,2,1,3,4] + [5,6,7] + [4,5,6,5],
                "color": (["green"]*2 + ["red"]*3) + (["orange"]*3) + (["green"]*3 + ["yellow"])
            }
        )
        self.tree = DescisionTree(fruits, target_col="fruit")

    def test_impurity(self):
        self.assertEqual(_gini_impurity(["apple"]), 0.0)
        self.assertEqual(_gini_impurity(["apple", "orange"]), 0.5)
        self.assertEqual(_gini_impurity(["apple", "orange", "pear", "grape"]), 0.75)


unittest.main(argv=[''], verbosity=10, exit=False)

test_question_mismatch (__main__.TestDecisionTree) ... ERROR
test_question_number (__main__.TestDecisionTree) ... ERROR
test_question_string (__main__.TestDecisionTree) ... ERROR
test_impurity (__main__.TestGiniImpurity) ... 

Finding the best split
Iteration: Col=='size'; Critera >= 1
Iteration: Col=='size'; Critera >= 2
1      grape
3      grape
4      grape
5     orange
6     orange
7     orange
8       pear
9       pear
10      pear
11      pear
Name: fruit, dtype: object
0    grape
2    grape
Name: fruit, dtype: object
Current Impurity = 0.6527777777777777
Group1 Impurity = 0.66
Group2 Impurity = 0.0
Total Info Gain = 0.10277777777777763

Iteration: Col=='size'; Critera >= 3
3      grape
4      grape
5     orange
6     orange
7     orange
8       pear
9       pear
10      pear
11      pear
Name: fruit, dtype: object
0    grape
1    grape
2    grape
Name: fruit, dtype: object
Current Impurity = 0.6527777777777777
Group1 Impurity = 0.6419753086419753
Group2 Impurity = 0.0
Total Info Gain = 0.17129629629629622

Iteration: Col=='size'; Critera >= 4
4      grape
5     orange
6     orange
7     orange
8       pear
9       pear
10      pear
11      pear
Name: fruit, dtype: object
0    grape
1    grape
2    gra

ERROR

ERROR: test_question_mismatch (__main__.TestDecisionTree)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\Matt\AppData\Local\Temp\ipykernel_28268\1517668327.py", line 5, in setUp
    self.tree = DescisionTree(df, np.array([1,2]))
NameError: name 'df' is not defined

ERROR: test_question_number (__main__.TestDecisionTree)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\Matt\AppData\Local\Temp\ipykernel_28268\1517668327.py", line 5, in setUp
    self.tree = DescisionTree(df, np.array([1,2]))
NameError: name 'df' is not defined

ERROR: test_question_string (__main__.TestDecisionTree)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\Matt\AppData\Local\Temp\ipykernel_28268\1517668327.py", line 5, in setUp
    self.tree = DescisionTree(df, np.array([1,2]))
NameErro

<unittest.main.TestProgram at 0x1907f2ffa00>

In [223]:
s= 1
s >= 2

False