Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random splitting for numeric features for decision trees. #2883

Merged
merged 36 commits into from
May 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6cd97a7
Defined the random split class
RishabhGarg108 Mar 19, 2021
0ad73bc
Added implementation of random split
RishabhGarg108 Mar 20, 2021
a140aba
Added tests
RishabhGarg108 Mar 20, 2021
d8ca63f
Make new file for random split.
RishabhGarg108 Apr 20, 2021
ccebfcb
Changed imports
RishabhGarg108 Apr 20, 2021
25ce215
Use math::Random which essentially does the same thing
RishabhGarg108 Apr 20, 2021
374cb2d
Fixed typo
RishabhGarg108 Apr 20, 2021
8d63ae3
Add citation.
RishabhGarg108 Apr 20, 2021
f2bc627
Move citation into class and added newline at EOF
RishabhGarg108 Apr 22, 2021
ab6ae18
Removed best found gain check, as discussed with @rcurtin
RishabhGarg108 Apr 22, 2021
4704a25
Removed test where no split was made if there was no gain.
RishabhGarg108 Apr 22, 2021
bf98ea5
Add test for different splits under best and random settings.
RishabhGarg108 Apr 22, 2021
3c3660f
Merge branch 'master' of https://github.com/mlpack/mlpack into random…
RishabhGarg108 Apr 22, 2021
43097c9
Add UseBootstrap template parameter to random forest
RishabhGarg108 Apr 22, 2021
1ff5e61
Changed train function to use UseBootstrap
RishabhGarg108 Apr 22, 2021
c049d49
Removed ElemType from RandomBinaryNumericSplit
RishabhGarg108 Apr 22, 2021
5c3b60e
Add splitIfBetterGain to RandomBinaryNumericSplit
RishabhGarg108 Apr 24, 2021
fdfdc92
Shifter UseBootstrap to end of template list
RishabhGarg108 Apr 24, 2021
6ed9384
Add typedef for ExtraTrees
RishabhGarg108 Apr 24, 2021
185d795
Add test ensuring high accuracy on iris classification task.
RishabhGarg108 Apr 24, 2021
9102ac8
Lined up spacing in ExtraTrees typedef
RishabhGarg108 Apr 24, 2021
be8893d
Removed <random> include from BestBinaryNumericSplit
RishabhGarg108 Apr 24, 2021
8984315
Reduced min accurary to 91%
RishabhGarg108 Apr 25, 2021
82be58a
Added to HISTORY.md
RishabhGarg108 Apr 25, 2021
dd9383d
Update HISTORY.md
RishabhGarg108 Apr 26, 2021
bff45dc
Add documentation for parameters of NumChildren
RishabhGarg108 May 1, 2021
f4e5bf7
Add documentation for splitIfBetterGain
RishabhGarg108 May 1, 2021
0eefde3
Fixed dataset name in test file
RishabhGarg108 May 1, 2021
347607d
Merge branch 'random-split' of ssh://github.com/RishabhGarg108/mlpack…
RishabhGarg108 May 1, 2021
5aece9b
Changed and to && to fix windows build
RishabhGarg108 May 1, 2021
9453469
Changed Randomised -> Randomized
RishabhGarg108 May 3, 2021
27227f5
Add random_split to CMakeLists.txt
RishabhGarg108 May 3, 2021
0cb6a13
Reverted unnecessary bracket changes
RishabhGarg108 May 7, 2021
0475778
Apply suggestions from code review
RishabhGarg108 May 9, 2021
b2a5fe6
Merge branch 'master' into random-split
RishabhGarg108 May 9, 2021
2a6152d
Update HISTORY.md
RishabhGarg108 May 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,77 @@ class BestBinaryNumericSplit
const AuxiliarySplitInfo<ElemType>& /* aux */);
};

/**
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved
* The RandomBinaryNumericSplit is a splitting function for decision trees that
* will split based on a randomly selected point between the minimum
* and maximum value of the numerical dimension.
*
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved
* @tparam FitnessFunction Fitness function to use to calculate gain.
*/
template<typename FitnessFunction>
class RandomBinaryNumericSplit
{
public:
// No extra info needed for split.
template<typename ElemType>
class AuxiliarySplitInfo { };

/**
* Check if we can split a node. If we can split a node in a way that
* improves on 'bestGain', then we return the improved gain. Otherwise we
* return the value 'bestGain'. If a split is made, then classProbabilities
* and aux may be modified.
*
* @param bestGain Best gain seen so far (we'll only split if we find gain
* better than this).
* @param data The dimension of data points to check for a split in.
* @param labels Labels for each point.
* @param numClasses Number of classes in the dataset.
* @param weights Weights associated with labels.
* @param minimumLeafSize Minimum number of points in a leaf node for
* splitting.
* @param minimumGainSplit Minimum gain split.
* @param classProbabilities Class probabilities vector, which may be filled
* with split information a successful split.
* @param aux Auxiliary split information, which may be modified on a
* successful split.
*/
template<bool UseWeights, typename VecType, typename WeightVecType>
static double SplitIfBetter(
const double bestGain,
const VecType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const WeightVecType& weights,
const size_t minimumLeafSize,
const double minimumGainSplit,
arma::Col<typename VecType::elem_type>& classProbabilities,
AuxiliarySplitInfo<typename VecType::elem_type>& aux);

/**
* Returns 2, since the binary split always has two children.
*/
template<typename ElemType>
static size_t NumChildren(const arma::Col<ElemType>& /* classProbabilities */,
const AuxiliarySplitInfo<ElemType>& /* aux */)
{
return 2;
}

/**
* Given a point, calculate which child it should go to (left or right).
*
* @param point Point to calculate direction of.
* @param classProbabilities Auxiliary information for the split.
* @param * (aux) Auxiliary information for the split (Unused).
*/
template<typename ElemType>
static size_t CalculateDirection(
const ElemType& point,
const arma::Col<ElemType>& classProbabilities,
const AuxiliarySplitInfo<ElemType>& /* aux */);
};

} // namespace tree
} // namespace mlpack

Expand Down
183 changes: 162 additions & 21 deletions src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP
#define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP

#include <random>
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved

namespace mlpack {
namespace tree {

Expand Down Expand Up @@ -39,11 +41,11 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
arma::Row<size_t> sortedLabels(labels.n_elem);
arma::rowvec sortedWeights;
for (size_t i = 0; i < sortedLabels.n_elem; ++i)
sortedLabels[i] = labels[sortedIndices[i]];
sortedLabels(i) = labels(sortedIndices(i));

// Sanity check: if the first element is the same as the last, we can't split
// in this dimension.
if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
if (data(sortedIndices(0)) == data(sortedIndices(sortedIndices.n_elem - 1)))
return DBL_MAX;

// Only initialize if we are using weights.
Expand All @@ -52,7 +54,7 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
sortedWeights.set_size(sortedLabels.n_elem);
// The weights must keep the same order as the labels.
for (size_t i = 0; i < sortedLabels.n_elem; ++i)
sortedWeights[i] = weights[sortedIndices[i]];
sortedWeights(i) = weights(sortedIndices(i));
}

// Loop through all possible split points, choosing the best one. Also, force
Expand All @@ -77,15 +79,15 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
// These points have to be on the left.
for (size_t i = 0; i < minimum - 1; ++i)
{
classWeightSums(sortedLabels[i], 0) += sortedWeights[i];
totalLeftWeight += sortedWeights[i];
classWeightSums(sortedLabels(i), 0) += sortedWeights(i);
totalLeftWeight += sortedWeights(i);
}

// These points have to be on the right.
for (size_t i = minimum - 1; i < data.n_elem; ++i)
{
classWeightSums(sortedLabels[i], 1) += sortedWeights[i];
totalRightWeight += sortedWeights[i];
classWeightSums(sortedLabels(i), 1) += sortedWeights(i);
totalRightWeight += sortedWeights(i);
}
}
else
Expand All @@ -96,31 +98,31 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
// Initialize the counts.
// These points have to be on the left.
for (size_t i = 0; i < minimum - 1; ++i)
++classCounts(sortedLabels[i], 0);
++classCounts(sortedLabels(i), 0);

// These points have to be on the right.
for (size_t i = minimum - 1; i < data.n_elem; ++i)
++classCounts(sortedLabels[i], 1);
++classCounts(sortedLabels(i), 1);
}

for (size_t index = minimum; index < data.n_elem - minimum; ++index)
{
// Update class weight sums or counts.
if (UseWeights)
{
classWeightSums(sortedLabels[index - 1], 1) -= sortedWeights[index - 1];
classWeightSums(sortedLabels[index - 1], 0) += sortedWeights[index - 1];
totalLeftWeight += sortedWeights[index - 1];
totalRightWeight -= sortedWeights[index - 1];
classWeightSums(sortedLabels(index - 1), 1) -= sortedWeights(index - 1);
classWeightSums(sortedLabels(index - 1), 0) += sortedWeights(index - 1);
totalLeftWeight += sortedWeights(index - 1);
totalRightWeight -= sortedWeights(index - 1);
}
else
{
--classCounts(sortedLabels[index - 1], 1);
++classCounts(sortedLabels[index - 1], 0);
--classCounts(sortedLabels(index - 1), 1);
++classCounts(sortedLabels(index - 1), 0);
}

// Make sure that the value has changed.
if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
if (data(sortedIndices(index)) == data(sortedIndices(index - 1)))
continue;

// Calculate the gain for the left and right child. Only use weights if
Expand Down Expand Up @@ -156,8 +158,8 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
classProbabilities.set_size(1);
// The actual split value will be halfway between the value at index - 1
// and index.
classProbabilities[0] = (data[sortedIndices[index - 1]] +
data[sortedIndices[index]]) / 2.0;
classProbabilities(0) = (data(sortedIndices(index - 1)) +
data(sortedIndices(index))) / 2.0;

return gain;
}
Expand All @@ -166,8 +168,8 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
// We still have a better split.
bestFoundGain = gain;
classProbabilities.set_size(1);
classProbabilities[0] = (data[sortedIndices[index - 1]] +
data[sortedIndices[index]]) / 2.0;
classProbabilities(0) = (data(sortedIndices(index - 1)) +
data(sortedIndices(index))) / 2.0;
improved = true;
}
}
Expand All @@ -192,7 +194,146 @@ size_t BestBinaryNumericSplit<FitnessFunction>::CalculateDirection(
const arma::Col<ElemType>& classProbabilities,
const AuxiliarySplitInfo<ElemType>& /* aux */)
{
if (point <= classProbabilities[0])
if (point <= classProbabilities(0))
return 0; // Go left.
else
return 1; // Go right.
}

template<typename FitnessFunction>
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved
template<bool UseWeights, typename VecType, typename WeightVecType>
double RandomBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
const double bestGain,
const VecType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const WeightVecType& weights,
const size_t minimumLeafSize,
const double minimumGainSplit,
arma::Col<typename VecType::elem_type>& classProbabilities,
AuxiliarySplitInfo<typename VecType::elem_type>& /* aux */)
{
double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
// Forcing a minimum leaf size of 1 (empty children don't make sense).
const size_t minimum = std::max(minimumLeafSize, (size_t) 1);

// First sanity check: if we don't have enough points, we can't split.
if (data.n_elem < (minimum * 2))
return DBL_MAX;
if (bestGain == 0.0)
return DBL_MAX; // It can't be outperformed.

typename VecType::elem_type maxValue = arma::max(data);
typename VecType::elem_type minValue = arma::min(data);

// Sanity check: if the maximum element is the same as the mininimum, we
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved
// can't split in this dimension.
if (maxValue == minValue)
return DBL_MAX;

/*
Just for making review easy, the following bit of code is taken directly from
https://en.cppreference.com/w/cpp/numeric/random/uniform_real_distribution
to generate a random number. (To be removed before merge)
*/
// Picking a random pivot to split the dimension.
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> distribution(minValue, maxValue);
double randomPivot = distribution(gen);
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved

// We need to count the number of points for each class.
arma::Mat<size_t> classCounts;
arma::mat classWeightSums;
double totalWeight = 0.0;
double totalLeftWeight = 0.0;
double totalRightWeight = 0.0;
size_t leftLeafSize = 0;
size_t rightLeafSize = 0;
if (UseWeights)
{
classWeightSums.zeros(numClasses, 2);
totalWeight = arma::accu(weights);
bestFoundGain *= totalWeight;

for (size_t i = 0; i < data.n_elem; ++i)
{
if (data(i) < randomPivot)
{
++leftLeafSize;
classWeightSums(labels(i), 0) += weights(i);
totalLeftWeight += weights(i);
}
else
{
++rightLeafSize;
classWeightSums(labels(i), 1) += weights(i);
totalRightWeight += weights(i);
}
}
}
else
{
classCounts.zeros(numClasses, 2);
bestFoundGain *= data.n_elem;

for (size_t i = 0; i < data.n_elem; i++)
{
if (data(i) < randomPivot)
{
++leftLeafSize;
++classCounts(labels(i), 0);
}
else
{
++rightLeafSize;
++classCounts(labels(i), 1);
}
}
}

// Calculate the gain for the left and right child. Only use weights if
// needed.
const double leftGain = UseWeights ?
FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(0),
numClasses, totalLeftWeight) :
FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(0),
numClasses, leftLeafSize);
const double rightGain = UseWeights ?
FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(1),
numClasses, totalRightWeight) :
FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(1),
numClasses, rightLeafSize);

double gain;
if (UseWeights)
gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
else
// Calculate the gain at this split point.
gain = double(leftLeafSize) * leftGain + double(rightLeafSize) * rightGain;

if (gain < bestFoundGain)
return DBL_MAX;
RishabhGarg108 marked this conversation as resolved.
Show resolved Hide resolved

classProbabilities.set_size(1);
classProbabilities(0) = randomPivot;

if (UseWeights)
gain /= totalWeight;
else
gain /= labels.n_elem;

return gain;
}

template<typename FitnessFunction>
template<typename ElemType>
size_t RandomBinaryNumericSplit<FitnessFunction>::CalculateDirection(
const ElemType& point,
const arma::Col<ElemType>& classProbabilities,
const AuxiliarySplitInfo<ElemType>& /* aux */)
{
if (point <= classProbabilities(0))
return 0; // Go left.
else
return 1; // Go right.
Expand Down