In [1]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [13]:
def distinct(rows, col):
    """Return the set of values from a specific column in a matrix."""
    return set([row[col] for row in rows])

In [14]:
def tally(rows):
    """Counts occurrences of specific values. Returns a 
    dict of    label: count    pairs."""
    tal = {} 
    for row in rows:
        # label is rightmost column
        label = row[-1]
        if label not in tal:
            tal[label] = 0 # add an entry for a new label
        tal[label] += 1
    return tal

In [2]:
class Tester:
    """A test to choose to which of two lists a tested row should be added."""
    
    def __init__(self, col, val):
        self.col = col # index of a column in a matrix.  Identifies a 'variable'
        self.val = val # the value of a 'variable'
        
    def passes(self, test_case):
        # Compare the feature value to a test value.
        test_val = test_case[self.col]
        if isinstance(self.val, int) or isinstance(self.val,float):
            
            # use greater than or equal for numeric values
            return test_val >= self.val
        else:
            
            # use double equals for string values
            return test_val == self.val
        

    def __repr__(self):
        # Print the actual test being applied.
        if isinstance(self.val, int) or isinstance(self.val,float):
            test = ">="
        else:
            test = "=="
        return f"Test whether {str(self.val)} matches column {self.col}, using {test}"

In [4]:
def split(data_rows, test):
    """Divides a data set into two 'child' datasets, using a test.
    
    For each row, if the test returns True, that row will be added to 
    the list 'true_rows'. If the test does not return True, the row is
    added to 'false rows'.
    """
    true_rows = []
    false_rows = []
    for row_to_test in data_rows:
        if test.passes(row_to_test):
            true_rows.append(row_to_test)
        else:
            false_rows.append(row_to_test)
            
    return true_rows, false_rows

In [12]:
true_rows, false_rows = split(training_data, Tester(0, 'Yellow'))
true_rows, false_rows

([['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']],
 [['Green', 3, 'Apple'], ['Red', 1, 'Grape'], ['Red', 1, 'Grape']])

In [19]:
def gini_impurity(rows):
    """Calculate the Gini impurity for a set of values in rows.
    """
    label_counts = tally(rows)
    
    impurity = 1 # start with complete impurity
    
    # adjust impurity for each label
    for label in label_counts:
        label_probability = label_counts[label] / float(len(rows))
        impurity -= label_probability**2 
        
    return impurity

gini_impurity([['foo'], ['foo']])
gini_impurity([['baz'], ['bar']])


0.5

In [20]:
def best_split(rows):
    """Given a set of observations, iterate over each pair 
    of feature and value.
    Calculate information gain.
    Retain best pair each iteration."""
    
    # Initialize variables to track best gain and
    # the test question used to get that gain.
    best_gain = 0 
    best_test_q = None  # keep train of the feature / value that produced it
    
    uncertainty = gini_impurity(rows)
    col_count = len(rows[0]) - 1  

    for col in range(col_count):  # iterate over features

        distinct_vals = set([row[col] for row in rows])

        for val in vals:

            test_question = Tester(col, val)

            # Try to split rows into subsets
            true_rows, false_rows = split(rows, test_question)

            # If there is no split, ignore.
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            # Calculate information gain
            gain = info_gain(true_rows, false_rows, uncertainty)

            if gain > best_gain:
                best_gain, best_test_q = gain, test_question

    return best_gain, best_test_q

In [6]:
import numpy as np

test_data_file = 'data_banknote_authentication.csv'

class DecisionTreeClassifier:
    """A decision tree classifier. Uses data in a matrix to predict classifications.
    Builds decision tree using Classification And Regression Tree (CART).
    
    Uncertainty at nodes is calculated using Gini Impurity.
    
    Input is expected to consist of a 2-dimensional array, a list of 'rows'.
    
    For each row, all but the rightmost element are 'features' to be used to 
    predict a 'class'.
    """
    
    
    def __init__(self, x, Y, labels):
        self.x = x
        self.Y = Y
        self.labels = labels
    
    def fit(self, X_matrix, y_vector):
        """Builds a decision tree."""
        return X_matrix
    
    
    def predict(self):
        """Takes an input vector, and returns the class the decision tree predicts."""
        pass

In [7]:


X_matrix = [row[0:2] for row in training_data]
y_vector = [row[2] for row in training_data]

labels = ['color', 'size', 'fruit type']

foo = DecisionTreeClassifier(X_matrix, y_vector, labels)

foo.labels

['color', 'size', 'fruit type']

1.5