Skip to content

Commit

Permalink
Added support for the 'WrappedModel' model type. Fixes #46
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 3, 2020
1 parent bf5703d commit 496248c
Show file tree
Hide file tree
Showing 18 changed files with 4,449 additions and 25 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Java library and command-line application for converting [R](https://www.r-proje
* `cv.glmnet` - Cross-validated GLMNet regression and calculation
* [`IsolationForest`](https://r-forge.r-project.org/R/?group_id=479) package:
* `iForest` - Isolation Forest (IF) anomaly detection
* [`mlr`](https://cran.r-project.org/package=mlr) package:
* `WrappedModel` - Selected JPMML-R model types.
* [`neuralnet`](https://cran.r-project.org/package=neuralnet) package:
* `nn` - Neural Network (NN) regression
* [`nnet`](https://cran.r-project.org/package=nnet) package:
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/jpmml/rexp/ConverterFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public ConverterFactory newInstance(){
converters.put("scorecard", ScorecardConverter.class);
converters.put("svm", SVMConverter.class);
converters.put("train", TrainConverter.class);
converters.put("WrappedModel", WrappedModelConverter.class);
converters.put("xgb.Booster", XGBoostConverter.class);
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/jpmml/rexp/DecorationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ public RGenericVector getGenericElement(RGenericVector model, String name){
}
}

static
public RBooleanVector getBooleanElement(RGenericVector model, String name){

try {
return model.getBooleanElement(name, true);
} catch(IllegalArgumentException iae){
throw toDecorationException(model, name, iae);
}
}

static
public RNumberVector<?> getNumericElement(RGenericVector model, String name){

Expand Down
10 changes: 6 additions & 4 deletions src/main/java/org/jpmml/rexp/GLMConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void encodeSchema(RExpEncoder encoder){
RGenericVector glm = getObject();

RGenericVector family = glm.getGenericElement("family");
RGenericVector model = glm.getGenericElement("model");
RGenericVector model = glm.getGenericElement("model", false);

RStringVector familyFamily = family.getStringElement("family");

Expand All @@ -55,11 +55,13 @@ public void encodeSchema(RExpEncoder encoder){
case CLASSIFICATION:
Label label = encoder.getLabel();

RIntegerVector variable = model.getFactorElement((label.getName()).getValue());
if(model != null){
RIntegerVector variable = model.getFactorElement((label.getName()).getValue());

DataField dataField = (DataField)encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));
DataField dataField = (DataField)encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));

encoder.setLabel(dataField);
encoder.setLabel(dataField);
}
break;
default:
break;
Expand Down
69 changes: 52 additions & 17 deletions src/main/java/org/jpmml/rexp/LMConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,35 +41,70 @@ public void encodeSchema(RExpEncoder encoder){
RGenericVector lm = getObject();

RGenericVector xlevels = lm.getGenericElement("xlevels", false);
RGenericVector model = lm.getGenericElement("model");
RGenericVector model = lm.getGenericElement("model", false);
RGenericVector data = lm.getGenericElement("data", false);

RExp terms = model.getAttribute("terms");
RExp terms;

FormulaContext context = new ModelFrameFormulaContext(model){
FormulaContext context;

@Override
public List<String> getCategories(String variable){
if(model != null){
terms = model.getAttribute("terms");

if(xlevels != null && xlevels.hasElement(variable)){
RStringVector levels = xlevels.getStringElement(variable);
context = new ModelFrameFormulaContext(model){

return levels.getValues();
@Override
public List<String> getCategories(String variable){

if(xlevels != null && xlevels.hasElement(variable)){
RStringVector levels = xlevels.getStringElement(variable);

return levels.getValues();
}

return super.getCategories(variable);
}

@Override
public RVector<?> getData(String variable){

if(data != null && data.hasElement(variable)){
return data.getVectorElement(variable);
}

return super.getData(variable);
}
};
} else

return super.getCategories(variable);
}
{
terms = lm.getElement("terms");

@Override
public RVector<?> getData(String variable){
context = new FormulaContext(){

if(data != null && data.hasElement(variable)){
return data.getVectorElement(variable);
@Override
public List<String> getCategories(String variable){

if(xlevels != null && xlevels.hasElement(variable)){
RStringVector levels = xlevels.getStringElement(variable);

return levels.getValues();
}

return null;
}

return super.getData(variable);
}
};
@Override
public RVector<?> getData(String variable){

if(data != null && data.hasElement(variable)){
return data.getVectorElement(variable);
}

return null;
}
};
}

encodeSchema(terms, context, encoder);
}
Expand Down
111 changes: 111 additions & 0 deletions src/main/java/org/jpmml/rexp/WrappedModelConverter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (c) 2020 Villu Ruusmann
*
* This file is part of JPMML-R
*
* JPMML-R is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-R is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-R. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.rexp;

import java.util.List;

import com.google.common.collect.Lists;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;

public class WrappedModelConverter extends FilterModelConverter<RGenericVector, RExp> {

public WrappedModelConverter(RGenericVector wrappedModel){
super(wrappedModel);
}

@Override
public void encodeSchema(RExpEncoder encoder){
RGenericVector wrappedModel = getObject();

RGenericVector taskDesc = wrappedModel.getGenericElement("task.desc");

RStringVector type = taskDesc.getStringElement("type");
RStringVector target = taskDesc.getStringElement("target");

super.encodeSchema(encoder);

FieldName targetName = FieldName.create(target.asScalar());

DataField dataField = encoder.getDataField(targetName);

switch(type.asScalar()){
case "regr":
{
if(dataField == null){
dataField = encoder.createDataField(targetName, OpType.CONTINUOUS, DataType.DOUBLE);

encoder.setLabel(dataField);
}
}
break;
case "classif":
{
RVector<?> classLevels = taskDesc.getVectorElement("class.levels");

List<?> values = classLevels.getValues();

if(dataField == null){
dataField = encoder.createDataField(targetName, OpType.CATEGORICAL, null, values);

encoder.setLabel(dataField);
} // End if

if(!(OpType.CATEGORICAL).equals(dataField.getOpType())){
dataField = (DataField)encoder.toCategorical(targetName, values);

encoder.setLabel(dataField);
} // End if

if(classLevels.size() == 2){
RBooleanVector invertLevels = DecorationUtil.getBooleanElement(wrappedModel, "invert_levels");

if(invertLevels.asScalar()){
Label label = new CategoricalLabel(dataField.getName(), dataField.getDataType(), Lists.reverse(values));

encoder.setLabel(label);
}
}
}
break;
default:
throw new IllegalArgumentException();
}
}

@Override
public Model encodeModel(Schema schema){
return super.encodeModel(schema);
}

@Override
public ModelConverter<RExp> createConverter(){
RGenericVector wrappedModel = getObject();

RExp learnerModel = wrappedModel.getElement("learner.model");

return (ModelConverter<RExp>)newConverter(learnerModel);
}
}
23 changes: 21 additions & 2 deletions src/test/R/gbm.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
library("caret")
library("dplyr")
library("gbm")
library("mlr")
library("r2pmml")

source("util.R")
Expand Down Expand Up @@ -43,6 +44,24 @@ set.seed(42)
generateGBMAdaBoostAuditNA()
generateGBMBernoulliAuditNA()

generateWrappedGBMAdaBoostAudit = function(){
audit.task = makeClassifTask(data = audit, target = "Adjusted")
classif.gbm = makeLearner("classif.gbm", distribution = "adaboost", shrinkage = 0.1, n.trees = 100, predict.type = "prob")

audit.lmr = mlr::train(classif.gbm, audit.task)
audit.lmr = decorate(audit.lmr, invert_levels = TRUE)
print(audit.lmr)

adjusted = as.data.frame(predict(audit.lmr, newdata = audit))

storeRds(audit.lmr, "WrappedGBMAdaBoostAuditNA")
storeCsv(data.frame("Adjusted" = adjusted$response, "probability(0)" = adjusted$prob.0, "probability(1)" = adjusted$prob.1, check.names = FALSE), "WrappedGBMAdaBoostAuditNA")
}

set.seed(42)

generateWrappedGBMAdaBoostAudit()

iris = loadIrisCsv("Iris")

iris_x = iris[, -ncol(iris)]
Expand Down Expand Up @@ -82,7 +101,7 @@ generateGBMFormulaIris()
generateGBMIris()

generateTrainGBMFormulaIris = function(){
iris.train = train(Species ~ ., data = iris, method = "gbm", response.name = "Species")
iris.train = caret::train(Species ~ ., data = iris, method = "gbm", response.name = "Species")
iris.train = verify(iris.train, newdata = sample_n(iris, 10))
print(iris.train)

Expand Down Expand Up @@ -131,7 +150,7 @@ auto.caret = auto
auto.caret$origin = as.integer(auto.caret$origin)

generateTrainGBMFormulaAutoNA = function(){
auto.train = train(mpg ~ ., data = auto.caret, method = "gbm", na.action = na.pass, response.name = "mpg")
auto.train = caret::train(mpg ~ ., data = auto.caret, method = "gbm", na.action = na.pass, response.name = "mpg")
auto.train = verify(auto.train, newdata = sample_n(auto.caret, 50), na.action = na.pass)
print(auto.train)

Expand Down
23 changes: 21 additions & 2 deletions src/test/R/glm.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
library("caret")
library("mlr")
library("plyr")
library("dplyr")
library("recipes")
Expand Down Expand Up @@ -38,7 +39,7 @@ generateGLMFormulaAudit()
generateGLMCustFormulaAudit()

generateTrainGLMFormulaAudit = function(){
audit.train = train(audit.recipe, data = audit, method = "glm")
audit.train = caret::train(audit.recipe, data = audit, method = "glm")
audit.train = verify(audit.train, newdata = sample_n(audit, 100))
print(audit.train)

Expand All @@ -51,6 +52,24 @@ generateTrainGLMFormulaAudit = function(){

generateTrainGLMFormulaAudit()

audit$Deductions = NULL

generateWrappedGLMFormulaAudit = function(){
audit.task = makeClassifTask(data = audit, target = "Adjusted")
classif.glm = makeLearner("classif.binomial", predict.type = "prob")

audit.lmr = mlr::train(classif.glm, audit.task)
audit.lmr = decorate(audit.lmr, invert_levels = FALSE)
print(audit.lmr)

adjusted = as.data.frame(predict(audit.lmr, newdata = audit))

storeRds(audit.lmr, "WrappedGLMFormulaAudit")
storeCsv(data.frame("Adjusted" = adjusted$response, "probability(0)" = adjusted$prob.0, "probability(1)" = adjusted$prob.1, check.names = FALSE), "WrappedGLMFormulaAudit")
}

generateWrappedGLMFormulaAudit()

auto = loadAutoCsv("Auto")

auto.recipe = recipe(mpg ~ ., data = auto)
Expand Down Expand Up @@ -79,7 +98,7 @@ generateGLMFormulaAuto()
generateGLMCustFormulaAuto()

generateTrainGLMFormulaAuto = function(){
auto.train = train(auto.recipe, data = auto, method = "glm")
auto.train = caret::train(auto.recipe, data = auto, method = "glm")
auto.train = verify(auto.train, newdata = sample_n(auto, 50))
print(auto.train)

Expand Down
16 changes: 16 additions & 0 deletions src/test/R/lm.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
library("mlr")
library("plyr")

source("util.R")
Expand Down Expand Up @@ -27,6 +28,21 @@ generateLMCustFormulaAuto = function(){
generateLMFormulaAuto()
generateLMCustFormulaAuto()

generateWrappedLMFormulaAuto = function(){
auto.task = makeRegrTask(data = auto, target = "mpg")
regr.lm = makeLearner("regr.lm")

auto.mlr = train(regr.lm, auto.task)
print(auto.mlr)

mpg = as.data.frame(predict(auto.mlr, newdata = auto))

storeRds(auto.mlr, "WrappedLMFormulaAuto")
storeCsv(data.frame("mpg" = mpg$response), "WrappedLMFormulaAuto")
}

generateWrappedLMFormulaAuto()

wine_quality = loadWineQualityCsv("WineQuality")

generateLMFormulaWineQuality = function(){
Expand Down

0 comments on commit 496248c

Please sign in to comment.