In [None]:
import numpy as np

class DecisionTree:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth

    def fit(self, X, y, depth=0):
        if len(np.unique(y)) == 1:
            return {'class': np.unique(y)[0]}

        if self.max_depth is not None and depth >= self.max_depth:
            return {'class': np.argmax(np.bincount(y))}

        num_features = X.shape[1]
        best_gini = float('inf')
        best_split = None
        for feature_index in range(num_features):
            unique_values = np.unique(X[:, feature_index])
            for value in unique_values:
                left_mask = X[:, feature_index] < value
                right_mask = ~left_mask
                left_gini = self._gini(y[left_mask])
                right_gini = self._gini(y[right_mask])
                gini = (len(y[left_mask]) * left_gini + len(y[right_mask]) * right_gini) / len(y)
                if gini < best_gini:
                    best_gini = gini
                    best_split = (feature_index, value, left_mask, right_mask)

        if best_gini == float('inf'):
            return {'class': np.argmax(np.bincount(y))}

        feature_index, value, left_mask, right_mask = best_split
        left_subtree = self.fit(X[left_mask], y[left_mask], depth + 1)
        right_subtree = self.fit(X[right_mask], y[right_mask], depth + 1)
        self.tree = {'feature_index': feature_index, 'value': value,
                'left': left_subtree, 'right': right_subtree}
        return self.tree

    def _gini(self, y):
        classes = np.unique(y)
        gini = 1.0
        for class_label in classes:
            p = np.sum(y == class_label) / len(y)
            gini -= p ** 2
        return gini

    def predict(self, X):
        predictions = [self._predict(x, self.tree) for x in X]
        return np.array(predictions)

    def _predict(self, x, node):
        if 'class' in node:
            return node['class']
        feature_index, value, left, right = node['feature_index'], node['value'], node['left'], node['right']
        if x[feature_index] < value:
            return self._predict(x, left)
        else:
            return self._predict(x, right)

def train_decision_tree(X_train, y_train):
    classifier = DecisionTree(max_depth = 5)
    classifier.fit(X_train, y_train)
    return classifier