Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .idea/copyright/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions .idea/libraries/Maven__org_slf4j_slf4j_api_1_7_36.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions .idea/libraries/Maven__org_slf4j_slf4j_simple_1_7_25.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

85 changes: 67 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,29 @@

This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.

For example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree`
export_text() as shown below:
## Support
* The tree.DecisionTreeClassifier is supported
* Supports `predict()`,
* Supports `predict_proba()` when `export_text()` configured with `show_weights=True`
* The tree.RandomForestClassifier is supported
* Supports `predict()`,
* Supports `predict_proba()` when `export_text()` configured with `show_weights=True`

## Installing

### Importing Maven Dependency
```xml
<dependency>
<groupId>rocks.vilaverde</groupId>
<artifactId>scikit-learn-2-java</artifactId>
<version>1.0.0</version>
</dependency>
```

## DecisionTreeClassifier

As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree`
`export_text()` as shown below:

```
>>> from sklearn.datasets import load_iris
Expand All @@ -26,17 +47,8 @@ export_text() as shown below:
| | |--- class: 2
```

can be executed in Java Maven. Note that when calling `export_text` it is recommended that `max_depth` be set
to sys.maxsize so that the tree isn't truncated.

### Importing Maven Dependency
```xml
<dependency>
<groupId>rocks.vilaverde</groupId>
<artifactId>scikit-learn-2-java</artifactId>
<version>1.0.0</version>
</dependency>
```
The exported text can then be executed in Java. Note that when calling `export_text` it is
recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated.

### Java Example
In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map
Expand All @@ -45,7 +57,7 @@ and the decision tree is asked to predict the class.
```
Reader tree = getTrainedModel("iris.model");
final Classifier<Integer> decisionTree = DecisionTreeClassifier.parse(tree,
new PredictionFactory.IntegerPredictionFactory());
PredictionFactory.INTEGER);

Map<String, Double> features = new HashMap<>();
features.put("sepal length (cm)", 3.0);
Expand All @@ -57,10 +69,47 @@ and the decision tree is asked to predict the class.
System.out.println(prediction.toString());
```

## Support
* The tree.DecisionTreeClassifier is supported
* Supports `predict()`,
* Supports `predict_proba()` when `export_text()` configured with `show_weights=True`
## RandomForestClassifier

To use a RandomForestClassifier that has been trained on the Iris dataset, each of the `estimators`
in the classifiers need to be and exported using `from sklearn.tree export export_text` as shown below:

```
>>> from sklearn import datasets
>>> from sklearn import tree
>>> from sklearn.ensemble import RandomForestClassifier
>>>
>>> import os
>>>
>>> iris = datasets.load_iris()
>>> X = iris.data
>>> y = iris.target
>>>
>>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8)
>>> model = clf.fit(X, y)
>>>
>>> for i, t in enumerate(clf.estimators_):
>>> with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1:
>>> text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize)
>>> file1.write(text_representation)
```

Once all the estimators are exported into `/tmp/estimators`, you can create a TAR archive, for example:
```bash
cd /tmp/estimators
tar -czvf /tmp/iris.tgz .
```

Then you can use the RandomForestClassifier class to parse the TAR archive.

```
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
...

TarArchiveInputStream tree = getArchive("iris.tgz");
final Classifier<Integer> decisionTree = RandomForestClassifier.parse(tree,
PredictionFactory.DOUBLE);
```

## Testing
Testing was done using sci-kit learn 1.1.3.
19 changes: 18 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>rocks.vilaverde</groupId>
<artifactId>scikit-learn-2-java</artifactId>
<version>1.0.0-SNAPSHOT</version>
<version>1.0.1-SNAPSHOT</version>

<name>${project.groupId}:${project.artifactId}</name>
<description>A sklearn exported_text models parser for executing in the Java runtime.</description>
Expand Down Expand Up @@ -54,11 +54,28 @@
</dependencyManagement>

<dependencies>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>1.21</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.36</version>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.25</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
3 changes: 3 additions & 0 deletions scikit-learn-2-java.iml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.21" level="project" />
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.36" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter:5.9.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.9.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.2.0" level="project" />
Expand All @@ -20,5 +22,6 @@
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.9.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.9.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.9.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.slf4j:slf4j-simple:1.7.25" level="project" />
</component>
</module>
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package rocks.vilaverde.classifier.dt;

import rocks.vilaverde.classifier.Prediction;
import rocks.vilaverde.classifier.Classifier;
import rocks.vilaverde.classifier.Operator;
import rocks.vilaverde.classifier.Prediction;

import java.io.BufferedReader;
import java.io.Reader;
import java.util.Map;
import java.util.Set;
import java.util.Stack;

public class DecisionTreeClassifier<T> implements Classifier<T> {
public class DecisionTreeClassifier<T> implements TreeClassifier<T> {

/**
* Factory method to create the classifier from the {@link Reader}.
Expand All @@ -19,7 +19,7 @@ public class DecisionTreeClassifier<T> implements Classifier<T> {
* @param <T> class
* @throws Exception when the model could no be parsed
*/
public static <T> Classifier<T> parse(Reader reader, PredictionFactory<T> factory) throws Exception {
public static <T> DecisionTreeClassifier<T> parse(Reader reader, PredictionFactory<T> factory) throws Exception {

try (reader) {
DecisionTreeClassifier<T> classifier = new DecisionTreeClassifier<>(factory);
Expand Down Expand Up @@ -53,7 +53,7 @@ private DecisionTreeClassifier(PredictionFactory<T> predictionFactory) {
* @return predicted class
*/
public T predict(Map<String, Double> features) {
return findClassification(features).get();
return getClassification(features).get();
}

/**
Expand All @@ -64,13 +64,13 @@ public T predict(Map<String, Double> features) {
*/
@Override
public double[] predict_proba(Map<String, Double> features) {
return findClassification(features).getProbability();
return getClassification(features).getProbability();
}

/**
* Find the {@link Prediction} in the decision tree.
*/
private Prediction<T> findClassification(Map<String, Double> features) {
public Prediction<T> getClassification(Map<String, Double> features) {
validateFeature(features);

TreeNode currentNode = root;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
/////////////////////////////////////////////////////////////////////////////
// PROPRIETARY RIGHTS STATEMENT
// The contents of this file represent confidential information that is the
// proprietary property of Edge2Web, Inc. Viewing or use of
// this information is prohibited without the express written consent of
// Edge2Web, Inc. Removal of this PROPRIETARY RIGHTS STATEMENT
// is strictly forbidden. Copyright (c) 2016 All rights reserved.
/////////////////////////////////////////////////////////////////////////////
package rocks.vilaverde.classifier.dt;

/**
Expand All @@ -16,6 +8,7 @@ public interface PredictionFactory<T> {

PredictionFactory<Boolean> BOOLEAN = value -> Boolean.valueOf(value.toLowerCase());
PredictionFactory<Integer> INTEGER = Integer::valueOf;
PredictionFactory<Double> DOUBLE = Double::parseDouble;

T create(String value);
}
14 changes: 14 additions & 0 deletions src/main/java/rocks/vilaverde/classifier/dt/TreeClassifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package rocks.vilaverde.classifier.dt;

import rocks.vilaverde.classifier.Classifier;
import rocks.vilaverde.classifier.Prediction;

import java.util.Map;

/**
* Implemented by Tree classifiers.
*/
public interface TreeClassifier<T> extends Classifier<T> {

Prediction<T> getClassification(Map<String, Double> features);
}
Loading