Skip to content

Commit

Permalink
Use numba to speed up decision tree fitting
Browse files Browse the repository at this point in the history
Signed-off-by: Niklas Koep <niklas.koep@gmail.com>
  • Loading branch information
nkoep committed Jul 15, 2020
1 parent 4316444 commit 7cd0cfb
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 67 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
numba
sklearn
141 changes: 74 additions & 67 deletions tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,75 @@
import numba
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin


njit_cached = numba.njit(cache=True)


@njit_cached
def _score_split(y, partition):
"""Return the mean squared error of a potential split partition.
Parameters
----------
y : (num_samples,) ndarray
The vector of targets in the current node.
partition : tuple
A 2-tuple of boolean masks to index left and right samples in
``y``.
"""
mask_left, mask_right = partition
y_left = y[mask_left]
y_right = y[mask_right]
return (np.sum((y_left - y_left.mean()) ** 2) +
np.sum((y_right - y_right.mean()) ** 2))


@njit_cached
def _find_best_split(x, y):
"""Find the best split for a vector of samples.
Determine the threshold in `x` which optimizes the split score by
minimizing the mean squared error of the target.
Parameters
----------
x : (num_samples,) ndarray
The vector of observations of a particular feature.
y : (num_samples,) ndarray
The vector of targets.
Returns
-------
split_configuration : dict
A dictionary with the best `score`, `threshold` and `partition`.
"""
best_score = np.inf
best_threshold = None
best_partition = None
for threshold in x:
# Obtain binary masks for all samples whose feature values are
# below (left) or above (right) the split threshold.
mask_left = x < threshold
mask_right = x >= threshold

# If we can't split the samples based on `threshold', move on.
if not mask_left.any() or not mask_right.any():
continue

# Score the candidate split.
partition = (mask_left, mask_right)
score = _score_split(y, partition)

if score < best_score:
best_score = score
best_threshold = threshold
best_partition = partition

return best_score, best_threshold, best_partition


class Tree:
"""The fundamental data structure representing a binary decision tree.
Expand Down Expand Up @@ -33,72 +101,6 @@ def __init__(self, min_samples_split):
self.threshold = None
self.prediction = None

def _score_split(self, y, partition):
"""
Return the mean squared error of a potential split partition.
Parameters
----------
y : (num_samples,) ndarray
The vector of targets in the current node.
partition : tuple
A 2-tuple of boolean masks to index left and right samples in
``y``.
"""
mask_left, mask_right = partition
y_left = y[mask_left]
y_right = y[mask_right]
return (np.sum((y_left - y_left.mean()) ** 2) +
np.sum((y_right - y_right.mean()) ** 2))

def _find_best_split(self, x, y):
"""Find the best split for a vector of samples and the corresponding
target vector.
Determine the threshold in `x` which optimizes the split score by
minimizing the mean squared error of the target.
Parameters
----------
x : (num_samples,) ndarray
The vector of observations of a particular feature.
y : (num_samples,) ndarray
The vector of targets.
Returns
-------
split_configuration : dict
A dictionary with the best `score`, `threshold` and `partition`.
"""
best_score = np.inf
best_threshold = None
best_partition = None
for threshold in x:
# Obtain binary masks for all samples whose feature values are
# below (left) or above (right) the split threshold.
mask_left = x < threshold
mask_right = x >= threshold

# If we can't split the samples based on `threshold', move on.
if not mask_left.any() or not mask_right.any():
continue

# Score the candidate split.
partition = (mask_left, mask_right)
score = self._score_split(y, partition)

if score < best_score:
best_score = score
best_threshold = threshold
best_partition = partition

return {
"score": best_score,
"threshold": best_threshold,
"partition": best_partition
}

def construct_tree(self, X, y):
"""Construct the binary decision tree via recursive splitting.
Expand All @@ -120,7 +122,12 @@ def construct_tree(self, X, y):
feature_scores = {}
for feature_index in np.arange(num_features):
x = X[:, feature_index]
feature_scores[feature_index] = self._find_best_split(x, y)
score, threshold, partition = _find_best_split(x, y)
feature_scores[feature_index] = {
"score": score,
"threshold": threshold,
"partition": partition
}

# Retrieve the split configuration for the best (lowest) score.
feature_index = min(feature_scores,
Expand Down

0 comments on commit 7cd0cfb

Please sign in to comment.