In [495]:

class Condition:
    def __init__(self,val,column):
        self.val = val
        self.column = column
        pass  
    def __repr__(self):
        if (isinstance(self.val, int) or isinstance(self.val, float)):
            msg = 'Is data[%d] >= %s' % (self.column,str(self.val))
        else:
            msg = 'Is data[%d] == %s' % (self.column,self.val)
        return msg
    def match(self, test_data):
        if (isinstance(self.val, int) or isinstance(self.val, float)):
            return (test_data[self.column] >= self.val) 
        else:
            return (test_data[self.column] == self.val)           

class decision_node:
    def __init__(self,condition,true_branch=None,false_branch=None, p=None):
        self.condition = condition
        self.true_branch = true_branch
        self.false_branch = false_branch 
        self.parent = p
    def __repr__(self):
        msg = 'condition: %s'%self.condition
        if (self.true_branch):
            msg += '\n'+'has true_branch'
        if (self.false_branch):
            msg += '\n'+'has false_branch'
        return(msg)

        
class decision_leaf:
    def __init__(self,data, p=None):
        self.parent = p
        count = {}
        sets = set(data)
        for row in data:
            if row not in count:
                count[row] = 0
            count[row] += 1
        self.dict = count
    def __repr__(self):
        msg = str(self.dict)
        return(msg)
class decision_tree:
    def __init__(self):
        pass
    def fit(self,X,y):
        transport = list(zip(*X))
        transport.append(y)
        self.data_with_label = list(zip(*transport))
        self.root = self.build_tree(self.data_with_label)
    def build_tree(self, data=None):
        # return decision_node
        if (data != None):
            best_gain,best_condition =self.find_bast_condition(data)
            if best_gain==0: #(best_condition == None):
                return decision_leaf([row[-1] for row in data])
            true_set,false_set = self.split_by_condition(data,best_condition)
            true_decision_node = self.build_tree(true_set) 
            false_decision_node = self.build_tree(false_set) 
            return decision_node(best_condition,true_branch=true_decision_node,false_branch=false_decision_node)
    def find_bast_condition(self,data_with_label):
        current_uncertainty = self.gini([row[-1] for row in data_with_label])
        n_features = len(data_with_label[0])-1
        max_gain = 0
        best_condition = None
        for i in range(n_features):
            sets = set([row[i] for row in data_with_label])
            for j in sets:
                c = Condition(j,i)
                left_data,right_data = self.split_by_condition(data_with_label,c)
                # I do not know why [0,..] or [..,0] the gain is big?
                if (len(left_data) == 0 or len(right_data) == 0):
                    continue
                gain = self.gini_gain(left_data, right_data, current_uncertainty)
                #print(c)
                #print('gain=%f'%gain)
                if (max_gain < gain):
                    best_condition = c
                    max_gain = gain
        #print(data_with_label)
        #print("max_gain=%f"%max_gain)
        #print("best: %s"%best_condition)
        return max_gain, best_condition

    def gini(self, data):
        count = self.gen_count_dict(data)
        impurity = 1
        for d in count:
            impurity -= (count[d] / float(len(data)))**2
        return impurity
    def gini_gain(self, left_data, right_data, current_uncertainty):
        p = float(len(left_data)) /  float(len(left_data)+len(right_data))
        return current_uncertainty - p*self.gini(left_data) - (1-p)*self.gini(right_data)
    def gen_count_dict(self,data):
        count = {}
        for val in data:
            if val not in count:
                count[val] = 0
            count[val] += 1                 
        return count
    def split_by_condition(self,data,condition):
        #return 2 sets
        true_set = []
        false_set = []
        for val in data:
            if condition.match(val):
                true_set.extend([val])
            else:
                false_set.extend([val])
        return  true_set, false_set     
    def predict(self,test_data, node=None):
        if (node == None):
            node = self.root
        if isinstance(node, decision_leaf):
            total_this_leaf = sum(node.dict.values())
            probability = {}
            for key in node.dict.keys():
                probability[key] = node.dict[key] / total_this_leaf
                #probability = 
            return probability
        if node.condition.match(test_data):
            return self.predict(test_data, node.true_branch)
        else:
            return self.predict(test_data, node.false_branch)
    def __repr__(self):
        return ''

training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

X_train = [row[0:2] for row in training_data]
y_train = [row[2] for row in training_data]

dt = decision_tree()
dt.fit(X_train,y_train)
#print(dt.root)

testing_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 4, 'Apple'],
    ['Red', 2, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]
X_test = [row[0:2] for row in testing_data]
y_test = [row[2] for row in testing_data]

for i in range(len(X_test)):
    print(dt.predict(X_test[i]))


{'Apple': 1.0}
{'Apple': 0.5, 'Lemon': 0.5}
{'Grape': 1.0}
{'Grape': 1.0}
{'Apple': 0.5, 'Lemon': 0.5}
