<a href="https://colab.research.google.com/github/lapa19/DL_basics/blob/main/DecisionTree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
class DecisionTree:
    def __init__(self, min_samples_split = 2, max_depth=None):
      self.max_depth = max_depth
      self.min_samples_split = min_samples_split

    def fit(self, X, y):
      self.tree_ = self._build_tree(X, y)

    def _build_tree(self, X, y, depth=0):
      if depth == self.max_depth:
        return
      best_feature, best_threshold = self._find_best_split(X, y)
      if best_feature is None:
        return
      tree = {'feature': best_feature, 'threshold': best_threshold}
      left_indices = X[:, best_feature] < best_threshold
      right_indices = ~left_indices
      tree['left'] = self._build_tree(X[left_indices], y[left_indices], depth+1)
      tree['right'] = self._build_tree(X[right_indices], y[right_indices], depth+1)
      return tree

    def _find_best_split(self, X, y):
      best_feature, best_threshold = None, None
      for feature in range(X.shape[1]):
        thresholds = np.unique(X[:, feature])
        for threshold in thresholds:
          left_indices = X[:, feature] < threshold
          right_indices = ~left_indices
          if np.sum(left_indices) < self.min_samples_split or np.sum(right_indices) < self.min_samples_split:
            continue
          left_gini = self._gini(y[left_indices])
          right_gini = self._gini(y[right_indices])
          weighted_gini = (np.sum(left_indices) * left_gini + np.sum(right_indices) * right_gini) / len(y)
          if weighted_gini < best_gini:
            best_gini, best_feature, best_threshold = weighted_gini, feature, threshold
      return best_feature, best_threshold

    def _gini(self, y):
      _, counts = np.unique(y, return_counts=True)
      probabilities = counts / len(y)
      return 1 - np.sum(probabilities**2)

    def predict(self, X):
      return np.array([self._traverse_tree(x, self.tree_) for x in X])

    def _traverse_tree(self, x, tree):
      if 'value' in tree:
        return tree['value']
      feature, threshold = tree['feature'], tree['threshold']
      if x[feature] < threshold:
        return self._traverse_tree(x, tree['left'])
      else:
        return self._traverse_tree(x, tree['right'])
