diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml
new file mode 100644
index 0000000..a55e7a1
--- /dev/null
+++ b/.idea/codeStyles/codeStyleConfig.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index b70b7f1..4c6c3c2 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,8 @@
This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.
+[](https://javadoc.io/doc/rocks.vilaverde/scikit-learn-2-java)
+
## Support
* The tree.DecisionTreeClassifier is supported
* Supports `predict()`,
@@ -48,10 +50,10 @@ As an example, a DecisionTreeClassifier model trained on the Iris dataset and ex
```
The exported text can then be executed in Java. Note that when calling `export_text` it is
-recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated.
+recommended that `max_depth` be set to `sys.maxsize` so that the tree isn't truncated.
### Java Example
-In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map
+In this example the iris model exported using `export_text` is parsed, features are created as a Java Map
and the decision tree is asked to predict the class.
```
@@ -59,13 +61,14 @@ and the decision tree is asked to predict the class.
final Classifier decisionTree = DecisionTreeClassifier.parse(tree,
PredictionFactory.INTEGER);
- Map features = new HashMap<>();
- features.put("sepal length (cm)", 3.0);
- features.put("sepal width (cm)", 5.0);
- features.put("petal length (cm)", 4.0);
- features.put("petal width (cm)", 2.0);
+ Features features = Features.of("sepal length (cm)",
+ "sepal width (cm)",
+ "petal length (cm)",
+ "petal width (cm)");
+ FeatureVector fv = features.newSample();
+ fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);
- Integer prediction = decisionTree.predict(features);
+ Integer prediction = decisionTree.predict(fv);
System.out.println(prediction.toString());
```
@@ -107,7 +110,7 @@ Then you can use the RandomForestClassifier class to parse the TAR archive.
...
TarArchiveInputStream tree = getArchive("iris.tgz");
- final Classifier decisionTree = RandomForestClassifier.parse(tree,
+ final Classifier decisionTree = RandomForestClassifier.parse(tree,
PredictionFactory.DOUBLE);
```
diff --git a/pom.xml b/pom.xml
index 0676c7c..abf95bd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
rocks.vilaverdescikit-learn-2-java
- 1.0.3-SNAPSHOT
+ 1.1.0-SNAPSHOT${project.groupId}:${project.artifactId}A sklearn exported_text models parser for executing in the Java runtime.
diff --git a/src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java
new file mode 100644
index 0000000..a3a7e65
--- /dev/null
+++ b/src/main/java/rocks/vilaverde/classifier/AbstractTreeClassifier.java
@@ -0,0 +1,50 @@
+package rocks.vilaverde.classifier;
+
+import rocks.vilaverde.classifier.dt.TreeClassifier;
+
+import java.util.Map;
+
+/**
+ * Abstract base class for Tree classifiers.
+ */
+public abstract class AbstractTreeClassifier implements TreeClassifier {
+
+ /**
+ * Predict class or regression value for features.
+ *
+ * @param samples input samples
+ * @return class probabilities of the input sample
+ */
+ @Override
+ public T predict(Map samples) {
+ FeatureVector fv = toFeatureVector(samples);
+ return predict(fv).get(0);
+ }
+
+ /**
+ * Predict class probabilities of the input samples features.
+ * The predicted class probability is the fraction of samples of the same class in a leaf.
+ *
+ * @param samples the input samples
+ * @return the class probabilities of the input sample
+ */
+ @Override
+ public double[] predict_proba(Map samples) {
+ FeatureVector fv = toFeatureVector(samples);
+ return predict_proba(fv)[0];
+ }
+
+ /**
+ * Convert a Map of features to a {@link FeatureVector}.
+ * @param samples a KV map of feature name to value
+ * @return FeatureVector
+ */
+ private FeatureVector toFeatureVector(Map samples) {
+ Features features = Features.fromSet(samples.keySet());
+ FeatureVector fv = features.newSample();
+ for (Map.Entry entry : samples.entrySet()) {
+ fv.add(entry.getKey(), entry.getValue());
+ }
+ return fv;
+ }
+}
diff --git a/src/main/java/rocks/vilaverde/classifier/BooleanFeature.java b/src/main/java/rocks/vilaverde/classifier/BooleanFeature.java
deleted file mode 100644
index 4b84dd4..0000000
--- a/src/main/java/rocks/vilaverde/classifier/BooleanFeature.java
+++ /dev/null
@@ -1,19 +0,0 @@
-package rocks.vilaverde.classifier;
-
-/**
- * Represents features that are Boolean as a Double `1.0` or `0.0`.
- */
-public enum BooleanFeature {
-
- FALSE(0.0),
- TRUE(1.0);
-
- private final Double value;
- BooleanFeature(double v) {
- value = v;
- }
-
- public Double asDouble() {
- return value;
- }
-}
diff --git a/src/main/java/rocks/vilaverde/classifier/Classifier.java b/src/main/java/rocks/vilaverde/classifier/Classifier.java
index 158869e..4855c30 100644
--- a/src/main/java/rocks/vilaverde/classifier/Classifier.java
+++ b/src/main/java/rocks/vilaverde/classifier/Classifier.java
@@ -1,24 +1,44 @@
package rocks.vilaverde.classifier;
+import java.util.List;
import java.util.Map;
import java.util.Set;
public interface Classifier {
+ /**
+ * Predict class or regression value for samples. Predictions will be
+ * returned at the same index of the sample provided.
+ * @param samples input samples
+ * @return class probabilities of the input sample
+ */
+ List predict(FeatureVector ... samples);
+
+ /**
+ * Predict class probabilities of the input samples. Probabilities will be
+ * returned at the same index of the sample provided.
+ * The predicted class probability is the fraction of samples of the same class in a leaf.
+ * @param samples the input samples
+ * @return the class probabilities of the input sample
+ */
+ double[][] predict_proba(FeatureVector ... samples);
+
/**
* Predict class or regression value for features.
- * @param features input samples
+ * @param samples input samples
* @return class probabilities of the input sample
*/
- T predict(Map features);
+ @Deprecated
+ T predict(Map samples);
/**
* Predict class probabilities of the input samples features.
* The predicted class probability is the fraction of samples of the same class in a leaf.
- * @param features the input samples
+ * @param samples the input samples
* @return the class probabilities of the input sample
*/
- double[] predict_proba(Map features);
+ @Deprecated
+ double[] predict_proba(Map samples);
/**
* Get the names of all the features in the model.
diff --git a/src/main/java/rocks/vilaverde/classifier/FeatureVector.java b/src/main/java/rocks/vilaverde/classifier/FeatureVector.java
new file mode 100644
index 0000000..7cb839e
--- /dev/null
+++ b/src/main/java/rocks/vilaverde/classifier/FeatureVector.java
@@ -0,0 +1,53 @@
+package rocks.vilaverde.classifier;
+
+/**
+ * A container for the values for each feature of a sample that will be predicted.
+ */
+public class FeatureVector {
+
+ private final Features features;
+ private final double[] vector;
+
+ public FeatureVector(Features features) {
+ this.features = features;
+ this.vector = new double[features.getLength()];
+ }
+
+ public FeatureVector add(String feature, boolean value) {
+ add(feature, value ? 1.0 : 0.0);
+ return this;
+ }
+
+ public FeatureVector add(int index, boolean value) {
+ add(index, value ? 1.0 : 0.0);
+ return this;
+ }
+
+ public FeatureVector add(int index, double value) {
+ this.vector[index] = value;
+ return this;
+ }
+
+ public FeatureVector add(String feature, double value) {
+ int index = this.features.getFeatureIndex(feature);
+ add(index, value);
+ return this;
+ }
+
+ public double get(int index) {
+ if (index >= vector.length) {
+ throw new IllegalArgumentException(String.format("index must be less than %d", index));
+ }
+
+ return vector[index];
+ }
+
+ public double get(String feature) {
+ int index = features.getFeatureIndex(feature);
+ return get(index);
+ }
+
+ public boolean hasFeature(String feature) {
+ return this.features.getFeatureNames().contains(feature);
+ }
+}
diff --git a/src/main/java/rocks/vilaverde/classifier/Features.java b/src/main/java/rocks/vilaverde/classifier/Features.java
new file mode 100644
index 0000000..e51a484
--- /dev/null
+++ b/src/main/java/rocks/vilaverde/classifier/Features.java
@@ -0,0 +1,91 @@
+package rocks.vilaverde.classifier;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.Set;
+
+/**
+ * Container class for the set of named features that will be provided for each sample.
+ * Call {@link Features#newSample()} to create a {@link FeatureVector} to provide the
+ * values for the sample.
+ */
+public class Features {
+
+ /* Map of feature name to index */
+ private final Map features = new HashMap<>();
+ /* Feature can be added as long as no samples have yet to be created,
+ at that point this is immutable */
+ private boolean allowFeatureAdd = true;
+
+ /**
+ * Convienence creation method.
+ * @param features the set of features.
+ * @return Features
+ */
+ public static Features of(String ... features) {
+
+ // make sure all the features are unique
+ Set featureSet = new HashSet<>(Arrays.asList(features));
+ if (featureSet.size() != features.length) {
+ throw new IllegalArgumentException("features names are not unique");
+ }
+
+ return new Features(features);
+ }
+
+ public static Features fromSet(Set features) {
+ return new Features(features.toArray(new String[0]));
+ }
+
+ /**
+ * Constructor
+ * @param features order list of features.
+ */
+ private Features(String ... features) {
+ for (int i = 0; i < features.length; i++) {
+ this.features.put(features[i], i);
+ }
+ }
+
+ FeatureVector newSample() {
+ allowFeatureAdd = false;
+ return new FeatureVector(this);
+ }
+
+ public void addFeature(String feature) {
+ if (!allowFeatureAdd) {
+ throw new IllegalStateException("features are immutable");
+ }
+
+ if (!this.features.containsKey(feature)) {
+ int next = 0;
+ OptionalInt optionalInt = features.values().stream().mapToInt(integer -> integer).max();
+ if (optionalInt.isPresent()) {
+ next = optionalInt.getAsInt() + 1;
+ }
+
+ this.features.put(feature, next);
+ }
+ }
+
+ public int getLength() {
+ return this.features.size();
+ }
+
+ public int getFeatureIndex(String feature) {
+ Integer index = this.features.get(feature);
+ if (index == null) {
+ throw new IllegalArgumentException(String.format("feature %s does not exist", feature));
+ }
+
+ return index;
+ }
+
+ public Set getFeatureNames() {
+ return Collections.unmodifiableSet(this.features.keySet());
+ }
+}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java b/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java
index 5e3bf22..3cd494e 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java
@@ -6,7 +6,7 @@
* Represents a Choice in the decision tree, where when the expression is evaluated,
* if true will result in the child node of the choice being selected.
*/
-class ChoiceNode extends TreeNode {
+public class ChoiceNode extends TreeNode {
private final Operator op;
private final Double value;
@@ -38,7 +38,7 @@ public String toString() {
return String.format("%s %s", op.toString(), value.toString());
}
- public boolean eval(Double featureValue) {
+ public boolean eval(double featureValue) {
return op.apply(featureValue, value);
}
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java b/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java
index 3952d51..4746545 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java
@@ -1,14 +1,11 @@
package rocks.vilaverde.classifier.dt;
-import java.util.ArrayList;
-import java.util.List;
-
/**
* Represents a decision in the DecisionTreeClassifier. The decision will have
* a left and right hand {@link ChoiceNode} to be evaluated.
* A {@link ChoiceNode} may have nested {@link DecisionNode} or {@link EndNode}.
*/
-class DecisionNode extends TreeNode {
+public class DecisionNode extends TreeNode {
private final String featureName;
@@ -26,7 +23,7 @@ public static DecisionNode create(String feature) {
/**
* Private Constructor.
- * @param featureName
+ * @param featureName the name of the feature used in this decision
*/
private DecisionNode(String featureName) {
this.featureName = featureName.intern();
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java
index 66c5eed..a9c14cf 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java
@@ -1,16 +1,27 @@
package rocks.vilaverde.classifier.dt;
+import rocks.vilaverde.classifier.AbstractTreeClassifier;
+import rocks.vilaverde.classifier.FeatureVector;
import rocks.vilaverde.classifier.Operator;
import rocks.vilaverde.classifier.Prediction;
+import rocks.vilaverde.classifier.dt.visitors.FeatureNameVisitor;
+import rocks.vilaverde.classifier.dt.visitors.PredictVisitor;
-import java.awt.*;
import java.io.BufferedReader;
import java.io.Reader;
-import java.util.Map;
+import java.util.Arrays;
+import java.util.List;
import java.util.Set;
import java.util.Stack;
+import java.util.stream.Collectors;
-public class DecisionTreeClassifier implements TreeClassifier {
+/**
+ * Represents a DecisionTreeClassifier trained in scikit-learn and exported using
+ * export_text.
+ * @param the Prediction Class
+ */
+public class DecisionTreeClassifier extends AbstractTreeClassifier
+ implements TreeClassifier {
/**
* Factory method to create the classifier from the {@link Reader}.
@@ -49,64 +60,51 @@ private DecisionTreeClassifier(PredictionFactory predictionFactory) {
/**
* Predict class or regression value for features.
* For the classification model, the predicted class for the features of the sample is returned.
- * @param features Map of feature name to value
+ * @param samples Map of feature name to value
* @return predicted class
*/
- public T predict(Map features) {
- return getClassification(features).get();
+ public List predict(FeatureVector ... samples) {
+ return Arrays.stream(samples)
+ .map(featureVector -> getClassification(featureVector).get())
+ .collect(Collectors.toList());
}
/**
* Predict class probabilities of the input features.
* The predicted class probability is the fraction of samples of the same class in a leaf.
- * @param features Map of feature name to value
+ * @param samples Map of feature name to value
* @return the class probabilities of the input sample
*/
@Override
- public double[] predict_proba(Map features) {
- return getClassification(features).getProbability();
+ public double[][] predict_proba(FeatureVector ... samples) {
+ double[][] probabilities = null;
+
+ for (int i = 0; i < samples.length; i++) {
+ double[] result = getClassification(samples[i]).getProbability();
+ if (probabilities == null) {
+ probabilities = new double[samples.length][result.length];
+ }
+
+ probabilities[i] = result;
+ }
+
+ return probabilities;
}
/**
* Find the {@link Prediction} in the decision tree.
*/
- public Prediction getClassification(Map features) {
- validateFeature(features);
-
- TreeNode currentNode = root;
-
- // traverse through the tree until the end node is reached.
- while (!(currentNode instanceof EndNode)) {
-
- if (currentNode != null) {
- DecisionNode decisionNode = ((DecisionNode) currentNode);
-
- Double featureValue = features.get(decisionNode.getFeatureName());
- if (featureValue == null) {
- featureValue = Double.NaN;
- }
-
- if (decisionNode.getLeft().eval(featureValue)) {
- currentNode = decisionNode.getLeft().getChild();
- } else if (decisionNode.getRight().eval(featureValue)) {
- currentNode = decisionNode.getRight().getChild();
- } else {
- // if I get here something is wrong since none of the branches evaluated to true
- throw new RuntimeException(String.format("no branches evaluated to true for feature '%s'",
- decisionNode.getFeatureName()));
- }
- }
- }
-
- return (Prediction) currentNode;
+ public Prediction getClassification(FeatureVector sample) {
+ validateFeatures(sample);
+ return PredictVisitor.predict(sample, root);
}
/**
* Validate the features provided are expected.
*/
- private void validateFeature(Map features) throws IllegalArgumentException {
+ private void validateFeatures(FeatureVector sample) throws IllegalArgumentException {
for (String f : featureNames) {
- if (!features.containsKey(f)) {
+ if (!sample.hasFeature(f)) {
throw new IllegalArgumentException(String.format("expected feature named '%s' but none provided", f));
}
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java b/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java
index 61ad6ae..d5d904c 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/EndNode.java
@@ -8,13 +8,18 @@
* Represents the end of the tree, where no further decisions can be made. The end node contains
* the prediction.
*/
-class EndNode extends TreeNode implements Prediction {
+public class EndNode extends TreeNode implements Prediction {
private static final MessageFormat CLASS_FORMAT = new MessageFormat("class: {0}");
private final T prediction;
/**
* Factory method to create the appropriate {@link EndNode} from the
* String in exported tree model.
+ * @param endNodeString the serialized EndNode text from the sklearn export_text
+ * @param predictionFactory the {@link PredictionFactory} used to deserialize the value after "class:"
+ * @return the {@link EndNode}
+ * @throws Exception may be thrown when end node string can't be parsed.
+ * @param the java type of the classification once deserialized by the predictionFactory
*/
public static EndNode create(String endNodeString,
PredictionFactory predictionFactory) throws Exception {
@@ -56,7 +61,7 @@ public String toString() {
* {@link EndNode} that supports calculating the probability from the
* weights in the exported tree model.
*/
- static class WeightedEndNode extends EndNode {
+ public static class WeightedEndNode extends EndNode {
private static final MessageFormat WEIGHTS_FORMAT = new MessageFormat("weights: {0} class: {1}");
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java
index 8af8190..10c8b2e 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java
@@ -1,7 +1,9 @@
package rocks.vilaverde.classifier.dt;
/**
- * A prediction from the classifier
+ * A prediction from the classifier can be of type Double, Integer, or Boolean. This
+ * factory is used while parsing the text export of the DecisionTreeClassifier to construct
+ * the correct Java type for the serialized value.
*/
@FunctionalInterface
public interface PredictionFactory {
@@ -10,5 +12,10 @@ public interface PredictionFactory {
PredictionFactory INTEGER = Integer::valueOf;
PredictionFactory DOUBLE = Double::parseDouble;
+ /**
+ * Convert a String value to the appropriate type for the model.
+ * @param value the serialized text value of the prediction from the exported decision tree.
+ * @return the deserialized type
+ */
T create(String value);
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java
index a0d59e9..76f6f34 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java
@@ -1,14 +1,13 @@
package rocks.vilaverde.classifier.dt;
import rocks.vilaverde.classifier.Classifier;
+import rocks.vilaverde.classifier.FeatureVector;
import rocks.vilaverde.classifier.Prediction;
-import java.util.Map;
-
/**
* Implemented by Tree classifiers.
*/
public interface TreeClassifier extends Classifier {
- Prediction getClassification(Map features);
+ Prediction getClassification(FeatureVector features);
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java b/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java
index 0355b40..46f5e53 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/TreeNode.java
@@ -4,5 +4,9 @@
public abstract class TreeNode implements Visitable {
+ /**
+ * Must be implemented by every implementation of TreeNode to call visit() on the visitor.
+ * @param visitor the visitor
+ */
public abstract void accept(AbstractDecisionTreeVisitor visitor);
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/FeatureNameVisitor.java b/src/main/java/rocks/vilaverde/classifier/dt/visitors/FeatureNameVisitor.java
similarity index 62%
rename from src/main/java/rocks/vilaverde/classifier/dt/FeatureNameVisitor.java
rename to src/main/java/rocks/vilaverde/classifier/dt/visitors/FeatureNameVisitor.java
index 71c67c8..332644f 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/FeatureNameVisitor.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/visitors/FeatureNameVisitor.java
@@ -1,4 +1,7 @@
-package rocks.vilaverde.classifier.dt;
+package rocks.vilaverde.classifier.dt.visitors;
+
+import rocks.vilaverde.classifier.dt.AbstractDecisionTreeVisitor;
+import rocks.vilaverde.classifier.dt.DecisionNode;
import java.util.HashSet;
import java.util.Set;
@@ -11,12 +14,13 @@ public class FeatureNameVisitor extends AbstractDecisionTreeVisitor {
private final Set featureNames = new HashSet<>();
-
+ /**
+ * Visit a {@link DecisionNode} and collect the feature name used in the decision.
+ * @param object the {@link DecisionNode} being visited.
+ */
@Override
public void visit(DecisionNode object) {
-
featureNames.add(object.getFeatureName());
-
super.visit(object);
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java b/src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java
new file mode 100644
index 0000000..f750cb7
--- /dev/null
+++ b/src/main/java/rocks/vilaverde/classifier/dt/visitors/PredictVisitor.java
@@ -0,0 +1,84 @@
+package rocks.vilaverde.classifier.dt.visitors;
+
+import rocks.vilaverde.classifier.FeatureVector;
+import rocks.vilaverde.classifier.Prediction;
+import rocks.vilaverde.classifier.dt.AbstractDecisionTreeVisitor;
+import rocks.vilaverde.classifier.dt.ChoiceNode;
+import rocks.vilaverde.classifier.dt.DecisionNode;
+import rocks.vilaverde.classifier.dt.EndNode;
+import rocks.vilaverde.classifier.dt.TreeNode;
+
+/**
+ * Visits the nodes of the {@link rocks.vilaverde.classifier.dt.TreeClassifier} looking for the
+ * {@link rocks.vilaverde.classifier.dt.EndNode} for the given {@link rocks.vilaverde.classifier.FeatureVector}.
+ */
+public class PredictVisitor extends AbstractDecisionTreeVisitor {
+
+ private final FeatureVector sample;
+ private Prediction prediction;
+
+ /**
+ * Convenience method to search a tree for a prediction.
+ * @param sample the sample {@link FeatureVector}
+ * @param root the root {@link TreeNode} of the Decision Tree
+ * @return the {@link Prediction}
+ * @param the classification java type
+ */
+ public static Prediction predict(FeatureVector sample, TreeNode root) {
+ PredictVisitor visitor = new PredictVisitor<>(sample);
+ root.accept(visitor);
+
+ if (visitor.getPrediction() == null) {
+ throw new RuntimeException("expected a prediction result from the tree, but none found");
+ }
+
+ return visitor.getPrediction();
+ }
+
+ /**
+ * Constructor.
+ * @param sample the {@link FeatureVector}
+ */
+ private PredictVisitor(FeatureVector sample) {
+ this.sample = sample;
+ }
+
+ /**
+ * When visiting a {@link DecisionNode} we need to test the left and right
+ * {@link ChoiceNode} and visit only the one that evaluates to true.
+ * @param object the {@link DecisionNode} being visited
+ */
+ @Override
+ public void visit(DecisionNode object) {
+
+ // don't call super otherwise both choice nodes are visited.
+
+ double featureValue = this.sample.get(object.getFeatureName());
+ if (object.getLeft().eval(featureValue)) {
+ object.getLeft().getChild().accept(this);
+ } else if (object.getRight().eval(featureValue)) {
+ object.getRight().getChild().accept(this);
+ } else {
+ throw new RuntimeException(String.format("no branches evaluated to true for feature '%s'",
+ object.getFeatureName()));
+ }
+ }
+
+ /**
+ * When visiting an {@link EndNode} we've found the prediction
+ * and no longer need to visit the tree.
+ * @param object the {@link EndNode} being visited
+ */
+ @Override
+ public void visit(EndNode object) {
+ this.prediction = object;
+ }
+
+ /**
+ * Get the prediction by searching the decision tree.
+ * @return the {@link Prediction}
+ */
+ public Prediction getPrediction() {
+ return prediction;
+ }
+}
diff --git a/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java
index 36f0e39..6e2380c 100644
--- a/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java
+++ b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java
@@ -1,10 +1,12 @@
package rocks.vilaverde.classifier.ensemble;
-import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
-import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.archivers.ArchiveEntry;
+import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import rocks.vilaverde.classifier.AbstractTreeClassifier;
import rocks.vilaverde.classifier.Classifier;
+import rocks.vilaverde.classifier.FeatureVector;
import rocks.vilaverde.classifier.Prediction;
import rocks.vilaverde.classifier.dt.DecisionTreeClassifier;
import rocks.vilaverde.classifier.dt.PredictionFactory;
@@ -15,26 +17,34 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
-import java.util.*;
-import java.util.concurrent.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import java.util.stream.Collectors;
/**
* A forest of DecisionTreeClassifiers.
*/
-public class RandomForestClassifier implements Classifier {
+public class RandomForestClassifier extends AbstractTreeClassifier
+ implements Classifier {
private static final Logger LOG = LoggerFactory.getLogger(RandomForestClassifier.class);
/**
* Accept a TAR of exported DecisionTreeClassifiers from sklearn and product a
* RandomForestClassifier. This default to running in a single (current) thread.
- * @param tar the Tar Archive input stream
+ * @param tar the {@link ArchiveInputStream}
* @param factory the factory for creating the prediction class
* @return the {@link Classifier}
* @param the classifier type
* @throws Exception when the model could no be parsed
*/
- public static Classifier parse(final TarArchiveInputStream tar,
+ public static Classifier parse(final ArchiveInputStream tar,
PredictionFactory factory) throws Exception {
return RandomForestClassifier.parse(tar, factory, null);
}
@@ -49,14 +59,14 @@ public static Classifier parse(final TarArchiveInputStream tar,
* @param the classifier type
* @throws Exception when the model could no be parsed
*/
- public static Classifier parse(final TarArchiveInputStream tar,
+ public static Classifier parse(final ArchiveInputStream tar,
PredictionFactory factory,
ExecutorService executor) throws Exception {
List> forest = new ArrayList<>();
try (tar) {
- TarArchiveEntry exportedTree;
- while ((exportedTree = tar.getNextTarEntry()) != null) {
+ ArchiveEntry exportedTree;
+ while ((exportedTree = tar.getNextEntry()) != null) {
if (!exportedTree.isDirectory()) {
LOG.debug("Parsing tree {}", exportedTree.getName());
final InputStream noCloseStream = new InputStream() {
@@ -95,75 +105,69 @@ private RandomForestClassifier(List> forest, ExecutorService e
/**
* Predict class or regression value for features.
- * @param features features of the sample
+ * @param samples features of the sample
* @return class probabilities of the input sample
*/
@Override
- public T predict(Map features) {
- List> predictions = getPredictions(features);
- Map map = predictions.stream()
- .collect(Collectors.groupingBy(Prediction::get, Collectors.counting()));
-
- long max = map.values().stream().mapToLong(Long::longValue).max().orElse(0);
- for (Map.Entry entry : map.entrySet()) {
- if (entry.getValue() == max) {
- return entry.getKey();
- }
- }
-
- throw new IllegalStateException("no classification");
+ public List predict(FeatureVector ... samples) {
+ return Arrays.stream(samples)
+ .map(this::predictSingle)
+ .collect(Collectors.toList());
}
/**
* Predict class probabilities of the input samples features.
* The predicted class probability is the fraction of samples of the same class in a leaf.
- * @param features the input samples
+ * @param samples the input samples
* @return the class probabilities of the input sample
*/
@Override
- public double[] predict_proba(Map features) {
- if (forest.size() == 1) {
- return forest.get(0).getClassification(features).getProbability();
- }
-
- List> predictions = getPredictions(features);
+ public double[][] predict_proba(FeatureVector ... samples) {
+ double[][] probabilities = null;
- double[] result = null;
- for (Prediction prediction : predictions) {
- double[] prob = prediction.getProbability();
+ for (int i = 0; i < samples.length; i ++) {
+ FeatureVector fv = samples[i];
+ Prediction prediction = getClassification(fv);
- if (result == null) {
- result = prob;
- } else {
- for (int i = 0; i < prob.length; i++) {
- result[i] += prob[i];
- }
+ double[] sampleProb = prediction.getProbability();
+ if (probabilities == null) {
+ probabilities = new double[samples.length][sampleProb.length];
}
- }
- if (result != null) {
- int forestSize = forest.size();
- for (int i = 0; i < result.length; i++) {
- result[i] /= forestSize;
- }
+ probabilities[i] = sampleProb;
}
- return result;
+ return probabilities;
+ }
+
+ protected T predictSingle(FeatureVector sample) {
+ return getClassification(sample).get();
+ }
+
+ /**
+ * Return a prediction from the forest for the sample.
+ * @param sample a samples feature vector
+ * @return Prediction
+ */
+ @Override
+ public Prediction getClassification(FeatureVector sample) {
+ final List> predictions = getPredictions(sample);
+ return new RandomForestPrediction<>(predictions, forest.size());
}
/**
* Get all the predictions for the features of the sample
- * @param features features of the sample
+ * @param sample features of the sample
* @return a List of {@link Prediction} objects from the trees in the forest.
*/
- protected List> getPredictions(final Map features) {
+ protected List> getPredictions(final FeatureVector sample) {
List> predictions;
if (executorService != null) {
int jobs = Runtime.getRuntime().availableProcessors();
List> parallel = new ArrayList<>(jobs);
for (int i = 0; i < jobs; i++) {
- ParallelPrediction parallelPrediction = new ParallelPrediction<>(forest, features, i, jobs);
+ ParallelPrediction parallelPrediction = new ParallelPrediction<>(forest, sample, i, jobs);
parallel.add(parallelPrediction);
}
@@ -179,7 +183,7 @@ protected List> getPredictions(final Map features)
} else {
predictions = new ArrayList<>(forest.size());
for (TreeClassifier tree : forest) {
- Prediction prediction = tree.getClassification(features);
+ Prediction prediction = tree.getClassification(sample);
predictions.add(prediction);
}
}
@@ -201,6 +205,66 @@ public Set getFeatureNames() {
}
+ static class RandomForestPrediction implements Prediction {
+ private final List> predictions;
+ private final int forestSize;
+
+ /**
+ * Constructor
+ * @param predictions the list of predictions that need to be merged
+ * @param forestSize the number of trees in the forest
+ */
+ public RandomForestPrediction(List> predictions, int forestSize) {
+ this.predictions = predictions;
+ this.forestSize = forestSize;
+ }
+
+ /**
+ * @return The class.
+ */
+ @Override
+ public T get() {
+ Map map = predictions.stream()
+ .collect(Collectors.groupingBy(Prediction::get, Collectors.counting()));
+
+ long max = map.values().stream().mapToLong(Long::longValue).max().orElse(0);
+ T prediction = null;
+ for (Map.Entry entry : map.entrySet()) {
+ if (entry.getValue() == max) {
+ prediction = entry.getKey();
+ break;
+ }
+ }
+
+ if (prediction == null) {
+ throw new IllegalStateException("no classification");
+ }
+
+ return prediction;
+ }
+
+ /**
+ * @return the probability
+ */
+ @Override
+ public double[] getProbability() {
+ int arraySize = predictions.get(0).getProbability().length;
+ final double[] result = new double[arraySize];
+ for (Prediction prediction : predictions) {
+ double[] prob = prediction.getProbability();
+ for (int j = 0; j < prob.length; j++) {
+ result[j] += prob[j];
+ }
+ }
+
+ for (int j = 0; j < result.length; j++) {
+ result[j] /= forestSize;
+ }
+ return result;
+ }
+ }
+
+
/**
* A job that will only provide {@link Prediction} results
* from a subset of the trees in the forest.
@@ -210,23 +274,23 @@ private static class ParallelPrediction implements Callable> forest;
- private final Map features;
+ private final FeatureVector sample;
/**
* Constructor
* @param forest the random forest
- * @param features the features of a sample
+ * @param sample the features of a sample
* @param start the index of the tree in the forest this job will begin with
* @param offset the offest between the current tree and the next tree to call predict on
*/
private ParallelPrediction(List> forest,
- Map features,
+ FeatureVector sample,
int start,
int offset) {
this.offset = offset;
this.start = start;
this.forest = forest;
- this.features = features;
+ this.sample = sample;
}
/**
@@ -239,7 +303,7 @@ public List> call() throws Exception {
List> predictions = new ArrayList<>();
for (int i = start; i < forest.size(); i+=offset) {
- predictions.add(forest.get(i).getClassification(features));
+ predictions.add(forest.get(i).getClassification(sample));
}
return predictions;
diff --git a/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java b/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java
index d782f33..e6f8c46 100644
--- a/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java
+++ b/src/test/java/rocks/vilaverde/classifier/DecisionTreeClassifierTest.java
@@ -8,9 +8,6 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
public class DecisionTreeClassifierTest {
@@ -21,8 +18,12 @@ public void parseSimpleTree() throws Exception {
final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN);
Assertions.assertNotNull(decisionTree);
- Assertions.assertFalse(decisionTree.predict(Collections.singletonMap("feature1", 1.2)));
- Assertions.assertTrue(decisionTree.predict(Collections.singletonMap("feature1", 2.4)));
+ Features features = Features.of("feature1");
+ FeatureVector sample1 = features.newSample().add(0, 1.2);
+ Assertions.assertFalse(decisionTree.predict(sample1).get(0));
+
+ FeatureVector sample2 = features.newSample().add(0, 2.4);
+ Assertions.assertTrue(decisionTree.predict(sample2).get(0));
Assertions.assertNotNull(decisionTree.getFeatureNames());
Assertions.assertEquals(1, decisionTree.getFeatureNames().size());
@@ -35,20 +36,22 @@ public void parseDeepTree() throws Exception {
final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN);
Assertions.assertNotNull(decisionTree);
- Map features = new HashMap<>();
- features.put("feature1", 0.0);
- features.put("feature2", 1.0);
- features.put("feature3", BooleanFeature.FALSE.asDouble());
- features.put("feature4", 0.0);
- features.put("feature5", BooleanFeature.FALSE.asDouble());
- features.put("feature6", 1.0);
- features.put("feature7", 1.0);
- features.put("feature8", 0.0);
+ Features features = Features.of("feature1", "feature2", "feature3", "feature4",
+ "feature5", "feature6", "feature7", "feature8");
+ FeatureVector fv = features.newSample();
+ fv.add("feature1", 0.0)
+ .add("feature2", 1.0)
+ .add("feature3", false)
+ .add("feature4", 0.0)
+ .add("feature5", false)
+ .add("feature6", 1.0)
+ .add("feature7", 1.0)
+ .add("feature8", 0.0);
- Assertions.assertFalse(decisionTree.predict(features));
+ Assertions.assertFalse(decisionTree.predict(fv).get(0));
- features.put("feature5", BooleanFeature.TRUE.asDouble());
- Assertions.assertTrue(decisionTree.predict(features));
+ fv.add("feature5", true);
+ Assertions.assertTrue(decisionTree.predict(fv).get(0));
Assertions.assertNotNull(decisionTree.getFeatureNames());
Assertions.assertEquals(8, decisionTree.getFeatureNames().size());
@@ -61,17 +64,19 @@ public void invalidFeatureName() {
final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN);
Assertions.assertNotNull(decisionTree);
- Map features = new HashMap<>();
- features.put("feature11", 0.0);
- features.put("feature2", 1.0);
- features.put("feature3", BooleanFeature.FALSE.asDouble());
- features.put("feature4", 0.0);
- features.put("feature5", BooleanFeature.FALSE.asDouble());
- features.put("feature6", 1.0);
- features.put("feature7", 1.0);
- features.put("feature8", 0.0);
-
- decisionTree.predict(features);
+ Features features = Features.of("feature11", "feature2", "feature3", "feature4",
+ "feature5", "feature6", "feature7", "feature8");
+ FeatureVector fv = features.newSample();
+ fv.add(0, 0.0)
+ .add(1, 1.0)
+ .add(2, false)
+ .add(3, 0.0)
+ .add(4, false)
+ .add(5, 1.0)
+ .add(6, 1.0)
+ .add(7, 0.0);
+
+ decisionTree.predict(fv);
});
Assertions.assertEquals("expected feature named 'feature1' but none provided",
@@ -85,15 +90,17 @@ public void invalidFeatureCount() {
final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN);
Assertions.assertNotNull(decisionTree);
- Map features = new HashMap<>();
- features.put("feature1", 0.0);
- features.put("feature2", 1.0);
- features.put("feature3", BooleanFeature.FALSE.asDouble());
- features.put("feature4", 0.0);
- features.put("feature5", BooleanFeature.FALSE.asDouble());
- features.put("feature6", 1.0);
-
- decisionTree.predict(features);
+ Features features = Features.of("feature1", "feature2", "feature3", "feature4",
+ "feature5", "feature6");
+ FeatureVector fv = features.newSample();
+ fv.add(0, 0.0)
+ .add(1, 1.0)
+ .add(2, false)
+ .add(3, 0.0)
+ .add(4, false)
+ .add(5, 1.0);
+
+ decisionTree.predict(fv);
});
Assertions.assertEquals("expected feature named 'feature7' but none provided",
@@ -106,19 +113,21 @@ public void probability() throws Exception {
final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.BOOLEAN);
Assertions.assertNotNull(decisionTree);
- Map features = new HashMap<>();
- features.put("feature1", 1.2);
- features.put("feature2", 88.33);
- features.put("feature3", BooleanFeature.FALSE.asDouble());
- features.put("feature4", 1.727);
- features.put("feature5", BooleanFeature.FALSE.asDouble());
- features.put("feature6", 1.0);
- features.put("feature7", 0.0048);
- features.put("feature8", 0.0);
-
- Assertions.assertFalse(decisionTree.predict(features));
-
- double[] prediction = decisionTree.predict_proba(features);
+ Features features = Features.of("feature1", "feature2", "feature3", "feature4",
+ "feature5", "feature6", "feature7", "feature8");
+ FeatureVector fv = features.newSample();
+ fv.add("feature1", 1.2)
+ .add("feature2", 88.33)
+ .add("feature3", false)
+ .add("feature4", 1.727)
+ .add("feature5", false)
+ .add("feature6", 1.0)
+ .add("feature7", 0.0048)
+ .add("feature8", 0.0);
+
+ Assertions.assertFalse(decisionTree.predict(fv).get(0));
+
+ double[] prediction = decisionTree.predict_proba(fv)[0];
Assertions.assertNotNull(prediction);
Assertions.assertEquals(0.63636364, prediction[0], .00000001);
Assertions.assertEquals(0.36363636, prediction[1], .00000001);
@@ -130,19 +139,19 @@ public void noWeights() throws Exception {
final Classifier decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.INTEGER);
Assertions.assertNotNull(decisionTree);
- Map features = new HashMap<>();
- features.put("sepal length (cm)", 3.0);
- features.put("sepal width (cm)", 5.0);
- features.put("petal length (cm)", 4.0);
- features.put("petal width (cm)", 2.0);
-
- Integer prediction = decisionTree.predict(features);
+ Features features = Features.of("sepal length (cm)",
+ "sepal width (cm)",
+ "petal length (cm)",
+ "petal width (cm)");
+ FeatureVector fv = features.newSample();
+ fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);
+ Integer prediction = decisionTree.predict(fv).get(0);
Assertions.assertNotNull(prediction);
Assertions.assertEquals(1, prediction.intValue());
- features.put("sepal length (cm)", 6.0);
- prediction = decisionTree.predict(features);
+ fv.add("sepal length (cm)", 6.0);
+ prediction = decisionTree.predict(fv).get(0);
Assertions.assertEquals(2, prediction.intValue());
}
diff --git a/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java
index cffffcc..9ea2acf 100644
--- a/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java
+++ b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java
@@ -11,8 +11,8 @@
import java.io.IOException;
import java.io.InputStream;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -40,37 +40,55 @@ public void randomForestParallel() throws Exception {
PredictionFactory.DOUBLE, executorService);
Assertions.assertNotNull(decisionTree);
- double[] proba = decisionTree.predict_proba(getSample1());
+ double[] proba = decisionTree.predict_proba(getSample1())[0];
assertSample(proba, .06, .62, .32);
- Double prediction = decisionTree.predict(getSample1());
+ Double prediction = decisionTree.predict(getSample1()).get(0);
Assertions.assertNotNull(prediction);
Assertions.assertEquals(1.0, prediction, .0);
- prediction = decisionTree.predict(getSample2());
+ prediction = decisionTree.predict(getSample2()).get(0);
Assertions.assertEquals(2, prediction.intValue());
- proba = decisionTree.predict_proba(getSample2());
+ proba = decisionTree.predict_proba(getSample2())[0];
assertSample(proba, 0.0, .44, .56);
}
+ @Test
+ public void randomForestParallel10000() throws Exception {
+ TarArchiveInputStream exported = getExportedModel("rf/iris.tgz");
+ final Classifier decisionTree = RandomForestClassifier.parse(exported,
+ PredictionFactory.DOUBLE, executorService);
+ Assertions.assertNotNull(decisionTree);
+
+ List vectorList = new ArrayList<>(10000);
+ for (int i = 0; i< 5000; i++) {
+ vectorList.add(getSample1());
+ vectorList.add(getSample2());
+ }
+
+ double[][] proba = decisionTree.predict_proba(vectorList.toArray(new FeatureVector[0]));
+ assertSample(proba[0], .06, .62, .32);
+ assertSample(proba[1], 0.0, .44, .56);
+ }
+
@Test
public void randomForest() throws Exception {
TarArchiveInputStream exported = getExportedModel("rf/iris.tgz");
final Classifier decisionTree = RandomForestClassifier.parse(exported, PredictionFactory.DOUBLE);
Assertions.assertNotNull(decisionTree);
- double[] proba = decisionTree.predict_proba(getSample1());
+ double[] proba = decisionTree.predict_proba(getSample1())[0];
assertSample(proba, .06, .62, .32);
- Double prediction = decisionTree.predict(getSample1());
+ Double prediction = decisionTree.predict(getSample1()).get(0);
Assertions.assertNotNull(prediction);
Assertions.assertEquals(1.0, prediction, .0);
- prediction = decisionTree.predict(getSample2());
+ prediction = decisionTree.predict(getSample2()).get(0);
Assertions.assertEquals(2, prediction.intValue());
- proba = decisionTree.predict_proba(getSample2());
+ proba = decisionTree.predict_proba(getSample2())[0];
assertSample(proba, 0.0, .44, .56);
}
@@ -82,10 +100,10 @@ public void invalidFeatureCount() {
Assertions.assertNotNull(decisionTree);
// create invalid number of features.
- Map features = getSample1();
- features.remove("sepal length (cm)");
-
- decisionTree.predict(features);
+ Features features = Features.of("sepal width (cm)", "petal length (cm)", "petal width (cm)");
+ FeatureVector fv = features.newSample();
+ fv.add(0, 3.0).add(1, 5.0).add(2, 4.0);
+ decisionTree.predict(fv);
});
Assertions.assertEquals("expected feature named 'sepal length (cm)' but none provided",
@@ -101,22 +119,28 @@ public void featureNames() throws Exception {
Assertions.assertEquals(4, decisionTree.getFeatureNames().size());
}
- private Map getSample1() {
- Map features = new HashMap<>();
- features.put("sepal length (cm)", 3.0);
- features.put("sepal width (cm)", 5.0);
- features.put("petal length (cm)", 4.0);
- features.put("petal width (cm)", 2.0);
- return features;
+ private FeatureVector getSample1() {
+ Features features = Features.of("sepal length (cm)",
+ "sepal width (cm)",
+ "petal length (cm)",
+ "petal width (cm)");
+ FeatureVector fv = features.newSample();
+ return fv.add(0, 3.0)
+ .add(1, 5.0)
+ .add(2, 4.0)
+ .add(3, 2.0);
}
- private Map getSample2() {
- Map features = new HashMap<>();
- features.put("sepal length (cm)", 1.0);
- features.put("sepal width (cm)", 2.0);
- features.put("petal length (cm)", 3.0);
- features.put("petal width (cm)", 4.0);
- return features;
+ private FeatureVector getSample2() {
+ Features features = Features.of("sepal length (cm)",
+ "sepal width (cm)",
+ "petal length (cm)",
+ "petal width (cm)");
+ FeatureVector fv = features.newSample();
+ return fv.add(0, 1.0)
+ .add(1, 2.0)
+ .add(2, 3.0)
+ .add(3, 4.0);
}
private void assertSample(double[] proba, double expected, double expected1, double expected2) {