Skip to content

Commit

Permalink
DROOLS-4164 DMN invoke PMML noname NN (and DROOLS-4157 ) (apache#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarilabs authored and mariofusco committed Jun 13, 2019
1 parent 5535072 commit b916d7d
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 11 deletions.
Expand Up @@ -275,7 +275,7 @@ protected static URL pmmlImportURL(ClassLoader classLoader, DMNModelImpl model,
URL pmmlURL = null;
try {
URI resolveRelativeURI = DMNCompilerImpl.resolveRelativeURI(model, locationURI);
pmmlURL = resolveRelativeURI.isAbsolute() ? resolveRelativeURI.toURL() : classLoader.getResource(resolveRelativeURI.toString());
pmmlURL = resolveRelativeURI.isAbsolute() ? resolveRelativeURI.toURL() : classLoader.getResource(resolveRelativeURI.getPath());
} catch (URISyntaxException | IOException e) {
new PMMLImportErrConsumer(model, i, node).accept(e);
}
Expand All @@ -284,13 +284,14 @@ protected static URL pmmlImportURL(ClassLoader classLoader, DMNModelImpl model,
}

protected static URI resolveRelativeURI(DMNModelImpl model, String relative) throws URISyntaxException, IOException {
URI relativeAsURI = new URI(null, null, relative, null);
if (model.getResource() instanceof FileSystemResource) {
FileSystemResource fsr = (FileSystemResource) model.getResource();
URI resolve = fsr.getURL().toURI().resolve(relative);
URI resolve = fsr.getURL().toURI().resolve(relativeAsURI);
return resolve;
} else {
URI dmnModelURI = new URI(model.getResource().getSourcePath());
URI relativeURI = dmnModelURI.resolve(relative);
URI dmnModelURI = new URI(null, null, model.getResource().getSourcePath(), null);
URI relativeURI = dmnModelURI.resolve(relativeAsURI);
return relativeURI;
}
}
Expand Down
Expand Up @@ -463,10 +463,15 @@ private DMNExpressionEvaluator compileFunctionDefinitionPMML(DMNCompilerContext
String pmmlModel = null;
for (ContextEntry ce : context.getContextEntry()) {
if (ce.getVariable() != null && ce.getVariable().getName() != null && ce.getExpression() instanceof LiteralExpression) {
LiteralExpression ceLitExpr = (LiteralExpression) ce.getExpression();
if (ce.getVariable().getName().equals("document")) {
pmmlDocument = stripQuotes(((LiteralExpression) ce.getExpression()).getText().trim());
if (ceLitExpr.getText() != null) {
pmmlDocument = stripQuotes(ceLitExpr.getText().trim());
}
} else if (ce.getVariable().getName().equals("model")) {
pmmlModel = stripQuotes(((LiteralExpression) ce.getExpression()).getText().trim());
if (ceLitExpr.getText() != null) {
pmmlModel = stripQuotes(ceLitExpr.getText().trim());
}
}
}
}
Expand Down
Expand Up @@ -28,8 +28,8 @@ public class DMNPMMLModelInfo extends PMMLModelInfo {

private final Map<String, DMNType> inputFields;

public DMNPMMLModelInfo(String name, Map<String, DMNType> inputFields, Collection<String> outputFields) {
super(name, inputFields.keySet(), outputFields);
public DMNPMMLModelInfo(String name, Map<String, DMNType> inputFields, Collection<String> targetFields, Collection<String> outputFields) {
super(name, inputFields.keySet(), targetFields, outputFields);
this.inputFields = Collections.unmodifiableMap(new HashMap<>(inputFields));
}

Expand All @@ -38,7 +38,7 @@ public static DMNPMMLModelInfo from(PMMLModelInfo info, DMNModelImpl model) {
for (String name : info.inputFieldNames) {
inputFields.put(name, model.getTypeRegistry().unknown());
}
return new DMNPMMLModelInfo(info.name, inputFields, info.outputFieldNames);
return new DMNPMMLModelInfo(info.name, inputFields, info.targetFieldNames, info.outputFieldNames);
}

public Map<String, DMNType> getInputFields() {
Expand Down
Expand Up @@ -53,9 +53,14 @@ public static PMMLInfo<PMMLModelInfo> from(InputStream is) throws SAXException,
.stream()
.filter(mf -> mf.getUsageType() == UsageType.ACTIVE)
.forEach(fn -> inputFields.add(fn.getName().getValue()));
Collection<String> targetFields = new ArrayList<>();
miningSchema.getMiningFields()
.stream()
.filter(mf -> mf.getUsageType() == UsageType.PREDICTED)
.forEach(fn -> targetFields.add(fn.getName().getValue()));
Collection<String> outputFields = new ArrayList<>();
pm.getOutput().getOutputFields().forEach(of -> outputFields.add(of.getName().getValue()));
models.add(new PMMLModelInfo(pm.getModelName(), inputFields, outputFields));
models.add(new PMMLModelInfo(pm.getModelName(), inputFields, targetFields, outputFields));
}
Map<String, String> headerExtensions = new HashMap<>();
for (Extension ex : pmml.getHeader().getExtensions()) {
Expand Down
Expand Up @@ -25,10 +25,12 @@ public class PMMLModelInfo {
protected final String name;
protected final Collection<String> inputFieldNames;
protected final Collection<String> outputFieldNames;
protected final Collection<String> targetFieldNames;

public PMMLModelInfo(String name, Collection<String> inputFieldNames, Collection<String> outputFieldNames) {
public PMMLModelInfo(String name, Collection<String> inputFieldNames, Collection<String> targetFieldNames, Collection<String> outputFieldNames) {
this.name = name;
this.inputFieldNames = Collections.unmodifiableList(new ArrayList<>(inputFieldNames));
this.targetFieldNames = Collections.unmodifiableList(new ArrayList<>(targetFieldNames));
this.outputFieldNames = Collections.unmodifiableList(new ArrayList<>(outputFieldNames));
}

Expand All @@ -44,4 +46,8 @@ public Collection<String> getOutputFieldNames() {
return outputFieldNames;
}

public Collection<String> getTargetFieldNames() {
return targetFieldNames;
}

}
Expand Up @@ -43,5 +43,7 @@ public void testPMMLInfo() throws Exception {
is("occupation"),
is("residenceState"),
is("validLicense")));
assertThat(m0.getTargetFieldNames(), containsInAnyOrder(is("overallScore")));
assertThat(m0.getOutputFieldNames(), containsInAnyOrder(is("calculatedScore")));
}
}

0 comments on commit b916d7d

Please sign in to comment.