diff --git a/.idea/copyright/profiles_settings.xml b/.idea/copyright/profiles_settings.xml new file mode 100644 index 0000000..8295f31 --- /dev/null +++ b/.idea/copyright/profiles_settings.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/libraries/Maven__org_apache_commons_commons_compress_1_21.xml b/.idea/libraries/Maven__org_apache_commons_commons_compress_1_21.xml new file mode 100644 index 0000000..49cd123 --- /dev/null +++ b/.idea/libraries/Maven__org_apache_commons_commons_compress_1_21.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml b/.idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml new file mode 100644 index 0000000..2d759c1 --- /dev/null +++ b/.idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml b/.idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml new file mode 100644 index 0000000..8bc862b --- /dev/null +++ b/.idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 8306744..35eb1dd 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -2,6 +2,5 @@ - \ No newline at end of file diff --git a/README.md b/README.md index c2ca859..2273253 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,29 @@ This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java. -For example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree` -export_text() as shown below: +## Support +* The tree.DecisionTreeClassifier is supported + * Supports `predict()`, + * Supports `predict_proba()` when `export_text()` configured with `show_weights=True` +* The tree.RandomForestClassifier is supported + * Supports `predict()`, + * Supports `predict_proba()` when `export_text()` configured with `show_weights=True` + +## Installing + +### Importing Maven Dependency +```xml + + rocks.vilaverde + scikit-learn-2-java + 1.0.0 + +``` + +## DecisionTreeClassifier + +As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree` +`export_text()` as shown below: ``` >>> from sklearn.datasets import load_iris @@ -26,17 +47,8 @@ export_text() as shown below: | | |--- class: 2 ``` -can be executed in Java Maven. Note that when calling `export_text` it is recommended that `max_depth` be set -to sys.maxsize so that the tree isn't truncated. - -### Importing Maven Dependency -```xml - - rocks.vilaverde - scikit-learn-2-java - 1.0.0 - -``` +The exported text can then be executed in Java. Note that when calling `export_text` it is +recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated. ### Java Example In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map @@ -45,7 +57,7 @@ and the decision tree is asked to predict the class. ``` Reader tree = getTrainedModel("iris.model"); final Classifier decisionTree = DecisionTreeClassifier.parse(tree, - new PredictionFactory.IntegerPredictionFactory()); + PredictionFactory.INTEGER); Map features = new HashMap<>(); features.put("sepal length (cm)", 3.0); @@ -57,10 +69,47 @@ and the decision tree is asked to predict the class. System.out.println(prediction.toString()); ``` -## Support -* The tree.DecisionTreeClassifier is supported - * Supports `predict()`, - * Supports `predict_proba()` when `export_text()` configured with `show_weights=True` +## RandomForestClassifier + +To use a RandomForestClassifier that has been trained on the Iris dataset, each of the `estimators` +in the classifiers need to be and exported using `from sklearn.tree export export_text` as shown below: + +``` +>>> from sklearn import datasets +>>> from sklearn import tree +>>> from sklearn.ensemble import RandomForestClassifier +>>> +>>> import os +>>> +>>> iris = datasets.load_iris() +>>> X = iris.data +>>> y = iris.target +>>> +>>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8) +>>> model = clf.fit(X, y) +>>> +>>> for i, t in enumerate(clf.estimators_): +>>> with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1: +>>> text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize) +>>> file1.write(text_representation) +``` + +Once all the estimators are exported into `/tmp/estimators`, you can create a TAR archive, for example: +```bash +cd /tmp/estimators +tar -czvf /tmp/iris.tgz . +``` + +Then you can use the RandomForestClassifier class to parse the TAR archive. + +``` + import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; + ... + + TarArchiveInputStream tree = getArchive("iris.tgz"); + final Classifier decisionTree = RandomForestClassifier.parse(tree, + PredictionFactory.DOUBLE); +``` ## Testing Testing was done using sci-kit learn 1.1.3. diff --git a/pom.xml b/pom.xml index 818978d..02d6863 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ rocks.vilaverde scikit-learn-2-java - 1.0.0-SNAPSHOT + 1.0.1-SNAPSHOT ${project.groupId}:${project.artifactId} A sklearn exported_text models parser for executing in the Java runtime. @@ -54,11 +54,28 @@ + + org.apache.commons + commons-compress + 1.21 + + + org.slf4j + slf4j-api + 1.7.36 + + org.junit.jupiter junit-jupiter test + + org.slf4j + slf4j-simple + 1.7.25 + test + diff --git a/scikit-learn-2-java.iml b/scikit-learn-2-java.iml index df974c0..3329a67 100644 --- a/scikit-learn-2-java.iml +++ b/scikit-learn-2-java.iml @@ -12,6 +12,8 @@ + + @@ -20,5 +22,6 @@ + \ No newline at end of file diff --git a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java index faf8015..f439f01 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/DecisionTreeClassifier.java @@ -1,15 +1,15 @@ package rocks.vilaverde.classifier.dt; -import rocks.vilaverde.classifier.Prediction; -import rocks.vilaverde.classifier.Classifier; import rocks.vilaverde.classifier.Operator; +import rocks.vilaverde.classifier.Prediction; + import java.io.BufferedReader; import java.io.Reader; import java.util.Map; import java.util.Set; import java.util.Stack; -public class DecisionTreeClassifier implements Classifier { +public class DecisionTreeClassifier implements TreeClassifier { /** * Factory method to create the classifier from the {@link Reader}. @@ -19,7 +19,7 @@ public class DecisionTreeClassifier implements Classifier { * @param class * @throws Exception when the model could no be parsed */ - public static Classifier parse(Reader reader, PredictionFactory factory) throws Exception { + public static DecisionTreeClassifier parse(Reader reader, PredictionFactory factory) throws Exception { try (reader) { DecisionTreeClassifier classifier = new DecisionTreeClassifier<>(factory); @@ -53,7 +53,7 @@ private DecisionTreeClassifier(PredictionFactory predictionFactory) { * @return predicted class */ public T predict(Map features) { - return findClassification(features).get(); + return getClassification(features).get(); } /** @@ -64,13 +64,13 @@ public T predict(Map features) { */ @Override public double[] predict_proba(Map features) { - return findClassification(features).getProbability(); + return getClassification(features).getProbability(); } /** * Find the {@link Prediction} in the decision tree. */ - private Prediction findClassification(Map features) { + public Prediction getClassification(Map features) { validateFeature(features); TreeNode currentNode = root; diff --git a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java index 0588293..8af8190 100644 --- a/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java +++ b/src/main/java/rocks/vilaverde/classifier/dt/PredictionFactory.java @@ -1,11 +1,3 @@ -///////////////////////////////////////////////////////////////////////////// -// PROPRIETARY RIGHTS STATEMENT -// The contents of this file represent confidential information that is the -// proprietary property of Edge2Web, Inc. Viewing or use of -// this information is prohibited without the express written consent of -// Edge2Web, Inc. Removal of this PROPRIETARY RIGHTS STATEMENT -// is strictly forbidden. Copyright (c) 2016 All rights reserved. -///////////////////////////////////////////////////////////////////////////// package rocks.vilaverde.classifier.dt; /** @@ -16,6 +8,7 @@ public interface PredictionFactory { PredictionFactory BOOLEAN = value -> Boolean.valueOf(value.toLowerCase()); PredictionFactory INTEGER = Integer::valueOf; + PredictionFactory DOUBLE = Double::parseDouble; T create(String value); } diff --git a/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java new file mode 100644 index 0000000..a0d59e9 --- /dev/null +++ b/src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java @@ -0,0 +1,14 @@ +package rocks.vilaverde.classifier.dt; + +import rocks.vilaverde.classifier.Classifier; +import rocks.vilaverde.classifier.Prediction; + +import java.util.Map; + +/** + * Implemented by Tree classifiers. + */ +public interface TreeClassifier extends Classifier { + + Prediction getClassification(Map features); +} diff --git a/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java new file mode 100644 index 0000000..eda0bda --- /dev/null +++ b/src/main/java/rocks/vilaverde/classifier/ensemble/RandomForestClassifier.java @@ -0,0 +1,190 @@ +package rocks.vilaverde.classifier.ensemble; + +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import rocks.vilaverde.classifier.Classifier; +import rocks.vilaverde.classifier.Prediction; +import rocks.vilaverde.classifier.dt.DecisionTreeClassifier; +import rocks.vilaverde.classifier.dt.PredictionFactory; +import rocks.vilaverde.classifier.dt.TreeClassifier; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.*; +import java.util.concurrent.*; +import java.util.stream.Collectors; + +/** + * A forest of DecisionTreeClassifiers. + */ +public class RandomForestClassifier implements Classifier { + private static final Logger LOG = LoggerFactory.getLogger(RandomForestClassifier.class); + + /** + * Accept a TAR of exported DecisionTreeClassifiers from sklearn and product a + * RandomForestClassifier. This default to running in a single (current) thread. + * @param tar the Tar Archive input stream + * @param factory the factory for creating the prediction class + * @return the {@link Classifier} + * @param the classifer type + * @throws Exception when the model could no be parsed + */ + public static Classifier parse(final TarArchiveInputStream tar, + PredictionFactory factory) throws Exception { + return RandomForestClassifier.parse(tar, factory, 0); + } + + /** + * Accept a TAR of exported DecisionTreeClassifiers from sklearn and product a + * RandomForestClassifier. This can be run in Parallel by + * @param tar the Tar Archive input stream + * @param factory the factory for creating the prediction class + * @param jobs the number of threads to use to search the forest. Using -1 will use all + * available threads. + * @return the {@link Classifier} + * @param the classifer type + * @throws Exception when the model could no be parsed + */ + public static Classifier parse(final TarArchiveInputStream tar, + PredictionFactory factory, + int jobs) throws Exception { + List> forest = new ArrayList<>(); + + try (tar) { + TarArchiveEntry exportedTree = null; + while ((exportedTree = tar.getNextTarEntry()) != null) { + if (!exportedTree.isDirectory()) { + LOG.debug("Parsing tree {}", exportedTree.getName()); + final InputStream noCloseStream = new InputStream() { + @Override + public int read() throws IOException { + return tar.read(); + } + + @Override + public void close() throws IOException { + // don't close otherwise next file in tar won't be read. + } + }; + BufferedReader reader = new BufferedReader(new InputStreamReader(noCloseStream)); + TreeClassifier tree = (TreeClassifier) DecisionTreeClassifier.parse(reader, factory); + forest.add(tree); + } + } + } + + return new RandomForestClassifier<>(forest, jobs); + } + + private int jobs; + private List> forest; + + /** + * Private Constructor + * @param forest + * @param jobs + */ + private RandomForestClassifier(List> forest, int jobs) { + this.forest = forest; + this.jobs = jobs; + } + + @Override + public T predict(Map features) { + + List> predictions = getPredictions(features); + + Map map = predictions.stream().collect( + Collectors.groupingBy(Prediction::get, Collectors.counting()) + ); + + long max = map.values().stream().mapToLong(Long::longValue).max().getAsLong(); + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() == max) { + return entry.getKey(); + } + } + + throw new IllegalStateException("no classification"); + } + + @Override + public double[] predict_proba(Map features) { + if (forest.size() == 1) { + return forest.get(0).getClassification(features).getProbability(); + } + + List> predictions = getPredictions(features); + + double[] result = null; + + for (Prediction prediction : predictions) { + double[] prob = prediction.getProbability(); + + if (result == null) { + result = prob; + } else { + for (int i = 0; i < prob.length; i++) { + result[i] += prob[i]; + } + } + } + + int forestSize = forest.size(); + for (int i = 0; i < result.length; i++) { + result[i] /= forestSize; + } + + return result; + } + + protected List> getPredictions(Map features) { + final List> predictions = new ArrayList<>(forest.size()); + + if (jobs == -1) { + jobs = Runtime.getRuntime().availableProcessors(); + } + + if (jobs > 0) { + ExecutorService executor = Executors.newFixedThreadPool(jobs); + + try { + for (TreeClassifier tree : forest) { + executor.submit(() -> { + Prediction pred = tree.getClassification(features); + synchronized (predictions) { + predictions.add(pred); + } + }); + } + } finally { + executor.shutdown(); + try { + executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS); + } catch (Exception e) { + LOG.error("interrupted while searching trees", e); + } + } + } else { + for (TreeClassifier tree : forest) { + Prediction prediction = tree.getClassification(features); + predictions.add(prediction); + } + } + + return predictions; + } + + @Override + public Set getFeatureNames() { + Set features = new HashSet<>(); + for (Classifier tree : forest) { + features.addAll(tree.getFeatureNames()); + } + return features; + } +} diff --git a/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java new file mode 100644 index 0000000..ae971c9 --- /dev/null +++ b/src/test/java/rocks/vilaverde/classifier/RandomForestClassifierTest.java @@ -0,0 +1,94 @@ +package rocks.vilaverde.classifier; + +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import rocks.vilaverde.classifier.dt.PredictionFactory; +import rocks.vilaverde.classifier.ensemble.RandomForestClassifier; + +import java.io.IOException; +import java.io.InputStream; +import java.util.HashMap; +import java.util.Map; + +/** + * Tests for the RandomForestClassifier + */ +public class RandomForestClassifierTest { + + @Test + public void randomForestParallel() throws Exception { + TarArchiveInputStream exported = getExportedModel("rf/iris.tgz"); + final Classifier decisionTree = RandomForestClassifier.parse(exported, + PredictionFactory.DOUBLE, 4); + Assertions.assertNotNull(decisionTree); + + double[] proba = decisionTree.predict_proba(getSample1()); + assertSample(proba, .06, .62, .32); + + Double prediction = decisionTree.predict(getSample1()); + Assertions.assertNotNull(prediction); + Assertions.assertEquals(1.0, prediction.doubleValue(), .0); + + prediction = decisionTree.predict(getSample2()); + Assertions.assertEquals(2, prediction.intValue()); + + proba = decisionTree.predict_proba(getSample2()); + assertSample(proba, 0.0, .44, .56); + } + + @Test + public void randomForest() throws Exception { + TarArchiveInputStream exported = getExportedModel("rf/iris.tgz"); + final Classifier decisionTree = RandomForestClassifier.parse(exported, PredictionFactory.DOUBLE); + Assertions.assertNotNull(decisionTree); + + double[] proba = decisionTree.predict_proba(getSample1()); + assertSample(proba, .06, .62, .32); + + Double prediction = decisionTree.predict(getSample1()); + Assertions.assertNotNull(prediction); + Assertions.assertEquals(1.0, prediction.doubleValue(), .0); + + prediction = decisionTree.predict(getSample2()); + Assertions.assertEquals(2, prediction.intValue()); + + proba = decisionTree.predict_proba(getSample2()); + assertSample(proba, 0.0, .44, .56); + } + + private Map getSample1() { + Map features = new HashMap<>(); + features.put("sepal length (cm)", 3.0); + features.put("sepal width (cm)", 5.0); + features.put("petal length (cm)", 4.0); + features.put("petal width (cm)", 2.0); + return features; + } + + private Map getSample2() { + Map features = new HashMap<>(); + features.put("sepal length (cm)", 1.0); + features.put("sepal width (cm)", 2.0); + features.put("petal length (cm)", 3.0); + features.put("petal width (cm)", 4.0); + return features; + } + + private void assertSample(double[] proba, double expected, double expected1, double expected2) { + Assertions.assertNotNull(proba); + Assertions.assertEquals(expected, proba[0], .0); + Assertions.assertEquals(expected1, proba[1], .0); + Assertions.assertEquals(expected2, proba[2], .0); + } + + private TarArchiveInputStream getExportedModel(String fileName) throws IOException { + ClassLoader cl = DecisionTreeClassifierTest.class.getClassLoader(); + InputStream stream = cl.getResourceAsStream(fileName); + if (stream == null) { + throw new RuntimeException(String.format("no zip found with name %s", fileName)); + } + return new TarArchiveInputStream(new GzipCompressorInputStream(stream)); + } +} diff --git a/src/test/resources/rf/iris.tgz b/src/test/resources/rf/iris.tgz new file mode 100644 index 0000000..158dc96 Binary files /dev/null and b/src/test/resources/rf/iris.tgz differ diff --git a/src/test/resources/simplelogger.properties b/src/test/resources/simplelogger.properties new file mode 100644 index 0000000..40cfffe --- /dev/null +++ b/src/test/resources/simplelogger.properties @@ -0,0 +1 @@ +org.slf4j.simpleLogger.defaultLogLevel=info \ No newline at end of file