In [1]:
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

This will be a playground for exploring sklearn's Decision Tree Classifier.

In [2]:
iris = load_iris()

In [3]:
iris.keys()

dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])

In [4]:
iris['feature_names']

['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [5]:
iris_petal_dims = iris['data'][:, 2:]

In [6]:
type(iris_petal_dims)

numpy.ndarray

In [7]:
iris_dict = {'petal length (cm)': iris_petal_dims[:, 0],
             'petal width (cm)': iris_petal_dims[:, 1],
             'labels': iris['target']}

In [8]:
iris_df = pd.DataFrame(iris_dict)

In [9]:
iris_df

Unnamed: 0,labels,petal length (cm),petal width (cm)
0,0,1.4,0.2
1,0,1.4,0.2
2,0,1.3,0.2
3,0,1.5,0.2
4,0,1.4,0.2
5,0,1.7,0.4
6,0,1.4,0.3
7,0,1.5,0.2
8,0,1.4,0.2
9,0,1.5,0.1


In [10]:
from collections import Counter
def _gini(labels):
    """
    Calculates the gini impurity for a set of labels.
    """
    total = len(labels)
    label_counts = Counter(labels).values()
    return 1 - sum((p / total)**2
                   for p in label_counts
                   if p)

In [11]:
labels = _gini(iris_df['labels'])

In [12]:
labels

0.6666666666666667

In [17]:
newy = []
minim = []
cost_min = float('inf')
for x in np.arange(.2, 7, .1):
    left = iris_df[iris_df['petal length (cm)'] <= x]['labels']
    right = iris_df[~(iris_df['petal length (cm)'] <= x)]['labels']
    cost = len(left) * _gini(left) / 150 + len(right) * _gini(right) / 150
    if cost < cost_min:
        cost_min = cost
        minim = []
    elif cost == cost_min:
        minim.append(x)
    newy.append(cost)

In [18]:
cost_min

0.3333333333333333

In [19]:
np.mean(minim)

2.4500000000000011

In [20]:
def _gini_cost(labeled_data, feature_name):
    """
    Calculate the cost function as part of the CART algorithm.
    """
    minim = []
    cost_min = float('inf')
    len_labeled_data = len(labeled_data)
    feature_data = [row[0][feature_name] for row in labeled_data]
    max_feature = int(max(feature_data)) * 100
    min_feature = int(min(feature_data)) * 100
    feature_range = max_feature - min_feature
    span = (i / 100 for i in range(min_feature, max_feature + 1))
    for x in span:
        left = []
        right = []
        for i, num in enumerate(feature_data):
            if num <= x:
                left.append(i)
            else:
                right.append(i)
        left_labels = [labeled_data[idx][1] for idx in left]
        right_labels = [labeled_data[idx][1] for idx in right]
        cost = len(left_labels) * _gini(left_labels) / len_labeled_data + len(right_labels) * _gini(right_labels) / len_labeled_data
        if cost < cost_min:
            cost_min = cost
            minim = []
            minim.append(x)
            left_node = [labeled_data[idx] for idx in left]
            right_node = [labeled_data[idx] for idx in right]
        elif cost == cost_min:
            minim.append(x)
    avg_minimums = sum(minim) / len(minim)
    return avg_minimums, left_node, right_node

In [21]:
labels = iris_df['labels'].tolist()

In [22]:
zipped_columns = zip(iris_df['petal length (cm)'].tolist(), iris_df['petal width (cm)'].tolist())

In [23]:
features = [{'petal length (cm)': length,
             'petal width (cm)': width} for length, width in zipped_columns]

In [24]:
zipped_iris_data = list(zip(features, labels))

In [25]:
avgs, left, right =_gini_cost(zipped_iris_data, 'petal length (cm)')

In [26]:
class_labels = [i[1] for i in left]

In [27]:
from collections import Counter

In [28]:
Counter(class_labels)

Counter({0: 50})

In [29]:
listy = [0 if i < 50 else 1 for i in range(100)]

In [30]:
for i in range(50):
    listy.append(2)

In [31]:
_gini(listy)

0.6666666666666667

In [34]:
zero_dtc = DecisionTreeClassifier(max_depth=1)

In [35]:
zero_dtc.fit(iris_petal_dims, iris.target)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=1,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

In [36]:
zero_dtc.predict([[5, 1.5]])

array([1])

In [None]:
import sys
sys.path.remove('c:\\users\\kurtrm\\projects\\phone_network_graph')

In [38]:
zipped_iris_data

[({'petal length (cm)': 1.4, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.4, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.3, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.5, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.4, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.7, 'petal width (cm)': 0.4}, 0),
 ({'petal length (cm)': 1.4, 'petal width (cm)': 0.3}, 0),
 ({'petal length (cm)': 1.5, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.4, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.5, 'petal width (cm)': 0.1}, 0),
 ({'petal length (cm)': 1.5, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.6, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.4, 'petal width (cm)': 0.1}, 0),
 ({'petal length (cm)': 1.1, 'petal width (cm)': 0.1}, 0),
 ({'petal length (cm)': 1.2, 'petal width (cm)': 0.2}, 0),
 ({'petal length (cm)': 1.5, 'petal width (cm)': 0.4}, 0),
 ({'petal length (cm)': 1.3, 'petal width (cm)': 0.4}, 0

In [56]:
"""
Module containing the DecisionTree class.
"""
from collections import Counter
from math import log2


class Node:
    """
    Node that contains all information required in order to make predictions and pass
    information on for further evaluation to other nodes.
    """

    def __init__(self, threshold, samples, values, classification, gini=None):
        """
        """
        self.threshold = threshold
        self.samples = samples
        self.values = values
        self.classification = classification
        self.left = None
        self.right = None
        self.gini = None


class DecisionTree:
    """
    A crude implentation of a decision tree that can either use the Gini index
    or entropy (information gain).
    """

    def __init__(self, max_depth=1):
        """
        """
        self.root = None
        self.max_depth = max_depth

    def train(self, X, y, method='gini'):
        """
        """
        pass

    def predict():
        """
        """
        pass

    def _gini(self, labels):
        """
        Calculates the gini impurity for a set of labels.
        """
        total = len(labels)
        label_counts = Counter(labels).values()
        return 1 - sum((p / total)**2
                       for p in label_counts
                       if p)

    def _entropy(self, labels):
        """
        Calculates entropy for a set of labels.
        """
        total = len(labels)
        label_counts = Counter(labels).values()
        return -sum((p / total) * log2(p / total)
                    for p in label_counts
                    if p)

    def _id3():
        """
        Builds a decision tree using the ID3 algorithm.
        """
        pass

    def _cart(self, labeled_data):
        """
        Classification and Regression Tree implementation.
        """
        # Starts with all the data and runs the cost function on each feature to get the best starting split.
        # 

        self._depth = 0
        labels = [label[1] for label in labeled_data]
        lowest_cost = float('inf')
        for feature in labeled_data[0][0].keys():
            gini_calculations = self._gini_cost(labeled_data, feature)
            if gini_calculations[0] < lowest_cost:
                lowest_cost, threshold, left_samples, right_samples = gini_calculations

        if self.root is None:
            import pdb; pdb.set_trace()
            self.root = Node(threshold, len(labeled_data), labeled_data, left_samples[0][1], self._gini(labels))
            node = self.root
        else:
            node = Node(threshold, len(labeled_data), labeled_data, left_samples[0][1], self._gini(labels))

        if self._depth < self.max_depth:
            self._depth += 1
            node.left = self._cart(left_samples)
            node.right = self._cart(right_samples)
        else:
            return node

    def _gini_cost(self, labeled_data, feature_name):
        """
        Calculate the cost function as part of the CART algorithm.
        """
        minim = []
        cost_min = float('inf')
        len_labeled_data = len(labeled_data)
        feature_data = [row[0][feature_name] for row in labeled_data]
        max_feature = int(max(feature_data)) * 100
        min_feature = int(min(feature_data)) * 100
        span = (i / 100 for i in range(min_feature, max_feature + 1))
        for x in span:
            left = []
            right = []
            for i, num in enumerate(feature_data):
                if num <= x:
                    left.append(i)
                else:
                    right.append(i)
            left_labels = [labeled_data[idx][1] for idx in left]
            right_labels = [labeled_data[idx][1] for idx in right]
            cost = len(left_labels) * self._gini(left_labels) / len_labeled_data + len(right_labels) * self._gini(right_labels) / len_labeled_data
            if cost < cost_min:
                cost_min = cost
                minim = []
                minim.append(x)
                left_samples = [labeled_data[idx] for idx in left]
                right_samples = [labeled_data[idx] for idx in right]
            elif cost == cost_min:
                minim.append(x)
        avg_minimums = sum(minim) / len(minim)
        return cost_min, avg_minimums, left_samples, right_samples


'\nModule containing the DecisionTree class.\n'

In [57]:
dct = DecisionTree()

In [58]:
dct.max_depth = 2

In [59]:
dct._cart(zipped_iris_data)

> <ipython-input-56-c4dfdef07c23>(91)_cart()
-> self.root = Node(threshold, len(labeled_data), labeled_data, left_samples[0][1], self._gini(labels))
(Pdb) exit


BdbQuit: 