In [13]:
import numpy as np
import pandas as pd
from collections import Counter

In [14]:
class Node:
    def __init__(self,feature=None, threshold=None, left=None,right=None, value=None):
        self.feature=feature 
        self.threshold=threshold
        self.right=right
        self.left=left
        self.value=None # if it's a leaf node (majority for classification, mean for regression)

In [35]:
class DecisionTree:
    def __init__(self, task='classification', criterion='gini', max_depth=None, min_samples_split=2, min_samples_leaf=1):
        self.task=task
        self.criterion=criterion
        self.max_depth=max_depth
        self.min_samples_split=min_samples_split
        self.min_samples_leaf=min_samples_leaf
    def fit(self,X,y):
        self.X=X
        self.y=y
        self.n_features=len(X[0])
        self.root=self.build_tree(self.X,self.y,depth=0)
        
    def build_tree(self, X, y, depth):
        if ( self.max_depth is not None and depth>=self.max_depth 
            or len(y)<= self.min_samples_split 
            or len(y)<= self.min_samples_leaf 
            or np.unique(y).size == 1):
            if self.task=="classification":
                values_counter=Counter(y)
                majority_class=values_counter.most_common(1)[0][0]
                return Node(value=majority_class)
            elif self.task=="regression":
                pass
                
            
        best_feature, best_threshold=self.find_best_split(X,y)
        X_left,y_left,X_right,y_right=self.split_data(X,y,best_feature,best_threshold)
        if len(y_left) == 0 or len(y_right) == 0:
                print("Empty split encountered! Returning leaf node.")

            

        left_subtree=self.build_tree(X_left,y_left, depth+1)
        right_subtree=self.build_tree(X_right,y_right, depth+1)
        
        return Node(feature=best_feature, threshold=best_threshold,left=left_subtree, right=right_subtree) 

    def find_best_split(self, X, y):
        best_gini=float('inf')
        best_feature=None
        best_threshold=None
        for feature in range(self.n_features):
            unique_values=np.array(sorted(np.unique(X[:,feature])))
            midpoints=(unique_values[:-1]+unique_values[1:])/2 # finding the midpoints to use as possible thresholds
            for threshold in midpoints:
                X_left,y_left,X_right,y_right=self.split_data(X,y,feature,threshold)
                weighted_gini=((len(y_left)/len(y))*self.Gini(y_left))+((len(y_right)/len(y))*self.Gini(y_right))
                if weighted_gini<best_gini:
                    best_gini=weighted_gini
                    best_feature=feature
                    best_threshold=threshold
        return best_feature, best_threshold
        
    def split_data(self, X, y, feature, threshold):
     left = X[:, feature] <= threshold
     right = X[:, feature] > threshold
     print(f"Splitting on feature {feature} at threshold {threshold}")
     print(f"Left count: {np.sum(left)}, Right count: {np.sum(right)}")
     X_left = X[left]
     y_left = y[left]
     X_right = X[right]
     y_right = y[right]
     return X_left, y_left, X_right, y_right
        
    def Gini(self,y):
        if len(y)==0:
            return 0.0
        values, counts=np.unique(y, return_counts=True)
        probabilites= counts/counts.sum()
        return 1-np.sum(probabilites**2)

    def predict(self, X):
     y = np.zeros(X.shape[0], dtype=object)  # safer to use dtype=object for mixed types
     for i in range(X.shape[0]):
        current = self.root
        while current.left is not None or current.right is not None:
            if current.value is not None:
                y[i] = current.value
                break
            feature = current.feature
            threshold = current.threshold
            if X[i, feature] <= threshold:
                current = current.left
            else:
                current = current.right
        else:
            # in case the loop exits naturally, assign value if it's a leaf
            y[i] = current.value
     return y
