Skip to content

Commit

Permalink
Merged version 1.2.8. Fixes jpmml/jpmml-evaluator#107
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 31, 2018
2 parents a19492a + 629437c commit 82b7b7d
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 46 deletions.
44 changes: 15 additions & 29 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>1.48</version>
<version>1.72</version>
</dependency>

<dependency>
Expand All @@ -65,25 +65,7 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-converter</artifactId>
<version>1.2.6</version>
<exclusions>
<exclusion>
<groupId>com.sun.xml.fastinfoset</groupId>
<artifactId>FastInfoset</artifactId>
</exclusion>
<exclusion>
<groupId>javax.xml.bind</groupId>
<artifactId>jaxb-api</artifactId>
</exclusion>
<exclusion>
<groupId>org.glassfish.jaxb</groupId>
<artifactId>txw2</artifactId>
</exclusion>
<exclusion>
<groupId>org.jvnet.staxex</groupId>
<artifactId>stax-ex</artifactId>
</exclusion>
</exclusions>
<version>1.3.0</version>
</dependency>

<dependency>
Expand All @@ -103,13 +85,13 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.3.10</version>
<version>1.4.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-test</artifactId>
<version>1.3.10</version>
<version>1.4.1</version>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down Expand Up @@ -137,7 +119,7 @@
<configuration>
<rules>
<requireJavaVersion>
<version>1.7</version>
<version>1.8</version>
</requireJavaVersion>
</rules>
</configuration>
Expand All @@ -156,6 +138,14 @@
</archive>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<javadocVersion>1.8</javadocVersion>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-release-plugin</artifactId>
Expand All @@ -168,7 +158,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>2.4.3</version>
<version>3.1.0</version>
<executions>
<execution>
<phase>package</phase>
Expand Down Expand Up @@ -200,10 +190,6 @@
<pattern>org.jpmml.model</pattern>
<shadedPattern>org.shaded.jpmml.model</shadedPattern>
</relocation>
<relocation>
<pattern>org.jpmml.schema</pattern>
<shadedPattern>org.shaded.jpmml.schema</shadedPattern>
</relocation>
</relocations>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
Expand Down Expand Up @@ -239,7 +225,7 @@
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>0.7.9</version>
<version>0.8.1</version>
<executions>
<execution>
<id>pre-unit-test</id>
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/jpmml/sparkml/DocumentFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.util.Set;

import com.google.common.base.Objects.ToStringHelper;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.Field;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;

Expand All @@ -34,7 +34,7 @@ public class DocumentFeature extends Feature {
private Set<StopWordSet> stopWordSets = new LinkedHashSet<>();


public DocumentFeature(SparkMLEncoder encoder, TypeDefinitionField field, String wordSeparatorRE){
public DocumentFeature(SparkMLEncoder encoder, Field<?> field, String wordSeparatorRE){
super(encoder, field.getName(), field.getDataType());

setWordSeparatorRE(wordSeparatorRE);
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/jpmml/sparkml/ModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
Expand Down Expand Up @@ -99,7 +99,7 @@ public Schema encodeSchema(SparkMLEncoder encoder){
categories.add(String.valueOf(i));
}

TypeDefinitionField field = encoder.toCategorical(continuousFeature.getName(), categories);
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);

encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));

Expand All @@ -113,7 +113,7 @@ public Schema encodeSchema(SparkMLEncoder encoder){
break;
case REGRESSION:
{
TypeDefinitionField field = encoder.toContinuous(feature.getName());
Field<?> field = encoder.toContinuous(feature.getName());

field.setDataType(DataType.DOUBLE);

Expand Down
3 changes: 1 addition & 2 deletions src/main/java/org/jpmml/sparkml/TermFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ public Apply createApply(){
Feature feature = getFeature();
String value = getValue();

Constant constant = PMMLUtil.createConstant(value)
.setDataType(DataType.STRING);
Constant constant = PMMLUtil.createConstant(value, DataType.STRING);

return PMMLUtil.createApply(defineFunction.getName(), feature.ref(), constant);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.ValueUtil;
Expand Down Expand Up @@ -71,7 +71,7 @@ public List<Feature> encodeFeatures(SparkMLEncoder encoder){

Feature feature = encoder.getOnlyFeature(inputCol);

TypeDefinitionField field = encoder.getField(feature.getName());
Field<?> field = encoder.getField(feature.getName());

if(field instanceof DataField){
DataField dataField = (DataField)field;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import org.apache.spark.ml.feature.RegexTokenizer;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.PMMLUtil;
Expand Down Expand Up @@ -53,7 +53,7 @@ public List<Feature> encodeFeatures(SparkMLEncoder encoder){

Feature feature = encoder.getOnlyFeature(transformer.getInputCol());

TypeDefinitionField field = encoder.getField(feature.getName());
Field<?> field = encoder.getField(feature.getName());

if(transformer.getToLowercase()){
Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import org.apache.spark.ml.feature.StringIndexerModel;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
Expand All @@ -55,7 +56,7 @@ public List<Feature> encodeFeatures(SparkMLEncoder encoder){

String handleInvalid = transformer.getHandleInvalid();

TypeDefinitionField field = encoder.toCategorical(feature.getName(), categories);
Field<?> field = encoder.toCategorical(feature.getName(), categories);

if(field instanceof DataField){
DataField dataField = (DataField)field;
Expand Down Expand Up @@ -92,12 +93,12 @@ public List<Feature> encodeFeatures(SparkMLEncoder encoder){
Apply setApply = PMMLUtil.createApply("isIn", feature.ref());

for(String category : categories){
setApply.addExpressions(PMMLUtil.createConstant(category));
setApply.addExpressions(PMMLUtil.createConstant(category, feature.getDataType()));
}

categories.add(StringIndexerModelConverter.LABEL_UNKNOWN);

Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN));
Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN, DataType.STRING));

field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, feature.getDataType(), apply);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public RegressionModel encodeModel(Schema schema){
Matrix coefficientMatrix = model.coefficientMatrix();
Vector interceptVector = model.interceptVector();

List<Feature> features = schema.getFeatures();
List<? extends Feature> features = schema.getFeatures();

List<RegressionTable> regressionTables = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public NeuralNetwork encodeModel(Schema schema){
throw new IllegalArgumentException();
}

List<Feature> features = schema.getFeatures();
List<? extends Feature> features = schema.getFeatures();
if(features.size() != layers[0]){
throw new IllegalArgumentException();
}
Expand Down

0 comments on commit 82b7b7d

Please sign in to comment.