From 131c4326052f69433875d7a16a2e461dc616b04d Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Sat, 17 Dec 2022 09:29:53 -0500 Subject: [PATCH] refactored the interface for Classifier to deprecate the Map based method of providing features values and allow for multiple samples to be provided and produce multiple predictions in a single API call --- .idea/codeStyles/codeStyleConfig.xml | 5 + README.md | 21 ++- pom.xml | 2 +- .../classifier/AbstractTreeClassifier.java | 50 +++++ .../vilaverde/classifier/BooleanFeature.java | 19 -- .../vilaverde/classifier/Classifier.java | 28 ++- .../vilaverde/classifier/FeatureVector.java | 53 ++++++ .../rocks/vilaverde/classifier/Features.java | 91 +++++++++ .../vilaverde/classifier/dt/ChoiceNode.java | 4 +- .../vilaverde/classifier/dt/DecisionNode.java | 7 +- .../classifier/dt/DecisionTreeClassifier.java | 78 ++++---- .../vilaverde/classifier/dt/EndNode.java | 9 +- .../classifier/dt/PredictionFactory.java | 9 +- .../classifier/dt/TreeClassifier.java | 5 +- .../vilaverde/classifier/dt/TreeNode.java | 4 + .../dt/{ => visitors}/FeatureNameVisitor.java | 12 +- .../dt/visitors/PredictVisitor.java | 84 +++++++++ .../ensemble/RandomForestClassifier.java | 176 ++++++++++++------ .../DecisionTreeClassifierTest.java | 127 +++++++------ .../RandomForestClassifierTest.java | 80 +++++--- 20 files changed, 631 insertions(+), 233 deletions(-) create mode 100644 .idea/codeStyles/codeStyleConfig.xml create mode 100644 src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java delete mode 100644 src/main/java/rocks/vilaverde/classifier/BooleanFeature.java create mode 100644 src/main/java/rocks/vilaverde/classifier/FeatureVector.java create mode 100644 src/main/java/rocks/vilaverde/classifier/Features.java rename src/main/java/rocks/vilaverde/classifier/dt/{ => visitors}/FeatureNameVisitor.java (62%) create mode 100644 src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/README.md b/README.md index b70b7f1..4c6c3c2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java. +[![javadoc](https://javadoc.io/badge2/rocks.vilaverde/scikit-learn-2-java/javadoc.svg)](https://javadoc.io/doc/rocks.vilaverde/scikit-learn-2-java) + ## Support * The tree.DecisionTreeClassifier is supported * Supports `predict()`, @@ -48,10 +50,10 @@ As an example, a DecisionTreeClassifier model trained on the Iris dataset and ex ``` The exported text can then be executed in Java. Note that when calling `export_text` it is -recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated. +recommended that `max_depth` be set to `sys.maxsize` so that the tree isn't truncated. ### Java Example -In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map +In this example the iris model exported using `export_text` is parsed, features are created as a Java Map and the decision tree is asked to predict the class. ``` @@ -59,13 +61,14 @@ and the decision tree is asked to predict the class. final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.INTEGER); - Map features = new HashMap<>(); - features.put("sepal length (cm)", 3.0); - features.put("sepal width (cm)", 5.0); - features.put("petal length (cm)", 4.0); - features.put("petal width (cm)", 2.0); + Features features = Features.of("sepal length (cm)", + "sepal width (cm)", + "petal length (cm)", + "petal width (cm)"); + FeatureVector fv = features.newSample(); + fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0); - Integer prediction = decisionTree.predict(features); + Integer prediction = decisionTree.predict(fv); System.out.println(prediction.toString()); ``` @@ -107,7 +110,7 @@ Then you can use the RandomForestClassifier class to parse the TAR archive. ... TarArchiveInputStream tree = getArchive("iris.tgz"); - final Classifier decisionTree = RandomForestClassifier.parse(tree, + final Classifier decisionTree = RandomForestClassifier.parse(tree, PredictionFactory.DOUBLE); ``` diff --git a/pom.xml b/pom.xml index 0676c7c..abf95bd 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ rocks.vilaverde scikit-learn-2-java - 1.0.3-SNAPSHOT + 1.1.0-SNAPSHOT ${project.groupId}:${project.artifactId} A sklearn exported_text models parser for executing in the Java runtime. diff --git a/src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java new file mode 100644 index 0000000..a3a7e65 --- /dev/null +++ b/src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java @@ -0,0 +1,50 @@ +package rocks.vilaverde.classifier; + +import rocks.vilaverde.classifier.dt.TreeClassifier; + +import java.util.Map; + +/** + * Abstract base class for Tree classifiers. + */ +public abstract class AbstractTreeClassifier implements TreeClassifier { + + /** + * Predict class or regression value for features. + * + * @param samples input samples + * @return class probabilities of the input sample + */ + @Override + public T predict(Map samples) { + FeatureVector fv = toFeatureVector(samples); + return predict(fv).get(0); + } + + /** + * Predict class probabilities of the input samples features. + * The predicted class probability is the fraction of samples of the same class in a leaf. + * + * @param samples the input samples + * @return the class probabilities of the input sample + */ + @Override + public double[] predict_proba(Map samples) { + FeatureVector fv = toFeatureVector(samples); + return predict_proba(fv)[0]; + } + + /** + * Convert a Map of features to a {@link FeatureVector}. + * @param samples a KV map of feature name to value + * @return FeatureVector + */ + private FeatureVector toFeatureVector(Map samples) { + Features features = Features.fromSet(samples.keySet()); + FeatureVector fv = features.newSample(); + for (Map.Entry entry : samples.entrySet()) { + fv.add(entry.getKey(), entry.getValue()); + } + return fv; + } +} diff --git a/src/main/java/rocks/vilaverde/classifier/BooleanFeature.java b/src/main/java/rocks/vilaverde/classifier/BooleanFeature.java deleted file mode 100644 index 4b84dd4..0000000 --- a/src/main/java/rocks/vilaverde/classifier/BooleanFeature.java +++ /dev/null @@ -1,19 +0,0 @@ -package rocks.vilaverde.classifier; - -/** - * Represents features that are Boolean as a Double `1.0` or `0.0`. - */ -public enum BooleanFeature { - - FALSE(0.0), - TRUE(1.0); - - private final Double value; - BooleanFeature(double v) { - value = v; - } - - public Double asDouble() { - return value; - } -} diff --git a/src/main/java/rocks/vilaverde/classifier/Classifier.java b/src/main/java/rocks/vilaverde/classifier/Classifier.java index 158869e..4855c30 100644 --- a/src/main/java/rocks/vilaverde/classifier/Classifier.java +++ b/src/main/java/rocks/vilaverde/classifier/Classifier.java @@ -1,24 +1,44 @@ package rocks.vilaverde.classifier; +import java.util.List; import java.util.Map; import java.util.Set; public interface Classifier { + /** + * Predict class or regression value for samples. Predictions will be + * returned at the same index of the sample provided. + * @param samples input samples + * @return class probabilities of the input sample + */ + List predict(FeatureVector ... samples); + + /** + * Predict class probabilities of the input samples. Probabilities will be + * returned at the same index of the sample provided. + * The predicted class probability is the fraction of samples of the same class in a leaf. + * @param samples the input samples + * @return the class probabilities of the input sample + */ + double[][] predict_proba(FeatureVector ... samples); + /** * Predict class or regression value for features. - * @param features input samples + * @param samples input samples * @return class probabilities of the input sample */ - T predict(Map features); + @Deprecated + T predict(Map samples); /** * Predict class probabilities of the input samples features. * The predicted class probability is the fraction of samples of the same class in a leaf. - * @param features the input samples + * @param samples the input samples * @return the class probabilities of the input sample */ - double[] predict_proba(Map features); + @Deprecated + double[] predict_proba(Map samples); /** * Get the names of all the features in the model. diff --git a/src/main/java/rocks/vilaverde/classifier/FeatureVector.java b/src/main/java/rocks/vilaverde/classifier/FeatureVector.java new file mode 100644 index 0000000..7cb839e --- /dev/null +++ b/src/main/java/rocks/vilaverde/classifier/FeatureVector.java @@ -0,0 +1,53 @@ +package rocks.vilaverde.classifier; + +/** + * A container for the values for each feature of a sample that will be predicted. + */ +public class FeatureVector { + + private final Features features; + private final double[] vector; + + public FeatureVector(Features features) { + this.features = features; + this.vector = new double[features.getLength()]; + } + + public FeatureVector add(String feature, boolean value) { + add(feature, value ? 1.0 : 0.0); + return this; + } + + public FeatureVector add(int index, boolean value) { + add(index, value ? 1.0 : 0.0); + return this; + } + + public FeatureVector add(int index, double value) { + this.vector[index] = value; + return this; + } + + public FeatureVector add(String feature, double value) { + int index = this.features.getFeatureIndex(feature); + add(index, value); + return this; + } + + public double get(int index) { + if (index >= vector.length) { + throw new IllegalArgumentException(String.format("index must be less than %d", index)); + } + + return vector[index]; + } + + public double get(String feature) { + int index = features.getFeatureIndex(feature); + return get(index); + } + + public boolean hasFeature(String feature) { + return this.features.getFeatureNames().contains(feature); + } +} diff --git a/src/main/java/rocks/vilaverde/classifier/Features.java b/src/main/java/rocks/vilaverde/classifier/Features.java new file mode 100644 index 0000000..e51a484 --- /dev/null +++ b/src/main/java/rocks/vilaverde/classifier/Features.java @@ -0,0 +1,91 @@ +package rocks.vilaverde.classifier; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; + +/** + * Container class for the set of named features that will be provided for each sample. + * Call {@link Features#newSample()} to create a {@link FeatureVector} to provide the + * values for the sample. + */ +public class Features { + + /* Map of feature name to index */ + private final Map features = new HashMap<>(); + /* Feature can be added as long as no samples have yet to be created, + at that point this is immutable */ + private boolean allowFeatureAdd = true; + + /** + * Convienence creation method. + * @param features the set of features. + * @return Features + */ + public static Features of(String ... features) { + + // make sure all the features are unique + Set featureSet = new HashSet<>(Arrays.asList(features)); + if (featureSet.size() != features.length) { + throw new IllegalArgumentException("features names are not unique"); + } + + return new Features(features); + } + + public static Features fromSet(Set features) { + return new Features(features.toArray(new String[0])); + } + + /** + * Constructor + * @param features order list of features. + */ + private Features(String ... features) { + for (int i = 0; i < features.length; i++) { + this.features.put(features[i], i); + } + } + + FeatureVector newSample() { + allowFeatureAdd = false; + return new FeatureVector(this); + } + + public void addFeature(String feature) { + if (!allowFeatureAdd) { + throw new IllegalStateException("features are immutable"); + } + + if (!this.features.containsKey(feature)) { + int next = 0; + OptionalInt optionalInt = features.values().stream().mapToInt(integer -> integer).max(); + if (optionalInt.isPresent()) { + next = optionalInt.getAsInt() + 1; + } + + this.features.put(feature, next); + } + } + + public int getLength() { + return this.features.size(); + } + + public int getFeatureIndex(String feature) { + Integer index = this.features.get(feature); + if (index == null) { + throw new IllegalArgumentException(String.format("feature %s does not exist", feature)); + } + + return index; + } + + public Set getFeatureNames() { + return Collections.unmodifiableSet(this.features.keySet()); + } +} diff --git a/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java b/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java index 5e3bf22..3cd494e 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java @@ -6,7 +6,7 @@ * Represents a Choice in the decision tree, where when the expression is evaluated, * if true will result in the child node of the choice being selected. */ -class ChoiceNode extends TreeNode { +public class ChoiceNode extends TreeNode { private final Operator op; private final Double value; @@ -38,7 +38,7 @@ public String toString() { return String.format("%s %s", op.toString(), value.toString()); } - public boolean eval(Double featureValue) { + public boolean eval(double featureValue) { return op.apply(featureValue, value); } } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java b/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java index 3952d51..4746545 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java @@ -1,14 +1,11 @@ package rocks.vilaverde.classifier.dt; -import java.util.ArrayList; -import java.util.List; - /** * Represents a decision in the DecisionTreeClassifier. The decision will have * a left and right hand {@link ChoiceNode} to be evaluated. * A {@link ChoiceNode} may have nested {@link DecisionNode} or {@link EndNode}. */ -class DecisionNode extends TreeNode { +public class DecisionNode extends TreeNode { private final String featureName; @@ -26,7 +23,7 @@ public static DecisionNode create(String feature) { /** * Private Constructor. - * @param featureName + * @param featureName the name of the feature used in this decision */ private DecisionNode(String featureName) { this.featureName = featureName.intern(); diff --git a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java index 66c5eed..a9c14cf 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java @@ -1,16 +1,27 @@ package rocks.vilaverde.classifier.dt; +import rocks.vilaverde.classifier.AbstractTreeClassifier; +import rocks.vilaverde.classifier.FeatureVector; import rocks.vilaverde.classifier.Operator; import rocks.vilaverde.classifier.Prediction; +import rocks.vilaverde.classifier.dt.visitors.FeatureNameVisitor; +import rocks.vilaverde.classifier.dt.visitors.PredictVisitor; -import java.awt.*; import java.io.BufferedReader; import java.io.Reader; -import java.util.Map; +import java.util.Arrays; +import java.util.List; import java.util.Set; import java.util.Stack; +import java.util.stream.Collectors; -public class DecisionTreeClassifier implements TreeClassifier { +/** + * Represents a DecisionTreeClassifier trained in scikit-learn and exported using + * export_text. + * @param the Prediction Class + */ +public class DecisionTreeClassifier extends AbstractTreeClassifier + implements TreeClassifier { /** * Factory method to create the classifier from the {@link Reader}. @@ -49,64 +60,51 @@ private DecisionTreeClassifier(PredictionFactory predictionFactory) { /** * Predict class or regression value for features. * For the classification model, the predicted class for the features of the sample is returned. - * @param features Map of feature name to value + * @param samples Map of feature name to value * @return predicted class */ - public T predict(Map features) { - return getClassification(features).get(); + public List predict(FeatureVector ... samples) { + return Arrays.stream(samples) + .map(featureVector -> getClassification(featureVector).get()) + .collect(Collectors.toList()); } /** * Predict class probabilities of the input features. * The predicted class probability is the fraction of samples of the same class in a leaf. - * @param features Map of feature name to value + * @param samples Map of feature name to value * @return the class probabilities of the input sample */ @Override - public double[] predict_proba(Map features) { - return getClassification(features).getProbability(); + public double[][] predict_proba(FeatureVector ... samples) { + double[][] probabilities = null; + + for (int i = 0; i < samples.length; i++) { + double[] result = getClassification(samples[i]).getProbability(); + if (probabilities == null) { + probabilities = new double[samples.length][result.length]; + } + + probabilities[i] = result; + } + + return probabilities; } /** * Find the {@link Prediction} in the decision tree. */ - public Prediction getClassification(Map features) { - validateFeature(features); - - TreeNode currentNode = root; - - // traverse through the tree until the end node is reached. - while (!(currentNode instanceof EndNode)) { - - if (currentNode != null) { - DecisionNode decisionNode = ((DecisionNode) currentNode); - - Double featureValue = features.get(decisionNode.getFeatureName()); - if (featureValue == null) { - featureValue = Double.NaN; - } - - if (decisionNode.getLeft().eval(featureValue)) { - currentNode = decisionNode.getLeft().getChild(); - } else if (decisionNode.getRight().eval(featureValue)) { - currentNode = decisionNode.getRight().getChild(); - } else { - // if I get here something is wrong since none of the branches evaluated to true - throw new RuntimeException(String.format("no branches evaluated to true for feature '%s'", - decisionNode.getFeatureName())); - } - } - } - - return (Prediction) currentNode; + public Prediction getClassification(FeatureVector sample) { + validateFeatures(sample); + return PredictVisitor.predict(sample, root); } /** * Validate the features provided are expected. */ - private void validateFeature(Map features) throws IllegalArgumentException { + private void validateFeatures(FeatureVector sample) throws IllegalArgumentException { for (String f : featureNames) { - if (!features.containsKey(f)) { + if (!sample.hasFeature(f)) { throw new IllegalArgumentException(String.format("expected feature named '%s' but none provided", f)); } } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java b/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java index 61ad6ae..d5d904c 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java @@ -8,13 +8,18 @@ * Represents the end of the tree, where no further decisions can be made. The end node contains * the prediction. */ -class EndNode extends TreeNode implements Prediction { +public class EndNode extends TreeNode implements Prediction { private static final MessageFormat CLASS_FORMAT = new MessageFormat("class: {0}"); private final T prediction; /** * Factory method to create the appropriate {@link EndNode} from the * String in exported tree model. + * @param endNodeString the serialized EndNode text from the sklearn export_text + * @param predictionFactory the {@link PredictionFactory} used to deserialize the value after "class:" + * @return the {@link EndNode} + * @throws Exception may be thrown when end node string can't be parsed. + * @param the java type of the classification once deserialized by the predictionFactory */ public static EndNode create(String endNodeString, PredictionFactory predictionFactory) throws Exception { @@ -56,7 +61,7 @@ public String toString() { * {@link EndNode} that supports calculating the probability from the * weights in the exported tree model. */ - static class WeightedEndNode extends EndNode { + public static class WeightedEndNode extends EndNode { private static final MessageFormat WEIGHTS_FORMAT = new MessageFormat("weights: {0} class: {1}"); diff --git a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java index 8af8190..10c8b2e 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java @@ -1,7 +1,9 @@ package rocks.vilaverde.classifier.dt; /** - * A prediction from the classifier + * A prediction from the classifier can be of type Double, Integer, or Boolean. This + * factory is used while parsing the text export of the DecisionTreeClassifier to construct + * the correct Java type for the serialized value. */ @FunctionalInterface public interface PredictionFactory { @@ -10,5 +12,10 @@ public interface PredictionFactory { PredictionFactory INTEGER = Integer::valueOf; PredictionFactory DOUBLE = Double::parseDouble; + /** + * Convert a String value to the appropriate type for the model. + * @param value the serialized text value of the prediction from the exported decision tree. + * @return the deserialized type + */ T create(String value); } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java index a0d59e9..76f6f34 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java @@ -1,14 +1,13 @@ package rocks.vilaverde.classifier.dt; import rocks.vilaverde.classifier.Classifier; +import rocks.vilaverde.classifier.FeatureVector; import rocks.vilaverde.classifier.Prediction; -import java.util.Map; - /** * Implemented by Tree classifiers. */ public interface TreeClassifier extends Classifier { - Prediction getClassification(Map features); + Prediction getClassification(FeatureVector features); } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java b/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java index 0355b40..46f5e53 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java @@ -4,5 +4,9 @@ public abstract class TreeNode implements Visitable { + /** + * Must be implemented by every implementation of TreeNode to call visit() on the visitor. + * @param visitor the visitor + */ public abstract void accept(AbstractDecisionTreeVisitor visitor); } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/FeatureNameVisitor.java b/src/main/java/rocks/vilaverde/classifier/dt/visitors/FeatureNameVisitor.java similarity index 62% rename from src/main/java/rocks/vilaverde/classifier/dt/FeatureNameVisitor.java rename to src/main/java/rocks/vilaverde/classifier/dt/visitors/FeatureNameVisitor.java index 71c67c8..332644f 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/FeatureNameVisitor.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/visitors/FeatureNameVisitor.java @@ -1,4 +1,7 @@ -package rocks.vilaverde.classifier.dt; +package rocks.vilaverde.classifier.dt.visitors; + +import rocks.vilaverde.classifier.dt.AbstractDecisionTreeVisitor; +import rocks.vilaverde.classifier.dt.DecisionNode; import java.util.HashSet; import java.util.Set; @@ -11,12 +14,13 @@ public class FeatureNameVisitor extends AbstractDecisionTreeVisitor { private final Set featureNames = new HashSet<>(); - + /** + * Visit a {@link DecisionNode} and collect the feature name used in the decision. + * @param object the {@link DecisionNode} being visited. + */ @Override public void visit(DecisionNode object) { - featureNames.add(object.getFeatureName()); - super.visit(object); } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java b/src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java new file mode 100644 index 0000000..f750cb7 --- /dev/null +++ b/src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java @@ -0,0 +1,84 @@ +package rocks.vilaverde.classifier.dt.visitors; + +import rocks.vilaverde.classifier.FeatureVector; +import rocks.vilaverde.classifier.Prediction; +import rocks.vilaverde.classifier.dt.AbstractDecisionTreeVisitor; +import rocks.vilaverde.classifier.dt.ChoiceNode; +import rocks.vilaverde.classifier.dt.DecisionNode; +import rocks.vilaverde.classifier.dt.EndNode; +import rocks.vilaverde.classifier.dt.TreeNode; + +/** + * Visits the nodes of the {@link rocks.vilaverde.classifier.dt.TreeClassifier} looking for the + * {@link rocks.vilaverde.classifier.dt.EndNode} for the given {@link rocks.vilaverde.classifier.FeatureVector}. + */ +public class PredictVisitor extends AbstractDecisionTreeVisitor { + + private final FeatureVector sample; + private Prediction prediction; + + /** + * Convenience method to search a tree for a prediction. + * @param sample the sample {@link FeatureVector} + * @param root the root {@link TreeNode} of the Decision Tree + * @return the {@link Prediction} + * @param the classification java type + */ + public static Prediction predict(FeatureVector sample, TreeNode root) { + PredictVisitor visitor = new PredictVisitor<>(sample); + root.accept(visitor); + + if (visitor.getPrediction() == null) { + throw new RuntimeException("expected a prediction result from the tree, but none found"); + } + + return visitor.getPrediction(); + } + + /** + * Constructor. + * @param sample the {@link FeatureVector} + */ + private PredictVisitor(FeatureVector sample) { + this.sample = sample; + } + + /** + * When visiting a {@link DecisionNode} we need to test the left and right + * {@link ChoiceNode} and visit only the one that evaluates to true. + * @param object the {@link DecisionNode} being visited + */ + @Override + public void visit(DecisionNode object) { + + // don't call super otherwise both choice nodes are visited. + + double featureValue = this.sample.get(object.getFeatureName()); + if (object.getLeft().eval(featureValue)) { + object.getLeft().getChild().accept(this); + } else if (object.getRight().eval(featureValue)) { + object.getRight().getChild().accept(this); + } else { + throw new RuntimeException(String.format("no branches evaluated to true for feature '%s'", + object.getFeatureName())); + } + } + + /** + * When visiting an {@link EndNode} we've found the prediction + * and no longer need to visit the tree. + * @param object the {@link EndNode} being visited + */ + @Override + public void visit(EndNode object) { + this.prediction = object; + } + + /** + * Get the prediction by searching the decision tree. + * @return the {@link Prediction} + */ + public Prediction getPrediction() { + return prediction; + } +} diff --git a/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java index 36f0e39..6e2380c 100644 --- a/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java +++ b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java @@ -1,10 +1,12 @@ package rocks.vilaverde.classifier.ensemble; -import org.apache.commons.compress.archivers.tar.TarArchiveEntry; -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.archivers.ArchiveEntry; +import org.apache.commons.compress.archivers.ArchiveInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import rocks.vilaverde.classifier.AbstractTreeClassifier; import rocks.vilaverde.classifier.Classifier; +import rocks.vilaverde.classifier.FeatureVector; import rocks.vilaverde.classifier.Prediction; import rocks.vilaverde.classifier.dt.DecisionTreeClassifier; import rocks.vilaverde.classifier.dt.PredictionFactory; @@ -15,26 +17,34 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.util.*; -import java.util.concurrent.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.stream.Collectors; /** * A forest of DecisionTreeClassifiers. */ -public class RandomForestClassifier implements Classifier { +public class RandomForestClassifier extends AbstractTreeClassifier + implements Classifier { private static final Logger LOG = LoggerFactory.getLogger(RandomForestClassifier.class); /** * Accept a TAR of exported DecisionTreeClassifiers from sklearn and product a * RandomForestClassifier. This default to running in a single (current) thread. - * @param tar the Tar Archive input stream + * @param tar the {@link ArchiveInputStream} * @param factory the factory for creating the prediction class * @return the {@link Classifier} * @param the classifier type * @throws Exception when the model could no be parsed */ - public static Classifier parse(final TarArchiveInputStream tar, + public static Classifier parse(final ArchiveInputStream tar, PredictionFactory factory) throws Exception { return RandomForestClassifier.parse(tar, factory, null); } @@ -49,14 +59,14 @@ public static Classifier parse(final TarArchiveInputStream tar, * @param the classifier type * @throws Exception when the model could no be parsed */ - public static Classifier parse(final TarArchiveInputStream tar, + public static Classifier parse(final ArchiveInputStream tar, PredictionFactory factory, ExecutorService executor) throws Exception { List> forest = new ArrayList<>(); try (tar) { - TarArchiveEntry exportedTree; - while ((exportedTree = tar.getNextTarEntry()) != null) { + ArchiveEntry exportedTree; + while ((exportedTree = tar.getNextEntry()) != null) { if (!exportedTree.isDirectory()) { LOG.debug("Parsing tree {}", exportedTree.getName()); final InputStream noCloseStream = new InputStream() { @@ -95,75 +105,69 @@ private RandomForestClassifier(List> forest, ExecutorService e /** * Predict class or regression value for features. - * @param features features of the sample + * @param samples features of the sample * @return class probabilities of the input sample */ @Override - public T predict(Map features) { - List> predictions = getPredictions(features); - Map map = predictions.stream() - .collect(Collectors.groupingBy(Prediction::get, Collectors.counting())); - - long max = map.values().stream().mapToLong(Long::longValue).max().orElse(0); - for (Map.Entry entry : map.entrySet()) { - if (entry.getValue() == max) { - return entry.getKey(); - } - } - - throw new IllegalStateException("no classification"); + public List predict(FeatureVector ... samples) { + return Arrays.stream(samples) + .map(this::predictSingle) + .collect(Collectors.toList()); } /** * Predict class probabilities of the input samples features. * The predicted class probability is the fraction of samples of the same class in a leaf. - * @param features the input samples + * @param samples the input samples * @return the class probabilities of the input sample */ @Override - public double[] predict_proba(Map features) { - if (forest.size() == 1) { - return forest.get(0).getClassification(features).getProbability(); - } - - List> predictions = getPredictions(features); + public double[][] predict_proba(FeatureVector ... samples) { + double[][] probabilities = null; - double[] result = null; - for (Prediction prediction : predictions) { - double[] prob = prediction.getProbability(); + for (int i = 0; i < samples.length; i ++) { + FeatureVector fv = samples[i]; + Prediction prediction = getClassification(fv); - if (result == null) { - result = prob; - } else { - for (int i = 0; i < prob.length; i++) { - result[i] += prob[i]; - } + double[] sampleProb = prediction.getProbability(); + if (probabilities == null) { + probabilities = new double[samples.length][sampleProb.length]; } - } - if (result != null) { - int forestSize = forest.size(); - for (int i = 0; i < result.length; i++) { - result[i] /= forestSize; - } + probabilities[i] = sampleProb; } - return result; + return probabilities; + } + + protected T predictSingle(FeatureVector sample) { + return getClassification(sample).get(); + } + + /** + * Return a prediction from the forest for the sample. + * @param sample a samples feature vector + * @return Prediction + */ + @Override + public Prediction getClassification(FeatureVector sample) { + final List> predictions = getPredictions(sample); + return new RandomForestPrediction<>(predictions, forest.size()); } /** * Get all the predictions for the features of the sample - * @param features features of the sample + * @param sample features of the sample * @return a List of {@link Prediction} objects from the trees in the forest. */ - protected List> getPredictions(final Map features) { + protected List> getPredictions(final FeatureVector sample) { List> predictions; if (executorService != null) { int jobs = Runtime.getRuntime().availableProcessors(); List> parallel = new ArrayList<>(jobs); for (int i = 0; i < jobs; i++) { - ParallelPrediction parallelPrediction = new ParallelPrediction<>(forest, features, i, jobs); + ParallelPrediction parallelPrediction = new ParallelPrediction<>(forest, sample, i, jobs); parallel.add(parallelPrediction); } @@ -179,7 +183,7 @@ protected List> getPredictions(final Map features) } else { predictions = new ArrayList<>(forest.size()); for (TreeClassifier tree : forest) { - Prediction prediction = tree.getClassification(features); + Prediction prediction = tree.getClassification(sample); predictions.add(prediction); } } @@ -201,6 +205,66 @@ public Set getFeatureNames() { } + static class RandomForestPrediction implements Prediction { + private final List> predictions; + private final int forestSize; + + /** + * Constructor + * @param predictions the list of predictions that need to be merged + * @param forestSize the number of trees in the forest + */ + public RandomForestPrediction(List> predictions, int forestSize) { + this.predictions = predictions; + this.forestSize = forestSize; + } + + /** + * @return The class. + */ + @Override + public T get() { + Map map = predictions.stream() + .collect(Collectors.groupingBy(Prediction::get, Collectors.counting())); + + long max = map.values().stream().mapToLong(Long::longValue).max().orElse(0); + T prediction = null; + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() == max) { + prediction = entry.getKey(); + break; + } + } + + if (prediction == null) { + throw new IllegalStateException("no classification"); + } + + return prediction; + } + + /** + * @return the probability + */ + @Override + public double[] getProbability() { + int arraySize = predictions.get(0).getProbability().length; + final double[] result = new double[arraySize]; + for (Prediction prediction : predictions) { + double[] prob = prediction.getProbability(); + for (int j = 0; j < prob.length; j++) { + result[j] += prob[j]; + } + } + + for (int j = 0; j < result.length; j++) { + result[j] /= forestSize; + } + return result; + } + } + + /** * A job that will only provide {@link Prediction} results * from a subset of the trees in the forest. @@ -210,23 +274,23 @@ private static class ParallelPrediction implements Callable> forest; - private final Map features; + private final FeatureVector sample; /** * Constructor * @param forest the random forest - * @param features the features of a sample + * @param sample the features of a sample * @param start the index of the tree in the forest this job will begin with * @param offset the offest between the current tree and the next tree to call predict on */ private ParallelPrediction(List> forest, - Map features, + FeatureVector sample, int start, int offset) { this.offset = offset; this.start = start; this.forest = forest; - this.features = features; + this.sample = sample; } /** @@ -239,7 +303,7 @@ public List> call() throws Exception { List> predictions = new ArrayList<>(); for (int i = start; i < forest.size(); i+=offset) { - predictions.add(forest.get(i).getClassification(features)); + predictions.add(forest.get(i).getClassification(sample)); } return predictions; diff --git a/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java b/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java index d782f33..e6f8c46 100644 --- a/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java +++ b/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java @@ -8,9 +8,6 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; public class DecisionTreeClassifierTest { @@ -21,8 +18,12 @@ public void parseSimpleTree() throws Exception { final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN); Assertions.assertNotNull(decisionTree); - Assertions.assertFalse(decisionTree.predict(Collections.singletonMap("feature1", 1.2))); - Assertions.assertTrue(decisionTree.predict(Collections.singletonMap("feature1", 2.4))); + Features features = Features.of("feature1"); + FeatureVector sample1 = features.newSample().add(0, 1.2); + Assertions.assertFalse(decisionTree.predict(sample1).get(0)); + + FeatureVector sample2 = features.newSample().add(0, 2.4); + Assertions.assertTrue(decisionTree.predict(sample2).get(0)); Assertions.assertNotNull(decisionTree.getFeatureNames()); Assertions.assertEquals(1, decisionTree.getFeatureNames().size()); @@ -35,20 +36,22 @@ public void parseDeepTree() throws Exception { final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN); Assertions.assertNotNull(decisionTree); - Map features = new HashMap<>(); - features.put("feature1", 0.0); - features.put("feature2", 1.0); - features.put("feature3", BooleanFeature.FALSE.asDouble()); - features.put("feature4", 0.0); - features.put("feature5", BooleanFeature.FALSE.asDouble()); - features.put("feature6", 1.0); - features.put("feature7", 1.0); - features.put("feature8", 0.0); + Features features = Features.of("feature1", "feature2", "feature3", "feature4", + "feature5", "feature6", "feature7", "feature8"); + FeatureVector fv = features.newSample(); + fv.add("feature1", 0.0) + .add("feature2", 1.0) + .add("feature3", false) + .add("feature4", 0.0) + .add("feature5", false) + .add("feature6", 1.0) + .add("feature7", 1.0) + .add("feature8", 0.0); - Assertions.assertFalse(decisionTree.predict(features)); + Assertions.assertFalse(decisionTree.predict(fv).get(0)); - features.put("feature5", BooleanFeature.TRUE.asDouble()); - Assertions.assertTrue(decisionTree.predict(features)); + fv.add("feature5", true); + Assertions.assertTrue(decisionTree.predict(fv).get(0)); Assertions.assertNotNull(decisionTree.getFeatureNames()); Assertions.assertEquals(8, decisionTree.getFeatureNames().size()); @@ -61,17 +64,19 @@ public void invalidFeatureName() { final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN); Assertions.assertNotNull(decisionTree); - Map features = new HashMap<>(); - features.put("feature11", 0.0); - features.put("feature2", 1.0); - features.put("feature3", BooleanFeature.FALSE.asDouble()); - features.put("feature4", 0.0); - features.put("feature5", BooleanFeature.FALSE.asDouble()); - features.put("feature6", 1.0); - features.put("feature7", 1.0); - features.put("feature8", 0.0); - - decisionTree.predict(features); + Features features = Features.of("feature11", "feature2", "feature3", "feature4", + "feature5", "feature6", "feature7", "feature8"); + FeatureVector fv = features.newSample(); + fv.add(0, 0.0) + .add(1, 1.0) + .add(2, false) + .add(3, 0.0) + .add(4, false) + .add(5, 1.0) + .add(6, 1.0) + .add(7, 0.0); + + decisionTree.predict(fv); }); Assertions.assertEquals("expected feature named 'feature1' but none provided", @@ -85,15 +90,17 @@ public void invalidFeatureCount() { final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN); Assertions.assertNotNull(decisionTree); - Map features = new HashMap<>(); - features.put("feature1", 0.0); - features.put("feature2", 1.0); - features.put("feature3", BooleanFeature.FALSE.asDouble()); - features.put("feature4", 0.0); - features.put("feature5", BooleanFeature.FALSE.asDouble()); - features.put("feature6", 1.0); - - decisionTree.predict(features); + Features features = Features.of("feature1", "feature2", "feature3", "feature4", + "feature5", "feature6"); + FeatureVector fv = features.newSample(); + fv.add(0, 0.0) + .add(1, 1.0) + .add(2, false) + .add(3, 0.0) + .add(4, false) + .add(5, 1.0); + + decisionTree.predict(fv); }); Assertions.assertEquals("expected feature named 'feature7' but none provided", @@ -106,19 +113,21 @@ public void probability() throws Exception { final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN); Assertions.assertNotNull(decisionTree); - Map features = new HashMap<>(); - features.put("feature1", 1.2); - features.put("feature2", 88.33); - features.put("feature3", BooleanFeature.FALSE.asDouble()); - features.put("feature4", 1.727); - features.put("feature5", BooleanFeature.FALSE.asDouble()); - features.put("feature6", 1.0); - features.put("feature7", 0.0048); - features.put("feature8", 0.0); - - Assertions.assertFalse(decisionTree.predict(features)); - - double[] prediction = decisionTree.predict_proba(features); + Features features = Features.of("feature1", "feature2", "feature3", "feature4", + "feature5", "feature6", "feature7", "feature8"); + FeatureVector fv = features.newSample(); + fv.add("feature1", 1.2) + .add("feature2", 88.33) + .add("feature3", false) + .add("feature4", 1.727) + .add("feature5", false) + .add("feature6", 1.0) + .add("feature7", 0.0048) + .add("feature8", 0.0); + + Assertions.assertFalse(decisionTree.predict(fv).get(0)); + + double[] prediction = decisionTree.predict_proba(fv)[0]; Assertions.assertNotNull(prediction); Assertions.assertEquals(0.63636364, prediction[0], .00000001); Assertions.assertEquals(0.36363636, prediction[1], .00000001); @@ -130,19 +139,19 @@ public void noWeights() throws Exception { final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.INTEGER); Assertions.assertNotNull(decisionTree); - Map features = new HashMap<>(); - features.put("sepal length (cm)", 3.0); - features.put("sepal width (cm)", 5.0); - features.put("petal length (cm)", 4.0); - features.put("petal width (cm)", 2.0); - - Integer prediction = decisionTree.predict(features); + Features features = Features.of("sepal length (cm)", + "sepal width (cm)", + "petal length (cm)", + "petal width (cm)"); + FeatureVector fv = features.newSample(); + fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0); + Integer prediction = decisionTree.predict(fv).get(0); Assertions.assertNotNull(prediction); Assertions.assertEquals(1, prediction.intValue()); - features.put("sepal length (cm)", 6.0); - prediction = decisionTree.predict(features); + fv.add("sepal length (cm)", 6.0); + prediction = decisionTree.predict(fv).get(0); Assertions.assertEquals(2, prediction.intValue()); } diff --git a/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java index cffffcc..9ea2acf 100644 --- a/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java +++ b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java @@ -11,8 +11,8 @@ import java.io.IOException; import java.io.InputStream; -import java.util.HashMap; -import java.util.Map; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -40,37 +40,55 @@ public void randomForestParallel() throws Exception { PredictionFactory.DOUBLE, executorService); Assertions.assertNotNull(decisionTree); - double[] proba = decisionTree.predict_proba(getSample1()); + double[] proba = decisionTree.predict_proba(getSample1())[0]; assertSample(proba, .06, .62, .32); - Double prediction = decisionTree.predict(getSample1()); + Double prediction = decisionTree.predict(getSample1()).get(0); Assertions.assertNotNull(prediction); Assertions.assertEquals(1.0, prediction, .0); - prediction = decisionTree.predict(getSample2()); + prediction = decisionTree.predict(getSample2()).get(0); Assertions.assertEquals(2, prediction.intValue()); - proba = decisionTree.predict_proba(getSample2()); + proba = decisionTree.predict_proba(getSample2())[0]; assertSample(proba, 0.0, .44, .56); } + @Test + public void randomForestParallel10000() throws Exception { + TarArchiveInputStream exported = getExportedModel("rf/iris.tgz"); + final Classifier decisionTree = RandomForestClassifier.parse(exported, + PredictionFactory.DOUBLE, executorService); + Assertions.assertNotNull(decisionTree); + + List vectorList = new ArrayList<>(10000); + for (int i = 0; i< 5000; i++) { + vectorList.add(getSample1()); + vectorList.add(getSample2()); + } + + double[][] proba = decisionTree.predict_proba(vectorList.toArray(new FeatureVector[0])); + assertSample(proba[0], .06, .62, .32); + assertSample(proba[1], 0.0, .44, .56); + } + @Test public void randomForest() throws Exception { TarArchiveInputStream exported = getExportedModel("rf/iris.tgz"); final Classifier decisionTree = RandomForestClassifier.parse(exported, PredictionFactory.DOUBLE); Assertions.assertNotNull(decisionTree); - double[] proba = decisionTree.predict_proba(getSample1()); + double[] proba = decisionTree.predict_proba(getSample1())[0]; assertSample(proba, .06, .62, .32); - Double prediction = decisionTree.predict(getSample1()); + Double prediction = decisionTree.predict(getSample1()).get(0); Assertions.assertNotNull(prediction); Assertions.assertEquals(1.0, prediction, .0); - prediction = decisionTree.predict(getSample2()); + prediction = decisionTree.predict(getSample2()).get(0); Assertions.assertEquals(2, prediction.intValue()); - proba = decisionTree.predict_proba(getSample2()); + proba = decisionTree.predict_proba(getSample2())[0]; assertSample(proba, 0.0, .44, .56); } @@ -82,10 +100,10 @@ public void invalidFeatureCount() { Assertions.assertNotNull(decisionTree); // create invalid number of features. - Map features = getSample1(); - features.remove("sepal length (cm)"); - - decisionTree.predict(features); + Features features = Features.of("sepal width (cm)", "petal length (cm)", "petal width (cm)"); + FeatureVector fv = features.newSample(); + fv.add(0, 3.0).add(1, 5.0).add(2, 4.0); + decisionTree.predict(fv); }); Assertions.assertEquals("expected feature named 'sepal length (cm)' but none provided", @@ -101,22 +119,28 @@ public void featureNames() throws Exception { Assertions.assertEquals(4, decisionTree.getFeatureNames().size()); } - private Map getSample1() { - Map features = new HashMap<>(); - features.put("sepal length (cm)", 3.0); - features.put("sepal width (cm)", 5.0); - features.put("petal length (cm)", 4.0); - features.put("petal width (cm)", 2.0); - return features; + private FeatureVector getSample1() { + Features features = Features.of("sepal length (cm)", + "sepal width (cm)", + "petal length (cm)", + "petal width (cm)"); + FeatureVector fv = features.newSample(); + return fv.add(0, 3.0) + .add(1, 5.0) + .add(2, 4.0) + .add(3, 2.0); } - private Map getSample2() { - Map features = new HashMap<>(); - features.put("sepal length (cm)", 1.0); - features.put("sepal width (cm)", 2.0); - features.put("petal length (cm)", 3.0); - features.put("petal width (cm)", 4.0); - return features; + private FeatureVector getSample2() { + Features features = Features.of("sepal length (cm)", + "sepal width (cm)", + "petal length (cm)", + "petal width (cm)"); + FeatureVector fv = features.newSample(); + return fv.add(0, 1.0) + .add(1, 2.0) + .add(2, 3.0) + .add(3, 4.0); } private void assertSample(double[] proba, double expected, double expected1, double expected2) {