In [14]:
import numpy as np
import pandas as pd
from collections import Counter

In [17]:
def entropy(s):
    counts=np.bincount(s)
    percentage=counts/len(s)
    
    entropy=0
    
    for pct in percentage:
        if pct>0:
            entropy+=pct*np.log2(pct)
            
    return -entropy



In [18]:
s = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]

entropy(s)

0.8812908992306927

In [21]:
def information_gain(parent,left_child,right_child):
    
    num_left=len(left_child)/len(parent)
    num_right=len(right_child)/len(parent)
    
    information_gain=entropy(parent)-num_left*entropy(left_child)-num_right*entropy(right_child)
    
    return information_gain



In [22]:
parent = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
left_child = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]
right_child = [0, 0, 0, 0, 1, 1, 1, 1]


information_gain(parent,left_child,right_child)

0.1809371414656561

In [23]:
class Node:
    
    def __init__(self,feature=None,threshold=None,data_left=None,data_right=None,gain=None,value=None):
        
        self.feature=feature
        self.threshold=threshold
        self.data_left=data_left
        self.data_right=data_right
        self.gain=gain
        self.value=value
        

In [52]:
class DecisionTree:
    
    def __init__(self,min_samples_split=2,max_depth=5):
        
        self.min_samples_split=min_samples_split
        self.max_depth=max_depth
        self.root=None
        
    @staticmethod
    def _entropy(s):
        counts=np.bincount(np.array(s,dtype=np.int64))
        percentage=counts/len(s)

        entropy=0

        for pct in percentage:
            if pct>0:
                entropy+=pct*np.log2(pct)

        return -entropy
    
    def _information_gain(self,parent,left_child,right_child):
    
        num_left=len(left_child)/len(parent)
        num_right=len(right_child)/len(parent)
    
        information_gain=self._entropy(parent)-num_left*self._entropy(left_child)-num_right*self._entropy(right_child)
    
        return information_gain

    def _best_split(self,X,y):
        
        best_split={}
        best_info_gain=-1
        n_rows,n_cols=X.shape
        
        for f_idx in range(n_cols):
            X_curr=X[:,f_idx]
            
            for threshold in np.unique(X_curr):
                
                df=np.concatenate((X,y.reshape(1,-1).T),axis=1)
                df_left=np.array([row for row in df if row[f_idx]<=threshold])
                df_right=np.array([row for row in df if row[f_idx]>threshold])
                
                if len(df_left)>0 and len(df_right)>0:
                    
                    y=df[:,-1]
                    y_left=df_left[:,-1]
                    y_right=df_right[:,-1]
                    
                    
                    gain=self._information_gain(y,y_left,y_right)
                    
                    if gain> best_info_gain:
                        
                        best_split={'feature_index':f_idx,
                                    'threshold':threshold,
                                    'df_right':df_right,
                                    'df_left':df_left,
                                    'gain': gain
                                   }
                        
                        best_info_gain=gain
                        
        return best_split
    
    def _build(self,X,y,depth=0):
        
        n_rows,n_cols=X.shape
        best=self._best_split(X,y)
        
        if n_rows>=self.min_samples_split and depth <= self.max_depth:
            if best['gain']>0:

                left=self._build(
                    X=best['df_left'][:,:-1],
                    y=best['df_left'][:,-1],
                    depth=depth+1)

                right=self._build(
                    X=best['df_right'][:,:-1],
                    y=best['df_right'][:,-1],
                    depth=depth+1)

                return Node(
                    feature=best['feature_index'],
                    threshold=best['threshold'],
                    data_left=left,
                    data_right=right,
                    gain=best['gain'])

        return Node(value=Counter(y).most_common(1)[0][0])
    
    def fit(self,X,y):
        self.root=self._build(X,y)
        
    def _predict(self,x,tree):
        
        if tree.value!=None:
            return tree.value
        feature_value=x[tree.feature]
        
        if feature_value<=tree.threshold:
            return self._predict(x=x,tree=tree.data_left)
        
        if feature_value>tree.threshold:
            return self._predict(x=x,tree=tree.data_right)        
    
    def predict(self,X):
        
        return [self._predict(x,self.root) for x in X]
        
            
                    
        
        
                    
        
        

In [53]:
from sklearn.datasets import load_iris

iris = load_iris()

X = iris['data']
y = iris['target']

In [54]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [55]:
model = DecisionTree()
model.fit(X_train, y_train)
preds = model.predict(X_test)

In [57]:
model._best_split(X_train,y_train)

{'feature_index': 2,
 'threshold': 1.9,
 'df_right': array([[6.7, 3.1, 4.4, 1.4, 1. ],
        [6.3, 2.5, 5. , 1.9, 2. ],
        [6.4, 3.2, 4.5, 1.5, 1. ],
        [5.8, 2.7, 5.1, 1.9, 2. ],
        [6. , 3.4, 4.5, 1.6, 1. ],
        [6.7, 3.1, 4.7, 1.5, 1. ],
        [5.5, 2.4, 3.7, 1. , 1. ],
        [6.3, 2.8, 5.1, 1.5, 2. ],
        [6.4, 3.1, 5.5, 1.8, 2. ],
        [6.6, 3. , 4.4, 1.4, 1. ],
        [7.2, 3.6, 6.1, 2.5, 2. ],
        [5.7, 2.9, 4.2, 1.3, 1. ],
        [7.6, 3. , 6.6, 2.1, 2. ],
        [5.6, 3. , 4.5, 1.5, 1. ],
        [7.7, 2.8, 6.7, 2. , 2. ],
        [5.8, 2.7, 4.1, 1. , 1. ],
        [5. , 2. , 3.5, 1. , 1. ],
        [6.3, 2.7, 4.9, 1.8, 2. ],
        [5.6, 2.7, 4.2, 1.3, 1. ],
        [5.7, 3. , 4.2, 1.2, 1. ],
        [7.7, 3.8, 6.7, 2.2, 2. ],
        [6.2, 2.9, 4.3, 1.3, 1. ],
        [5.7, 2.5, 5. , 2. , 2. ],
        [6. , 3. , 4.8, 1.8, 2. ],
        [5.8, 2.7, 5.1, 1.9, 2. ],
        [6. , 2.2, 4. , 1. , 1. ],
        [5.4, 3. , 4.5, 1.5, 1. ],
   