<a href="https://colab.research.google.com/github/myrondza10/jax_experiments/blob/main/decision_tree_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
# Import required libraries

from jax import grad, jit, vmap
import jax.numpy as jnp

# Create a Class for the Decision Tree Algorithm

class DecisionTreeJAX():
    """
    Decision Tree using JAX.



    """

    def __init__(
        self,
        max_depth=None
        ):
        self.max_depth = max_depth

    def _build_tree(self, X, y, depth):
          if self.max_depth is not None and depth >= self.max_depth:
              return self._create_leaf_node(y)

          best_feature, best_threshold = self._find_best_split(X, y)
          if best_feature is None:
              return self._create_leaf_node(y)

          left_indices = X[:, best_feature] <= best_threshold
          right_indices = ~left_indices

          left_subtree = self._build_tree(X[left_indices], y[left_indices], depth + 1)
          right_subtree = self._build_tree(X[right_indices], y[right_indices], depth + 1)

          return {
              'feature': best_feature,
              'threshold': best_threshold,
              'left': left_subtree,
              'right': right_subtree
          }

    def _find_best_split(self, X, y):
          best_gini = float('inf')
          best_feature = None
          best_threshold = None

          for feature in range(X.shape[1]):
              unique_values = jnp.unique(X[:, feature])
              thresholds = (unique_values[:-1] + unique_values[1:]) / 2

              for threshold in thresholds:
                  left_indices = X[:, feature] <= threshold
                  right_indices = ~left_indices

                  gini = self._gini_index(y[left_indices]) * len(y[left_indices]) / len(y) \
                          + self._gini_index(y[right_indices]) * len(y[right_indices]) / len(y)

                  if gini < best_gini:
                      best_gini = gini
                      best_feature = feature
                      best_threshold = threshold

          return best_feature, best_threshold

    def _gini_index(self, y):
          _, counts = jnp.unique(y, return_counts=True)
          probabilities = counts / len(y)
          return 1 - jnp.sum(probabilities ** 2)

    def _create_leaf_node(self, y):
        unique_classes, counts = jnp.unique(y, return_counts=True)
        unique_classes_list = unique_classes.tolist()
        return {'class_counts': dict(zip(unique_classes_list, counts))}

    def _traverse_tree(self, x, node):
          if 'class_counts' in node:
              return max(node['class_counts'], key=node['class_counts'].get)

          if x[node['feature']] <= node['threshold']:
              return self._traverse_tree(x, node['left'])
          else:
              return self._traverse_tree(x, node['right'])

    def fit(self, X, y):
          self.tree = self._build_tree(X, y, 0)

    def predict(self, X):
          return vmap(self._traverse_tree)(X, self.tree)

In [23]:
X = jnp.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
y = jnp.array([0, 0, 1, 1, 1])

In [28]:
dtree_jax = DecisionTreeJAX(max_depth=5)
dtree_jax.fit(X, y)