Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Feature/min samples leaf option #253

Merged
merged 2 commits into from
Jan 8, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions core/src/main/java/hivemall/smile/classification/DecisionTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ public class DecisionTree implements Classifier<double[]> {
* The number of instances in a node below which the tree will not split.
*/
private final int _minSplit;
/**
* The minimum number of samples in a leaf node
*/
private final int _minLeafSize;
/**
* The index of training values in ascending order. Note that only numeric attributes will be
* sorted.
Expand All @@ -177,7 +181,11 @@ public static enum SplitRule {
/**
* Used by the ID3, C4.5 and C5.0 tree generation algorithms.
*/
ENTROPY
ENTROPY,
/**
* Classification error.
*/
CLASSIFICATION_ERROR
}

/**
Expand Down Expand Up @@ -480,6 +488,11 @@ public boolean findBestSplit() {
}
}

// avoid split if the number of samples is less than threshold
if (n <= _minSplit) {
return false;
}

final double impurity = impurity(count, n, _rule);

final int p = _attributes.length;
Expand Down Expand Up @@ -516,7 +529,7 @@ public boolean findBestSplit() {
* @param impurity the impurity of this node.
* @param j the attribute index to split on.
*/
public Node findBestSplit(final int n, final int[] count, final int[] falseCount,
private Node findBestSplit(final int n, final int[] count, final int[] falseCount,
final double impurity, final int j) {
final int N = x.length;
final Node splitNode = new Node();
Expand All @@ -536,8 +549,8 @@ public Node findBestSplit(final int n, final int[] count, final int[] falseCount
final int tc = Math.sum(trueCount[l]);
final int fc = n - tc;

// If either side is empty, skip this feature.
if (tc == 0 || fc == 0) {
// skip splitting this feature.
if (tc < _minSplit || fc < _minSplit) {
continue;
}

Expand All @@ -550,15 +563,13 @@ public Node findBestSplit(final int n, final int[] count, final int[] falseCount
* impurity(falseCount, fc, _rule);

if (gain > splitNode.splitScore) {
int trueLabel = Math.whichMax(trueCount[l]);
int falseLabel = Math.whichMax(falseCount);
// new best split
splitNode.splitFeature = j;
splitNode.splitFeatureType = AttributeType.NOMINAL;
splitNode.splitValue = l;
splitNode.splitScore = gain;
splitNode.trueChildOutput = trueLabel;
splitNode.falseChildOutput = falseLabel;
splitNode.trueChildOutput = Math.whichMax(trueCount[l]);
splitNode.falseChildOutput = Math.whichMax(falseCount);
}
}
} else if (_attributes[j].type == AttributeType.NUMERIC) {
Expand All @@ -582,8 +593,8 @@ public Node findBestSplit(final int n, final int[] count, final int[] falseCount
final int tc = Math.sum(trueCount);
final int fc = n - tc;

// If either side is empty, continue.
if (tc == 0 || fc == 0) {
// skip splitting this feature.
if (tc < _minSplit || fc < _minSplit) {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
Expand All @@ -599,15 +610,13 @@ public Node findBestSplit(final int n, final int[] count, final int[] falseCount
* impurity(falseCount, fc, _rule);

if (gain > splitNode.splitScore) {
int trueLabel = Math.whichMax(trueCount);
int falseLabel = Math.whichMax(falseCount);
// new best split
splitNode.splitFeature = j;
splitNode.splitFeatureType = AttributeType.NUMERIC;
splitNode.splitValue = (x_ij + prevx) / 2.d;
splitNode.splitScore = gain;
splitNode.trueChildOutput = trueLabel;
splitNode.falseChildOutput = falseLabel;
splitNode.trueChildOutput = Math.whichMax(trueCount);
splitNode.falseChildOutput = Math.whichMax(falseCount);
}

prevx = x_ij;
Expand Down Expand Up @@ -668,7 +677,8 @@ public boolean split(@Nullable final PriorityQueue<TrainNode> nextSplits) {
+ node.splitFeatureType);
}

if (tc == 0 || fc == 0) {
if (tc < _minLeafSize || fc < _minLeafSize) {
// set the node as leaf
node.splitFeature = -1;
node.splitFeatureType = null;
node.splitValue = Double.NaN;
Expand Down Expand Up @@ -736,14 +746,24 @@ private static double impurity(@Nonnull final int[] count, final int n,
}
break;
}
case CLASSIFICATION_ERROR: {
impurity = 0.d;
for (int i = 0; i < count.length; i++) {
if (count[i] > 0) {
impurity = Math.max(impurity, count[i] / (double) n);
}
}
impurity = Math.abs(1.d - impurity);
break;
}
}

return impurity;
}

public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y,
int J) {
this(attributes, x, y, x[0].length, Integer.MAX_VALUE, J, 2, null, null, SplitRule.GINI, null);
int numLeafs) {
this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
}

/**
Expand All @@ -754,8 +774,9 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
* @param y the response variable.
* @param numVars the number of input variables to pick to split on at each node. It seems that
* dim/3 give generally good performance, where dim is the number of variables.
* @param numLeafs the maximum number of leaf nodes in the tree.
* @param maxLeafs the maximum number of leaf nodes in the tree.
* @param minSplits the number of minimum elements in a node to split
* @param minLeafSize the minimum size of leaf nodes.
* @param order the index of training values in ascending order. Note that only numeric
* attributes need be sorted.
* @param samples the sample set of instances for stochastic learning. samples[i] is the number
Expand All @@ -764,19 +785,10 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
* @param seed
*/
public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y,
int numVars, int maxDepth, int numLeafs, int minSplits, @Nullable int[] samples,
@Nullable int[][] order, @Nonnull SplitRule rule, @Nullable smile.math.Random rand) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format(
"The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (numVars <= 0 || numVars > x[0].length) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
if (numLeafs < 2) {
throw new IllegalArgumentException("Invalid maximum leaves: " + numLeafs);
}
int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
@Nullable int[] samples, @Nullable int[][] order, @Nonnull SplitRule rule,
@Nullable smile.math.Random rand) {
checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);

this._k = Math.max(y) + 1;
if (_k < 2) {
Expand All @@ -788,9 +800,11 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
throw new IllegalArgumentException("-attrs option is invliad: "
+ Arrays.toString(attributes));
}

this._numVars = numVars;
this._maxDepth = maxDepth;
this._minSplit = minSplits;
this._minLeafSize = minLeafSize;
this._rule = rule;
this._order = (order == null) ? SmileExtUtils.sort(attributes, x) : order;
this._importance = new double[attributes.length];
Expand All @@ -813,7 +827,7 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
this._root = new Node(Math.whichMax(count));

final TrainNode trainRoot = new TrainNode(_root, x, y, samples, 1);
if (numLeafs == Integer.MAX_VALUE) {
if (maxLeafs == Integer.MAX_VALUE) {
if (trainRoot.findBestSplit()) {
trainRoot.split(null);
}
Expand All @@ -826,7 +840,7 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
}
// Pop best leaf from priority queue, split it, and push
// children nodes into the queue if possible.
for (int leaves = 1; leaves < numLeafs; leaves++) {
for (int leaves = 1; leaves < maxLeafs; leaves++) {
// parent is the leaf to split
TrainNode node = nextSplits.poll();
if (node == null) {
Expand All @@ -837,6 +851,32 @@ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @No
}
}

private static void checkArgument(@Nonnull double[][] x, @Nonnull int[] y, int numVars,
int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format(
"The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (numVars <= 0 || numVars > x[0].length) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
if (maxDepth < 2) {
throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
}
if (maxLeafs < 2) {
throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafs);
}
if (minSplits < 2) {
throw new IllegalArgumentException(
"Invalid minimum number of samples required to split an internal node: "
+ minSplits);
}
if (minLeafSize < 1) {
throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + minLeafSize);
}
}

/**
* Returns the variable importance. Every time a split of a node is made on variable the (GINI,
* information gain, etc.) impurity criterion for the two descendent nodes is less than the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
*/
private int _maxLeafNodes;
private int _minSamplesSplit;
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
private ModelType _outputType;
Expand All @@ -119,6 +120,8 @@ protected Options getOptions() {
"The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
opts.addOption("splits", "min_split", true,
"A node that has greater than or equals to `min_split` examples will split [default: 5]");
opts.addOption("min_samples_leaf", true,
"The minimum number of samples in a leaf node [default: 1]");
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
Expand All @@ -131,7 +134,8 @@ protected Options getOptions() {

@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
int trees = 500, maxDepth = 8, maxLeafs = Integer.MAX_VALUE, minSplit = 5;
int trees = 500, maxDepth = 8;
int maxLeafs = Integer.MAX_VALUE, minSplit = 5, minSamplesLeaf = 1;
float numVars = -1.f;
double eta = 0.05d, subsample = 0.7d;
Attribute[] attrs = null;
Expand All @@ -154,6 +158,8 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen
maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
maxLeafs = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafs);
minSplit = Primitives.parseInt(cl.getOptionValue("min_split"), minSplit);
minSamplesLeaf = Primitives.parseInt(cl.getOptionValue("min_samples_leaf"),
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
output = cl.getOptionValue("output", output);
Expand All @@ -169,6 +175,7 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen
this._maxDepth = maxDepth;
this._maxLeafNodes = maxLeafs;
this._minSamplesSplit = minSplit;
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
this._outputType = ModelType.resolve(output, compress);
Expand Down Expand Up @@ -345,7 +352,7 @@ private void train2(@Nonnull final double[][] x, @Nonnull final int[] y) throws
}

RegressionTree tree = new RegressionTree(_attributes, x, response, numVars, _maxDepth,
_maxLeafNodes, _minSamplesSplit, order, samples, output, rnd2);
_maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, samples, output, rnd2);

for (int i = 0; i < n; i++) {
h[i] += _eta * tree.predict(x[i]);
Expand Down Expand Up @@ -455,7 +462,8 @@ private void traink(final double[][] x, final int[] y, final int k) throws HiveE
}

RegressionTree tree = new RegressionTree(_attributes, x, response[j], numVars,
_maxDepth, _maxLeafNodes, _minSamplesSplit, order, samples, output[j], rnd2);
_maxDepth, _maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, samples,
output[j], rnd2);
trees[j] = tree;

for (int i = 0; i < n; i++) {
Expand Down
Loading