In [3]:
from sklearn.base import BaseEstimator, ClassifierMixin

class DecisionTree(BaseEstimator, ClassifierMixin):
    def __init__(self, depth=0, max_depth=3):
        self.depth = depth
        self.max_depth = max_depth
        self.feature = None
        self.threshold = None
        self.left = None
        self.right = None
        self.value = None

    def fit(self, X, y):
        if len(set(y)) == 1:
            self.value = y[0]
            return self

        if self.depth >= self.max_depth:
            self.value = Counter(y).most_common(1)[0][0]
            return self

        n_samples, n_features = X.shape
        best_gain = -1
        split = None

        for feature in range(n_features):
            values = X[:, feature]
            thresholds = np.unique(values)
            for threshold in thresholds:
                left_indices = np.where(values <= threshold)[0]
                right_indices = np.where(values > threshold)[0]
                if len(left_indices) > 0 and len(right_indices) > 0:
                    gain = information_gain(y, [left_indices, right_indices])
                    if gain > best_gain:
                        best_gain = gain
                        self.feature = feature
                        self.threshold = threshold
                        split = (left_indices, right_indices)

        if best_gain == -1:
            self.value = Counter(y).most_common(1)[0][0]
            return self

        left_indices, right_indices = split
        self.left = DecisionTree(self.depth + 1, self.max_depth)
        self.right = DecisionTree(self.depth + 1, self.max_depth)
        self.left.fit(X[left_indices, :], y[left_indices])
        self.right.fit(X[right_indices, :], y[right_indices])
        return self

    def predict(self, X):
        if self.value is not None:
            return np.array([self.value] * len(X))
        feature_value = X[:, self.feature]
        left_indices = np.where(feature_value <= self.threshold)[0]
        right_indices = np.where(feature_value > self.threshold)[0]
        y_pred = np.empty(len(X), dtype=int)
        y_pred[left_indices] = self.left.predict(X[left_indices])
        y_pred[right_indices] = self.right.predict(X[right_indices])
        return y_pred

    def print_tree(self, feature_names, indent=""):
        if self.value is not None:
            print(indent + "Predict:", self.value)
        else:
            print(indent + f"Feature {feature_names[self.feature]} <= {self.threshold}")
            print(indent + "Left:")
            self.left.print_tree(feature_names, indent + "  ")
            print(indent + "Right:")
            self.right.print_tree(feature_names, indent + "  ")