diff --git a/.idea/copyright/profiles_settings.xml b/.idea/copyright/profiles_settings.xml
new file mode 100644
index 0000000..8295f31
--- /dev/null
+++ b/.idea/copyright/profiles_settings.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/libraries/Maven__org_apache_commons_commons_compress_1_21.xml b/.idea/libraries/Maven__org_apache_commons_commons_compress_1_21.xml
new file mode 100644
index 0000000..49cd123
--- /dev/null
+++ b/.idea/libraries/Maven__org_apache_commons_commons_compress_1_21.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml b/.idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml
new file mode 100644
index 0000000..2d759c1
--- /dev/null
+++ b/.idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml b/.idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml
new file mode 100644
index 0000000..8bc862b
--- /dev/null
+++ b/.idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index 8306744..35eb1dd 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -2,6 +2,5 @@
-
\ No newline at end of file
diff --git a/README.md b/README.md
index c2ca859..2273253 100644
--- a/README.md
+++ b/README.md
@@ -2,8 +2,29 @@
This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.
-For example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree`
-export_text() as shown below:
+## Support
+* The tree.DecisionTreeClassifier is supported
+ * Supports `predict()`,
+ * Supports `predict_proba()` when `export_text()` configured with `show_weights=True`
+* The tree.RandomForestClassifier is supported
+ * Supports `predict()`,
+ * Supports `predict_proba()` when `export_text()` configured with `show_weights=True`
+
+## Installing
+
+### Importing Maven Dependency
+```xml
+
+ rocks.vilaverde
+ scikit-learn-2-java
+ 1.0.0
+
+```
+
+## DecisionTreeClassifier
+
+As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree`
+`export_text()` as shown below:
```
>>> from sklearn.datasets import load_iris
@@ -26,17 +47,8 @@ export_text() as shown below:
| | |--- class: 2
```
-can be executed in Java Maven. Note that when calling `export_text` it is recommended that `max_depth` be set
-to sys.maxsize so that the tree isn't truncated.
-
-### Importing Maven Dependency
-```xml
-
- rocks.vilaverde
- scikit-learn-2-java
- 1.0.0
-
-```
+The exported text can then be executed in Java. Note that when calling `export_text` it is
+recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated.
### Java Example
In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map
@@ -45,7 +57,7 @@ and the decision tree is asked to predict the class.
```
Reader tree = getTrainedModel("iris.model");
final Classifier decisionTree = DecisionTreeClassifier.parse(tree,
- new PredictionFactory.IntegerPredictionFactory());
+ PredictionFactory.INTEGER);
Map features = new HashMap<>();
features.put("sepal length (cm)", 3.0);
@@ -57,10 +69,47 @@ and the decision tree is asked to predict the class.
System.out.println(prediction.toString());
```
-## Support
-* The tree.DecisionTreeClassifier is supported
- * Supports `predict()`,
- * Supports `predict_proba()` when `export_text()` configured with `show_weights=True`
+## RandomForestClassifier
+
+To use a RandomForestClassifier that has been trained on the Iris dataset, each of the `estimators`
+in the classifiers need to be and exported using `from sklearn.tree export export_text` as shown below:
+
+```
+>>> from sklearn import datasets
+>>> from sklearn import tree
+>>> from sklearn.ensemble import RandomForestClassifier
+>>>
+>>> import os
+>>>
+>>> iris = datasets.load_iris()
+>>> X = iris.data
+>>> y = iris.target
+>>>
+>>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8)
+>>> model = clf.fit(X, y)
+>>>
+>>> for i, t in enumerate(clf.estimators_):
+>>> with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1:
+>>> text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize)
+>>> file1.write(text_representation)
+```
+
+Once all the estimators are exported into `/tmp/estimators`, you can create a TAR archive, for example:
+```bash
+cd /tmp/estimators
+tar -czvf /tmp/iris.tgz .
+```
+
+Then you can use the RandomForestClassifier class to parse the TAR archive.
+
+```
+ import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+ ...
+
+ TarArchiveInputStream tree = getArchive("iris.tgz");
+ final Classifier decisionTree = RandomForestClassifier.parse(tree,
+ PredictionFactory.DOUBLE);
+```
## Testing
Testing was done using sci-kit learn 1.1.3.
diff --git a/pom.xml b/pom.xml
index 818978d..02d6863 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
rocks.vilaverde
scikit-learn-2-java
- 1.0.0-SNAPSHOT
+ 1.0.1-SNAPSHOT
${project.groupId}:${project.artifactId}
A sklearn exported_text models parser for executing in the Java runtime.
@@ -54,11 +54,28 @@
+
+ org.apache.commons
+ commons-compress
+ 1.21
+
+
+ org.slf4j
+ slf4j-api
+ 1.7.36
+
+
org.junit.jupiter
junit-jupiter
test
+
+ org.slf4j
+ slf4j-simple
+ 1.7.25
+ test
+
diff --git a/scikit-learn-2-java.iml b/scikit-learn-2-java.iml
index df974c0..3329a67 100644
--- a/scikit-learn-2-java.iml
+++ b/scikit-learn-2-java.iml
@@ -12,6 +12,8 @@
+
+
@@ -20,5 +22,6 @@
+
\ No newline at end of file
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java
index faf8015..f439f01 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java
@@ -1,15 +1,15 @@
package rocks.vilaverde.classifier.dt;
-import rocks.vilaverde.classifier.Prediction;
-import rocks.vilaverde.classifier.Classifier;
import rocks.vilaverde.classifier.Operator;
+import rocks.vilaverde.classifier.Prediction;
+
import java.io.BufferedReader;
import java.io.Reader;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
-public class DecisionTreeClassifier implements Classifier {
+public class DecisionTreeClassifier implements TreeClassifier {
/**
* Factory method to create the classifier from the {@link Reader}.
@@ -19,7 +19,7 @@ public class DecisionTreeClassifier implements Classifier {
* @param class
* @throws Exception when the model could no be parsed
*/
- public static Classifier parse(Reader reader, PredictionFactory factory) throws Exception {
+ public static DecisionTreeClassifier parse(Reader reader, PredictionFactory factory) throws Exception {
try (reader) {
DecisionTreeClassifier classifier = new DecisionTreeClassifier<>(factory);
@@ -53,7 +53,7 @@ private DecisionTreeClassifier(PredictionFactory predictionFactory) {
* @return predicted class
*/
public T predict(Map features) {
- return findClassification(features).get();
+ return getClassification(features).get();
}
/**
@@ -64,13 +64,13 @@ public T predict(Map features) {
*/
@Override
public double[] predict_proba(Map features) {
- return findClassification(features).getProbability();
+ return getClassification(features).getProbability();
}
/**
* Find the {@link Prediction} in the decision tree.
*/
- private Prediction findClassification(Map features) {
+ public Prediction getClassification(Map features) {
validateFeature(features);
TreeNode currentNode = root;
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java
index 0588293..8af8190 100644
--- a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java
+++ b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java
@@ -1,11 +1,3 @@
-/////////////////////////////////////////////////////////////////////////////
-// PROPRIETARY RIGHTS STATEMENT
-// The contents of this file represent confidential information that is the
-// proprietary property of Edge2Web, Inc. Viewing or use of
-// this information is prohibited without the express written consent of
-// Edge2Web, Inc. Removal of this PROPRIETARY RIGHTS STATEMENT
-// is strictly forbidden. Copyright (c) 2016 All rights reserved.
-/////////////////////////////////////////////////////////////////////////////
package rocks.vilaverde.classifier.dt;
/**
@@ -16,6 +8,7 @@ public interface PredictionFactory {
PredictionFactory BOOLEAN = value -> Boolean.valueOf(value.toLowerCase());
PredictionFactory INTEGER = Integer::valueOf;
+ PredictionFactory DOUBLE = Double::parseDouble;
T create(String value);
}
diff --git a/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java
new file mode 100644
index 0000000..a0d59e9
--- /dev/null
+++ b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java
@@ -0,0 +1,14 @@
+package rocks.vilaverde.classifier.dt;
+
+import rocks.vilaverde.classifier.Classifier;
+import rocks.vilaverde.classifier.Prediction;
+
+import java.util.Map;
+
+/**
+ * Implemented by Tree classifiers.
+ */
+public interface TreeClassifier extends Classifier {
+
+ Prediction getClassification(Map features);
+}
diff --git a/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java
new file mode 100644
index 0000000..eda0bda
--- /dev/null
+++ b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java
@@ -0,0 +1,190 @@
+package rocks.vilaverde.classifier.ensemble;
+
+import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import rocks.vilaverde.classifier.Classifier;
+import rocks.vilaverde.classifier.Prediction;
+import rocks.vilaverde.classifier.dt.DecisionTreeClassifier;
+import rocks.vilaverde.classifier.dt.PredictionFactory;
+import rocks.vilaverde.classifier.dt.TreeClassifier;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.*;
+import java.util.concurrent.*;
+import java.util.stream.Collectors;
+
+/**
+ * A forest of DecisionTreeClassifiers.
+ */
+public class RandomForestClassifier implements Classifier {
+ private static final Logger LOG = LoggerFactory.getLogger(RandomForestClassifier.class);
+
+ /**
+ * Accept a TAR of exported DecisionTreeClassifiers from sklearn and product a
+ * RandomForestClassifier. This default to running in a single (current) thread.
+ * @param tar the Tar Archive input stream
+ * @param factory the factory for creating the prediction class
+ * @return the {@link Classifier}
+ * @param the classifer type
+ * @throws Exception when the model could no be parsed
+ */
+ public static Classifier parse(final TarArchiveInputStream tar,
+ PredictionFactory factory) throws Exception {
+ return RandomForestClassifier.parse(tar, factory, 0);
+ }
+
+ /**
+ * Accept a TAR of exported DecisionTreeClassifiers from sklearn and product a
+ * RandomForestClassifier. This can be run in Parallel by
+ * @param tar the Tar Archive input stream
+ * @param factory the factory for creating the prediction class
+ * @param jobs the number of threads to use to search the forest. Using -1 will use all
+ * available threads.
+ * @return the {@link Classifier}
+ * @param the classifer type
+ * @throws Exception when the model could no be parsed
+ */
+ public static Classifier parse(final TarArchiveInputStream tar,
+ PredictionFactory factory,
+ int jobs) throws Exception {
+ List> forest = new ArrayList<>();
+
+ try (tar) {
+ TarArchiveEntry exportedTree = null;
+ while ((exportedTree = tar.getNextTarEntry()) != null) {
+ if (!exportedTree.isDirectory()) {
+ LOG.debug("Parsing tree {}", exportedTree.getName());
+ final InputStream noCloseStream = new InputStream() {
+ @Override
+ public int read() throws IOException {
+ return tar.read();
+ }
+
+ @Override
+ public void close() throws IOException {
+ // don't close otherwise next file in tar won't be read.
+ }
+ };
+ BufferedReader reader = new BufferedReader(new InputStreamReader(noCloseStream));
+ TreeClassifier tree = (TreeClassifier) DecisionTreeClassifier.parse(reader, factory);
+ forest.add(tree);
+ }
+ }
+ }
+
+ return new RandomForestClassifier<>(forest, jobs);
+ }
+
+ private int jobs;
+ private List> forest;
+
+ /**
+ * Private Constructor
+ * @param forest
+ * @param jobs
+ */
+ private RandomForestClassifier(List> forest, int jobs) {
+ this.forest = forest;
+ this.jobs = jobs;
+ }
+
+ @Override
+ public T predict(Map features) {
+
+ List> predictions = getPredictions(features);
+
+ Map map = predictions.stream().collect(
+ Collectors.groupingBy(Prediction::get, Collectors.counting())
+ );
+
+ long max = map.values().stream().mapToLong(Long::longValue).max().getAsLong();
+ for (Map.Entry entry : map.entrySet()) {
+ if (entry.getValue() == max) {
+ return entry.getKey();
+ }
+ }
+
+ throw new IllegalStateException("no classification");
+ }
+
+ @Override
+ public double[] predict_proba(Map features) {
+ if (forest.size() == 1) {
+ return forest.get(0).getClassification(features).getProbability();
+ }
+
+ List> predictions = getPredictions(features);
+
+ double[] result = null;
+
+ for (Prediction prediction : predictions) {
+ double[] prob = prediction.getProbability();
+
+ if (result == null) {
+ result = prob;
+ } else {
+ for (int i = 0; i < prob.length; i++) {
+ result[i] += prob[i];
+ }
+ }
+ }
+
+ int forestSize = forest.size();
+ for (int i = 0; i < result.length; i++) {
+ result[i] /= forestSize;
+ }
+
+ return result;
+ }
+
+ protected List> getPredictions(Map features) {
+ final List> predictions = new ArrayList<>(forest.size());
+
+ if (jobs == -1) {
+ jobs = Runtime.getRuntime().availableProcessors();
+ }
+
+ if (jobs > 0) {
+ ExecutorService executor = Executors.newFixedThreadPool(jobs);
+
+ try {
+ for (TreeClassifier tree : forest) {
+ executor.submit(() -> {
+ Prediction pred = tree.getClassification(features);
+ synchronized (predictions) {
+ predictions.add(pred);
+ }
+ });
+ }
+ } finally {
+ executor.shutdown();
+ try {
+ executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
+ } catch (Exception e) {
+ LOG.error("interrupted while searching trees", e);
+ }
+ }
+ } else {
+ for (TreeClassifier tree : forest) {
+ Prediction prediction = tree.getClassification(features);
+ predictions.add(prediction);
+ }
+ }
+
+ return predictions;
+ }
+
+ @Override
+ public Set getFeatureNames() {
+ Set features = new HashSet<>();
+ for (Classifier tree : forest) {
+ features.addAll(tree.getFeatureNames());
+ }
+ return features;
+ }
+}
diff --git a/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java
new file mode 100644
index 0000000..ae971c9
--- /dev/null
+++ b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java
@@ -0,0 +1,94 @@
+package rocks.vilaverde.classifier;
+
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import rocks.vilaverde.classifier.dt.PredictionFactory;
+import rocks.vilaverde.classifier.ensemble.RandomForestClassifier;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Tests for the RandomForestClassifier
+ */
+public class RandomForestClassifierTest {
+
+ @Test
+ public void randomForestParallel() throws Exception {
+ TarArchiveInputStream exported = getExportedModel("rf/iris.tgz");
+ final Classifier decisionTree = RandomForestClassifier.parse(exported,
+ PredictionFactory.DOUBLE, 4);
+ Assertions.assertNotNull(decisionTree);
+
+ double[] proba = decisionTree.predict_proba(getSample1());
+ assertSample(proba, .06, .62, .32);
+
+ Double prediction = decisionTree.predict(getSample1());
+ Assertions.assertNotNull(prediction);
+ Assertions.assertEquals(1.0, prediction.doubleValue(), .0);
+
+ prediction = decisionTree.predict(getSample2());
+ Assertions.assertEquals(2, prediction.intValue());
+
+ proba = decisionTree.predict_proba(getSample2());
+ assertSample(proba, 0.0, .44, .56);
+ }
+
+ @Test
+ public void randomForest() throws Exception {
+ TarArchiveInputStream exported = getExportedModel("rf/iris.tgz");
+ final Classifier decisionTree = RandomForestClassifier.parse(exported, PredictionFactory.DOUBLE);
+ Assertions.assertNotNull(decisionTree);
+
+ double[] proba = decisionTree.predict_proba(getSample1());
+ assertSample(proba, .06, .62, .32);
+
+ Double prediction = decisionTree.predict(getSample1());
+ Assertions.assertNotNull(prediction);
+ Assertions.assertEquals(1.0, prediction.doubleValue(), .0);
+
+ prediction = decisionTree.predict(getSample2());
+ Assertions.assertEquals(2, prediction.intValue());
+
+ proba = decisionTree.predict_proba(getSample2());
+ assertSample(proba, 0.0, .44, .56);
+ }
+
+ private Map getSample1() {
+ Map features = new HashMap<>();
+ features.put("sepal length (cm)", 3.0);
+ features.put("sepal width (cm)", 5.0);
+ features.put("petal length (cm)", 4.0);
+ features.put("petal width (cm)", 2.0);
+ return features;
+ }
+
+ private Map getSample2() {
+ Map features = new HashMap<>();
+ features.put("sepal length (cm)", 1.0);
+ features.put("sepal width (cm)", 2.0);
+ features.put("petal length (cm)", 3.0);
+ features.put("petal width (cm)", 4.0);
+ return features;
+ }
+
+ private void assertSample(double[] proba, double expected, double expected1, double expected2) {
+ Assertions.assertNotNull(proba);
+ Assertions.assertEquals(expected, proba[0], .0);
+ Assertions.assertEquals(expected1, proba[1], .0);
+ Assertions.assertEquals(expected2, proba[2], .0);
+ }
+
+ private TarArchiveInputStream getExportedModel(String fileName) throws IOException {
+ ClassLoader cl = DecisionTreeClassifierTest.class.getClassLoader();
+ InputStream stream = cl.getResourceAsStream(fileName);
+ if (stream == null) {
+ throw new RuntimeException(String.format("no zip found with name %s", fileName));
+ }
+ return new TarArchiveInputStream(new GzipCompressorInputStream(stream));
+ }
+}
diff --git a/src/test/resources/rf/iris.tgz b/src/test/resources/rf/iris.tgz
new file mode 100644
index 0000000..158dc96
Binary files /dev/null and b/src/test/resources/rf/iris.tgz differ
diff --git a/src/test/resources/simplelogger.properties b/src/test/resources/simplelogger.properties
new file mode 100644
index 0000000..40cfffe
--- /dev/null
+++ b/src/test/resources/simplelogger.properties
@@ -0,0 +1 @@
+org.slf4j.simpleLogger.defaultLogLevel=info
\ No newline at end of file