In [1]:
training_data = [["Green",3,"Mango"],["Yellow",3,"Mango"],['Red',1,'Grape'],['Red',1,'Grape'],['Yellow',3,'Lemon']]

In [2]:
#column labels 
header = ['Color','Diameter','Label']

In [3]:
def unique_vals(rows, col): 
    return set([row[col] for row in rows])

In [4]:
def class_counts(rows): 
    """Counts the number of each type of example in a dataset"""
    counts = {} 
    for row in rows: 
        #in our dataset format, the label is the last column 
        label = row[-1]
        if label not in counts: 
            counts[label] = 0
        counts[label] +=1
    return counts

In [5]:
def is_numeric(value): 
    """test if a value is numeric or not """
    return isinstance(value, int) or isinstance(value,float)

In [6]:
class Question: 
    """A question is used to partition a dataset
    This class just records a column number e.g. 0 for color and 'column value' e.g. Green. 
    The 'match' method is used to compare the feature value in an example to the feature 
    value stored in the question."""
    def __init__(self, column, value): 
        self.column = column 
        self.value = value 
    def match(self, example): 
        #compare the feature value in an example to the feature value in this question 
        val = example[self.column]
        if is_numeric(val): 
            return val>=self.value
        else: 
            return val == self.value
    def __repr__(self): 
        #helper function to print the question to a readable format
        condition = "=="
        if is_numeric(self.value): 
            condition = ">="
        return "Is %s %s %s" %(header[self.column],condition, str(self.value))
    
    
    

In [7]:
def partition(rows, question): 
    """used to partition a dataset 
    
    For each row in the dataset, check if it matches the question. If so,
    add it to 'true rows'; otherwise, add it to 'false rows'....."""
    
    true_rows, false_rows = [],[]
    for row in rows: 
        if question.match(row): 
            true_rows.append(row)
        else: 
            false_rows.append(row)
    return true_rows, false_rows

In [8]:
def gini(rows): 
    """calculate the Gini impurity for a list of rows"""
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts: 
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl ** 2
    return impurity 



In [9]:
def info_gain(left, right, current_uncertainity): 
    """Information Gain. 
    The uncertainity of the starting node, minus the weighted impurity of the two child nodes"""
    p = float(len(left)) / (len(left)+ len(right))
    return current_uncertainity - p * gini(left) -  (1-p) * gini(right)

In [10]:
def find_best_split(rows): 
    """Find the best question to ask by iterating over every feature / value and calculating information gain """
    best_gain = 0 
    best_question = None 
    current_uncertainity = gini(rows)
    n_features = len(rows[0])
    
    for col in range(n_features): 
        values = set([row[col] for row in rows])
        for val in values: 
            question = Question(col, val)
            #try splitting the dataset 
            true_rows, false_rows = partition(rows, question)
            #skip this split if it doesn't divide the dataset 
            if len(true_rows) == 0 or len(false_rows) == 0: 
                continue 
                
            #calculate the information gain from this split 
            gain = info_gain(true_rows, false_rows, current_uncertainity)
            
            if gain >= best_gain: 
                best_gain, best_question = gain, question 
    return best_gain, best_question
            

In [11]:
class Leaf: 
    """A leaf node classifies the data
    This holds a dictionary of class (e.g. "Mango") --> number of times it appears in rows 
    from training data that reach this leaf"""
    def __init__(self,rows): 
        self.predictions = class_counts(rows)
class Decision_Node(): 
    """A decision Node asks a question. This holds a reference to the quesiton, and to the child nodes"""
    def __init__(self, question,true_branch,false_branch):
        self.question = question 
        self.true_branch = true_branch 
        self.false_branch = false_branch 

            
    
    

In [12]:
def build_tree(rows): 
    """Bulids the tree """
    #Try partitioning the dataset on each of the unique attribute 
    #calculate the information gain 
    #and return the question that produces the highest gain 
    
    gain, question = find_best_split(rows)
    
    #base case: no further info gain 
    #since we can ask no further question, 
    #we'll return a leaft 
    if gain == 0 : 
        return Leaf(rows)
    # if we reach here, we have found a useful feature/ value to partition on 
    true_rows, false_rows = partition(rows, question)
    
    #recursively build the true branch 
    true_branch = build_tree(true_rows)
    
    #recursively build the false branch 
    false_branch = build_tree(false_rows)
    
    #return a question node 
    #this records the best feature / value to ask at this point, 
    #as well as the branches to follow 
    #depending on the answer 
    return Decision_Node(question, true_branch, false_branch)
    

In [13]:
def print_tree(node , spacing = " "): 
    
    #Base case: we've reached a leaf 
    if isinstance(node, Leaf): 
        print(spacing + "Predict", node.predictions)
        return 
    #print the question at this node 
    print(spacing + str(node.question))
    
    #call this function recursively on the true branch 
    print(spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")
    
    #call this function recursively on the false branch 
    print(spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")
    

In [14]:
def classify(row, node): 
    #base case: we have reached a leaf 
    if isinstance(node, Leaf): 
        return node.predictions 
    
    #decide whether to follow the true branch or the false branch 
    #compare the feature/ value stored in the node to the example we're considering 
    if node.question.match(row): 
        return classify(row, node.true_branch)
    else: 
        return classify(row, node.false_branch)

In [15]:
def print_leaf(counts): 
    """prints the predictions at a leaf"""
    total = sum(counts.values()) * 1.0
    probs = {}
    for lbl in counts.keys(): 
        probs[lbl] = str(int(counts[lbl]/total * 100)) + '%'
    return probs

In [16]:
my_tree = build_tree(training_data)

In [17]:
print_tree(my_tree)

 Is Label == Grape
 --> True:
   Predict {'Grape': 2}
 --> False:
   Is Label == Mango
   --> True:
     Predict {'Mango': 2}
   --> False:
     Predict {'Lemon': 1}


In [18]:
testing_data = [['Green',3,'Mango'],['Yellow',4,'Mango'],['Red',2,'Grape'],['Red',1,'Grape'],['Yellow',3,'Lemon']]

In [19]:
for row in testing_data: 
    print("Actual : %s, Predicted : %s"%(row[-1],print_leaf(classify(row,my_tree))))

Actual : Mango, Predicted : {'Mango': '100%'}
Actual : Mango, Predicted : {'Mango': '100%'}
Actual : Grape, Predicted : {'Grape': '100%'}
Actual : Grape, Predicted : {'Grape': '100%'}
Actual : Lemon, Predicted : {'Lemon': '100%'}
