Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .idea/codeStyles/codeStyleConfig.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.

[![javadoc](https://javadoc.io/badge2/rocks.vilaverde/scikit-learn-2-java/javadoc.svg)](https://javadoc.io/doc/rocks.vilaverde/scikit-learn-2-java)

## Support
* The tree.DecisionTreeClassifier is supported
* Supports `predict()`,
Expand Down Expand Up @@ -48,24 +50,25 @@ As an example, a DecisionTreeClassifier model trained on the Iris dataset and ex
```

The exported text can then be executed in Java. Note that when calling `export_text` it is
recommended that `max_depth` be set to sys.maxsize so that the tree isn't truncated.
recommended that `max_depth` be set to `sys.maxsize` so that the tree isn't truncated.

### Java Example
In this example the iris model exported using `export_tree` is parsed, features are created as a Java Map
In this example the iris model exported using `export_text` is parsed, features are created as a Java Map
and the decision tree is asked to predict the class.

```
Reader tree = getTrainedModel("iris.model");
final Classifier<Integer> decisionTree = DecisionTreeClassifier.parse(tree,
PredictionFactory.INTEGER);

Map<String, Double> features = new HashMap<>();
features.put("sepal length (cm)", 3.0);
features.put("sepal width (cm)", 5.0);
features.put("petal length (cm)", 4.0);
features.put("petal width (cm)", 2.0);
Features features = Features.of("sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)");
FeatureVector fv = features.newSample();
fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);

Integer prediction = decisionTree.predict(features);
Integer prediction = decisionTree.predict(fv);
System.out.println(prediction.toString());
```

Expand Down Expand Up @@ -107,7 +110,7 @@ Then you can use the RandomForestClassifier class to parse the TAR archive.
...

TarArchiveInputStream tree = getArchive("iris.tgz");
final Classifier<Integer> decisionTree = RandomForestClassifier.parse(tree,
final Classifier<Double> decisionTree = RandomForestClassifier.parse(tree,
PredictionFactory.DOUBLE);
```

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>rocks.vilaverde</groupId>
<artifactId>scikit-learn-2-java</artifactId>
<version>1.0.3-SNAPSHOT</version>
<version>1.1.0-SNAPSHOT</version>

<name>${project.groupId}:${project.artifactId}</name>
<description>A sklearn exported_text models parser for executing in the Java runtime.</description>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package rocks.vilaverde.classifier;

import rocks.vilaverde.classifier.dt.TreeClassifier;

import java.util.Map;

/**
* Abstract base class for Tree classifiers.
*/
public abstract class AbstractTreeClassifier<T> implements TreeClassifier<T> {

/**
* Predict class or regression value for features.
*
* @param samples input samples
* @return class probabilities of the input sample
*/
@Override
public T predict(Map<String, Double> samples) {
FeatureVector fv = toFeatureVector(samples);
return predict(fv).get(0);
}

/**
* Predict class probabilities of the input samples features.
* The predicted class probability is the fraction of samples of the same class in a leaf.
*
* @param samples the input samples
* @return the class probabilities of the input sample
*/
@Override
public double[] predict_proba(Map<String, Double> samples) {
FeatureVector fv = toFeatureVector(samples);
return predict_proba(fv)[0];
}

/**
* Convert a Map of features to a {@link FeatureVector}.
* @param samples a KV map of feature name to value
* @return FeatureVector
*/
private FeatureVector toFeatureVector(Map<String, Double> samples) {
Features features = Features.fromSet(samples.keySet());
FeatureVector fv = features.newSample();
for (Map.Entry<String, Double> entry : samples.entrySet()) {
fv.add(entry.getKey(), entry.getValue());
}
return fv;
}
}
19 changes: 0 additions & 19 deletions src/main/java/rocks/vilaverde/classifier/BooleanFeature.java

This file was deleted.

28 changes: 24 additions & 4 deletions src/main/java/rocks/vilaverde/classifier/Classifier.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
package rocks.vilaverde.classifier;

import java.util.List;
import java.util.Map;
import java.util.Set;

public interface Classifier<T> {

/**
* Predict class or regression value for samples. Predictions will be
* returned at the same index of the sample provided.
* @param samples input samples
* @return class probabilities of the input sample
*/
List<T> predict(FeatureVector ... samples);

/**
* Predict class probabilities of the input samples. Probabilities will be
* returned at the same index of the sample provided.
* The predicted class probability is the fraction of samples of the same class in a leaf.
* @param samples the input samples
* @return the class probabilities of the input sample
*/
double[][] predict_proba(FeatureVector ... samples);

/**
* Predict class or regression value for features.
* @param features input samples
* @param samples input samples
* @return class probabilities of the input sample
*/
T predict(Map<String, Double> features);
@Deprecated
T predict(Map<String, Double> samples);

/**
* Predict class probabilities of the input samples features.
* The predicted class probability is the fraction of samples of the same class in a leaf.
* @param features the input samples
* @param samples the input samples
* @return the class probabilities of the input sample
*/
double[] predict_proba(Map<String, Double> features);
@Deprecated
double[] predict_proba(Map<String, Double> samples);

/**
* Get the names of all the features in the model.
Expand Down
53 changes: 53 additions & 0 deletions src/main/java/rocks/vilaverde/classifier/FeatureVector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package rocks.vilaverde.classifier;

/**
* A container for the values for each feature of a sample that will be predicted.
*/
public class FeatureVector {

private final Features features;
private final double[] vector;

public FeatureVector(Features features) {
this.features = features;
this.vector = new double[features.getLength()];
}

public FeatureVector add(String feature, boolean value) {
add(feature, value ? 1.0 : 0.0);
return this;
}

public FeatureVector add(int index, boolean value) {
add(index, value ? 1.0 : 0.0);
return this;
}

public FeatureVector add(int index, double value) {
this.vector[index] = value;
return this;
}

public FeatureVector add(String feature, double value) {
int index = this.features.getFeatureIndex(feature);
add(index, value);
return this;
}

public double get(int index) {
if (index >= vector.length) {
throw new IllegalArgumentException(String.format("index must be less than %d", index));
}

return vector[index];
}

public double get(String feature) {
int index = features.getFeatureIndex(feature);
return get(index);
}

public boolean hasFeature(String feature) {
return this.features.getFeatureNames().contains(feature);
}
}
91 changes: 91 additions & 0 deletions src/main/java/rocks/vilaverde/classifier/Features.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package rocks.vilaverde.classifier;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Set;

/**
* Container class for the set of named features that will be provided for each sample.
* Call {@link Features#newSample()} to create a {@link FeatureVector} to provide the
* values for the sample.
*/
public class Features {

/* Map of feature name to index */
private final Map<String, Integer> features = new HashMap<>();
/* Feature can be added as long as no samples have yet to be created,
at that point this is immutable */
private boolean allowFeatureAdd = true;

/**
* Convienence creation method.
* @param features the set of features.
* @return Features
*/
public static Features of(String ... features) {

// make sure all the features are unique
Set<String> featureSet = new HashSet<>(Arrays.asList(features));
if (featureSet.size() != features.length) {
throw new IllegalArgumentException("features names are not unique");
}

return new Features(features);
}

public static Features fromSet(Set<String> features) {
return new Features(features.toArray(new String[0]));
}

/**
* Constructor
* @param features order list of features.
*/
private Features(String ... features) {
for (int i = 0; i < features.length; i++) {
this.features.put(features[i], i);
}
}

FeatureVector newSample() {
allowFeatureAdd = false;
return new FeatureVector(this);
}

public void addFeature(String feature) {
if (!allowFeatureAdd) {
throw new IllegalStateException("features are immutable");
}

if (!this.features.containsKey(feature)) {
int next = 0;
OptionalInt optionalInt = features.values().stream().mapToInt(integer -> integer).max();
if (optionalInt.isPresent()) {
next = optionalInt.getAsInt() + 1;
}

this.features.put(feature, next);
}
}

public int getLength() {
return this.features.size();
}

public int getFeatureIndex(String feature) {
Integer index = this.features.get(feature);
if (index == null) {
throw new IllegalArgumentException(String.format("feature %s does not exist", feature));
}

return index;
}

public Set<String> getFeatureNames() {
return Collections.unmodifiableSet(this.features.keySet());
}
}
4 changes: 2 additions & 2 deletions src/main/java/rocks/vilaverde/classifier/dt/ChoiceNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Represents a Choice in the decision tree, where when the expression is evaluated,
* if true will result in the child node of the choice being selected.
*/
class ChoiceNode extends TreeNode {
public class ChoiceNode extends TreeNode {
private final Operator op;
private final Double value;

Expand Down Expand Up @@ -38,7 +38,7 @@ public String toString() {
return String.format("%s %s", op.toString(), value.toString());
}

public boolean eval(Double featureValue) {
public boolean eval(double featureValue) {
return op.apply(featureValue, value);
}
}
7 changes: 2 additions & 5 deletions src/main/java/rocks/vilaverde/classifier/dt/DecisionNode.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package rocks.vilaverde.classifier.dt;

import java.util.ArrayList;
import java.util.List;

/**
* Represents a decision in the DecisionTreeClassifier. The decision will have
* a left and right hand {@link ChoiceNode} to be evaluated.
* A {@link ChoiceNode} may have nested {@link DecisionNode} or {@link EndNode}.
*/
class DecisionNode extends TreeNode {
public class DecisionNode extends TreeNode {

private final String featureName;

Expand All @@ -26,7 +23,7 @@ public static DecisionNode create(String feature) {

/**
* Private Constructor.
* @param featureName
* @param featureName the name of the feature used in this decision
*/
private DecisionNode(String featureName) {
this.featureName = featureName.intern();
Expand Down
Loading