<a href="https://colab.research.google.com/github/chandini2595/DecisionTrees_Ensemble_Methods/blob/main/DecisionTree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from collections import Counter

In [2]:
def gini_impurity(y):
    counts = np.bincount(y)
    probabilities = counts / len(y)
    return 1 - np.sum(probabilities ** 2)

In [3]:
def entropy(y):
    counts = np.bincount(y)
    probabilities = counts / len(y)
    return -np.sum([p * np.log2(p) for p in probabilities if p > 0])

In [4]:
def information_gain(y, y_left, y_right, impurity_func):
    p = len(y_left) / len(y)
    return impurity_func(y) - p * impurity_func(y_left) - (1 - p) * impurity_func(y_right)

In [5]:
def split_dataset(X, y, feature_index, threshold):
    left_indices = X[:, feature_index] <= threshold
    right_indices = X[:, feature_index] > threshold
    return X[left_indices], X[right_indices], y[left_indices], y[right_indices]

In [6]:
def find_best_split(X, y, impurity_func):
    best_gain = -1
    best_split = None
    n_features = X.shape[1]

    for feature_index in range(n_features):
        thresholds = np.unique(X[:, feature_index])
        for threshold in thresholds:
            X_left, X_right, y_left, y_right = split_dataset(X, y, feature_index, threshold)
            if len(y_left) == 0 or len(y_right) == 0:
                continue
            gain = information_gain(y, y_left, y_right, impurity_func)
            if gain > best_gain:
                best_gain = gain
                best_split = (feature_index, threshold)
    return best_split

In [7]:
class DecisionTree:
    def __init__(self, max_depth=None, min_samples_split=2, impurity_func=gini_impurity):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.impurity_func = impurity_func
        self.tree = None

    def fit(self, X, y, depth=0):
        if depth == self.max_depth or len(y) < self.min_samples_split or len(np.unique(y)) == 1:
            return Counter(y).most_common(1)[0][0]

        feature_index, threshold = find_best_split(X, y, self.impurity_func)
        if feature_index is None:
            return Counter(y).most_common(1)[0][0]

        X_left, X_right, y_left, y_right = split_dataset(X, y, feature_index, threshold)
        self.tree = {
            "feature_index": feature_index,
            "threshold": threshold,
            "left": self.fit(X_left, y_left, depth + 1),
            "right": self.fit(X_right, y_right, depth + 1)
        }
        return self.tree

    def predict_single(self, x, tree):
        if isinstance(tree, dict):
            if x[tree["feature_index"]] <= tree["threshold"]:
                return self.predict_single(x, tree["left"])
            else:
                return self.predict_single(x, tree["right"])
        return tree

    def predict(self, X):
        return np.array([self.predict_single(x, self.tree) for x in X])

In [8]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load dataset
data = load_iris()
X, y = data.data, data.target

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train Decision Tree
tree = DecisionTree(max_depth=5)
tree.fit(X_train, y_train)

# Predict
y_pred = tree.predict(X_test)

# Evaluate
print("Accuracy:", accuracy_score(y_test, y_pred))

Accuracy: 1.0


In [9]:
def print_tree(tree, depth=0):
    if isinstance(tree, dict):
        print(f"{'|   ' * depth}Feature {tree['feature_index']} <= {tree['threshold']}")
        print_tree(tree["left"], depth + 1)
        print_tree(tree["right"], depth + 1)
    else:
        print(f"{'|   ' * depth}Predict: {tree}")

print_tree(tree.tree)

Feature 2 <= 1.9
|   Predict: 0
|   Feature 2 <= 4.7
|   |   Feature 3 <= 1.6
|   |   |   Predict: 1
|   |   |   Predict: 2
|   |   Feature 3 <= 1.7
|   |   |   Feature 2 <= 4.9
|   |   |   |   Predict: 1
|   |   |   |   Feature 3 <= 1.5
|   |   |   |   |   Predict: 2
|   |   |   |   |   Predict: 1
|   |   |   Feature 2 <= 4.8
|   |   |   |   Feature 0 <= 5.9
|   |   |   |   |   Predict: 1
|   |   |   |   |   Predict: 2
|   |   |   |   Predict: 2
