Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease #51

Merged
merged 21 commits into from Oct 2, 2020

Conversation

samanthacampo
Copy link
Member

@samanthacampo samanthacampo commented Sep 29, 2020

Description

Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease.

Motivation

Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease.

Paper reference

https://link.springer.com/article/10.1007/s10994-006-6226-1

Copy link
Member

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small changes, mostly formatting and javadoc. I would like consistent use of "features" rather than "attributes" as we don't use "attributes" to mean features anywhere else in the codebase.


lessThanOrEqual = new ClassifierTrainingNode(impurity, lessThanData, lessThanIndices.size, depth + 1, featureIDMap, labelIDMap);
greaterThan = new ClassifierTrainingNode(impurity, greaterThanData, numExamples - lessThanIndices.size, depth + 1, featureIDMap, labelIDMap);
List<AbstractTrainingNode<Label>> output = new ArrayList<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have sized this arraylist to 2 originally, but now you've moved it we should definitely do that. Ditto for the other places in Regression which create a small array. When we move up from Java 8 we can replace it with a List.of() which will be better.

public double getImpurity() {
return impurity.impurity(labelCounts);
}
public double getImpurity() { return impurityScore;}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be on a single line.

@@ -75,7 +75,8 @@ default public double impurity(double[] input) {
}

/**
* Calculates the impurity assuming the input are weighted counts, normalizing by their sum.
* Calculates the impurity assuming the input are weighted counts, normalizing by their sum. The resulting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This javadoc isn't quite right. The counts are assumed to be weighted, they are converted into a probability distribution by dividing by their sum, and then the impurity is multiplied by the sum. It's missing the "probability distribution" bit.


public void testCART(Pair<Dataset<Label>,Dataset<Label>> p) {
TreeModel<Label> m = t.train(p.getA());
public void testCART(Pair<Dataset<Label>,Dataset<Label>> p, AbstractCARTTrainer<Label> trainer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be sharply typed (i.e. CARTClassificationTrainer not AbstractCARTTrainer<Label>). I'd prefer nobody ever use the AbstractCARTTrainer type in user code, so we shouldn't do it in the tests unless it's strictly necessary.


public class TestCART {

private static final CARTClassificationTrainer t = new CARTClassificationTrainer();
private static final CARTClassificationTrainer randomt = new CARTClassificationTrainer(5, 2, 0.0f,1.0f, true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's some random whitespace in this line?

(node.getNumExamples() > minChildWeight)) {
if (numFeaturesInSplit != featureIDMap.size()) {
Util.randpermInPlace(originalIndices, localRNG);
System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
}
List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices);
List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices, localRNG,
getUseRandomSplitPoints(),getMinImpurityDecrease() * weightSum);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe precompute getMinImpurityDecrease()*weightSum rather than do it every time?

}

@Override
public double getImpurity() {
public double getImpurity() { return impurityScore;}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting.

* Calculates the impurity score of the node.
* @return the impurity score of the node.
*/
private double calcImpurity(){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a space between () and the open curly brace.


public void testJointRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p) {
TreeModel<Regressor> m = t.train(p.getA());
public void testJointRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p, AbstractCARTTrainer<Regressor> trainer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the classification tests, I'd prefer it if the sharp CARTJointRegressionTrainer is used rather than AbstractCARTTrainer<Regressor> unless you're sharing the tests across both types of regression tree trainer.

public void testIndependentRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p) {
Model<Regressor> m = t.train(p.getA());
public void testIndependentRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p,
AbstractCARTTrainer<Regressor> trainer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sharp type.

@samanthacampo samanthacampo changed the title Samanthacampo/extreme trees Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease Sep 30, 2020
Copy link
Member

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three tiny changes to clean things up. Looks good otherwise.

@@ -126,8 +126,8 @@ public synchronized void postConfig() {
throw new IllegalArgumentException("maxDepth must be greater than or equal to 1");
}

if ((minChildWeight < 0.0f)) {
throw new IllegalArgumentException("minChildWeight must be greater than or equal to 0");
if ((minChildWeight <= 0.0f)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two sets of parentheses here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -121,20 +124,20 @@ public static void main(String[] args) throws IOException {
SparseTrainer<Regressor> trainer;
switch (o.treeType) {
case CART_INDEPENDENT:
if (o.fraction <= 0) {
trainer = new CARTRegressionTrainer(o.depth,o.minChildWeight,0.0f, 1, false, impurity,
if (o.fraction == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to fix the default value of fraction to be 1.0, and then remove this if clause entirely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -69,10 +73,12 @@ public CARTClassificationTrainer getTrainer() {
CARTClassificationTrainer trainer;
switch (cartTreeAlgorithm) {
case CART:
if (cartSplitFraction <= 0) {
trainer = new CARTClassificationTrainer(cartMaxDepth, cartMinChildWeight, 1, impurity, cartSeed);
if (cartSplitFraction == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to set the default value of cartSplitFraction to 1.0 and then remove this if statement entirely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks.

@Craigacp Craigacp merged commit f072c2c into oracle:main Oct 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants