diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index 16d64385530f4..ec3df1fa96347 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -44,6 +44,12 @@ include::common-options.asciidoc[] Specifies the field to which the inference prediction is written. Defaults to `predicted_value`. +`num_top_feature_importance_values`:::: +(Optional, integer) +Specifies the maximum number of +{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature +importance] values per document. By default, it is zero and no feature importance +calculation occurs. [discrete] [[inference-processor-classification-opt]] @@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0. Specifies the field to which the top classes are written. Defaults to `top_classes`. +`num_top_feature_importance_values`:::: +(Optional, integer) +Specifies the maximum number of +{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature +importance] values per document. By default, it is zero and no feature importance +calculation occurs. [discrete] [[inference-processor-config-example]] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index cf7a8b7d224c5..f15d8296bd531 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -32,6 +32,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -73,6 +74,7 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, private final TrainedModel trainedModel; private final List preProcessors; + private Map decoderMap; private TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) { this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL); @@ -115,13 +117,35 @@ public List getPreProcessors() { return preProcessors; } - private void preProcess(Map fields) { + void preProcess(Map fields) { preProcessors.forEach(preProcessor -> preProcessor.process(fields)); } public InferenceResults infer(Map fields, InferenceConfig config) { preProcess(fields); - return trainedModel.infer(fields, config); + if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) { + throw ExceptionsHelper.badRequestException( + "Feature importance is not supported for the configured model of type [{}]", + trainedModel.getName()); + } + return trainedModel.infer(fields, + config, + config.requestingImportance() ? getDecoderMap() : Collections.emptyMap()); + } + + private Map getDecoderMap() { + if (decoderMap != null) { + return decoderMap; + } + synchronized (this) { + if (decoderMap != null) { + return decoderMap; + } + this.decoderMap = preProcessors.stream() + .map(PreProcessor::reverseLookup) + .collect(HashMap::new, Map::putAll, Map::putAll); + return decoderMap; + } } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 960fdbd86e2dc..c98a8f9c04d55 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -235,6 +236,11 @@ public void process(Map fields) { fields.put(destField, concatEmbeddings(processedFeatures)); } + @Override + public Map reverseLookup() { + return Collections.singletonMap(destField, fieldName); + } + @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index e9606d53ae27d..258b80dd7158e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -97,6 +97,11 @@ public String getFeatureName() { return featureName; } + @Override + public Map reverseLookup() { + return Collections.singletonMap(featureName, field); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index 9bb2537b61ed3..2e73da8a20913 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -18,8 +18,10 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; /** * PreProcessor for one hot encoding a set of categorical values for a given field. @@ -80,6 +82,11 @@ public Map getHotMap() { return hotMap; } + @Override + public Map reverseLookup() { + return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field)); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java index f5c2ff7398068..8a29875cc1e75 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java @@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou * @param fields The fields and their values to process */ void process(Map fields); + + /** + * @return Reverse lookup map to match resulting features to their original feature name + */ + Map reverseLookup(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 19c3cadbbef95..3902f3837b132 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -108,6 +108,11 @@ public String getFeatureName() { return featureName; } + @Override + public Map reverseLookup() { + return Collections.singletonMap(featureName, field); + } + @Override public String getName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 39ae4057fd9ca..a354ae74bf23c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -35,9 +35,25 @@ public ClassificationInferenceResults(double value, String classificationLabel, List topClasses, InferenceConfig config) { - super(value); - assert config instanceof ClassificationConfig; - ClassificationConfig classificationConfig = (ClassificationConfig)config; + this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config); + } + + public ClassificationInferenceResults(double value, + String classificationLabel, + List topClasses, + Map featureImportance, + InferenceConfig config) { + this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config); + } + + private ClassificationInferenceResults(double value, + String classificationLabel, + List topClasses, + Map featureImportance, + ClassificationConfig classificationConfig) { + super(value, + SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, + classificationConfig.getNumTopFeatureImportanceValues())); this.classificationLabel = classificationLabel; this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); this.topNumClassesField = classificationConfig.getTopClassesResultsField(); @@ -74,16 +90,17 @@ public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } ClassificationInferenceResults that = (ClassificationInferenceResults) object; - return Objects.equals(value(), that.value()) && - Objects.equals(classificationLabel, that.classificationLabel) && - Objects.equals(resultsField, that.resultsField) && - Objects.equals(topNumClassesField, that.topNumClassesField) && - Objects.equals(topClasses, that.topClasses); + return Objects.equals(value(), that.value()) + && Objects.equals(classificationLabel, that.classificationLabel) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(topNumClassesField, that.topNumClassesField) + && Objects.equals(topClasses, that.topClasses) + && Objects.equals(getFeatureImportance(), that.getFeatureImportance()); } @Override public int hashCode() { - return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField); + return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance()); } @Override @@ -100,6 +117,9 @@ public void writeResult(IngestDocument document, String parentResultField) { document.setFieldValue(parentResultField + "." + topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); } + if (getFeatureImportance().size() > 0) { + document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()); + } } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java index 6525908af3acd..add8399e89d00 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -10,18 +10,19 @@ import org.elasticsearch.ingest.IngestDocument; import java.io.IOException; +import java.util.Map; import java.util.Objects; public class RawInferenceResults extends SingleValueInferenceResults { public static final String NAME = "raw"; - public RawInferenceResults(double value) { - super(value); + public RawInferenceResults(double value, Map featureImportance) { + super(value, featureImportance); } public RawInferenceResults(StreamInput in) throws IOException { - super(in.readDouble()); + super(in); } @Override @@ -34,12 +35,13 @@ public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } RawInferenceResults that = (RawInferenceResults) object; - return Objects.equals(value(), that.value()); + return Objects.equals(value(), that.value()) + && Objects.equals(getFeatureImportance(), that.getFeatureImportance()); } @Override public int hashCode() { - return Objects.hash(value()); + return Objects.hash(value(), getFeatureImportance()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index 8aade6337f5a6..a0647b8dffa2d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import java.util.Objects; public class RegressionInferenceResults extends SingleValueInferenceResults { @@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { private final String resultsField; public RegressionInferenceResults(double value, InferenceConfig config) { - super(value); - assert config instanceof RegressionConfig; - RegressionConfig regressionConfig = (RegressionConfig)config; + this(value, (RegressionConfig) config, Collections.emptyMap()); + } + + public RegressionInferenceResults(double value, InferenceConfig config, Map featureImportance) { + this(value, (RegressionConfig)config, featureImportance); + } + + private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map featureImportance) { + super(value, + SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, + regressionConfig.getNumTopFeatureImportanceValues())); this.resultsField = regressionConfig.getResultsField(); } public RegressionInferenceResults(StreamInput in) throws IOException { - super(in.readDouble()); + super(in); this.resultsField = in.readString(); } @@ -44,12 +54,14 @@ public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } RegressionInferenceResults that = (RegressionInferenceResults) object; - return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField); + return Objects.equals(value(), that.value()) + && Objects.equals(this.resultsField, that.resultsField) + && Objects.equals(this.getFeatureImportance(), that.getFeatureImportance()); } @Override public int hashCode() { - return Objects.hash(value(), resultsField); + return Objects.hash(value(), resultsField, getFeatureImportance()); } @Override @@ -57,6 +69,9 @@ public void writeResult(IngestDocument document, String parentResultField) { ExceptionsHelper.requireNonNull(document, "document"); ExceptionsHelper.requireNonNull(parentResultField, "parentResultField"); document.setFieldValue(parentResultField + "." + this.resultsField, value()); + if (getFeatureImportance().size() > 0) { + document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()); + } } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java index a93f2b0f56b6a..cd739b557abea 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -5,27 +5,51 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; public abstract class SingleValueInferenceResults implements InferenceResults { private final double value; + private final Map featureImportance; + + static Map takeTopFeatureImportances(Map unsortedFeatureImportances, int numTopFeatures) { + return unsortedFeatureImportances.entrySet() + .stream() + .sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue()))) + .limit(numTopFeatures) + .collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll); + } SingleValueInferenceResults(StreamInput in) throws IOException { value = in.readDouble(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); + } else { + this.featureImportance = Collections.emptyMap(); + } } - SingleValueInferenceResults(double value) { + SingleValueInferenceResults(double value, Map featureImportance) { this.value = value; + this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance"); } public Double value() { return value; } + public Map getFeatureImportance() { + return featureImportance; + } + public String valueAsString() { return String.valueOf(value); } @@ -33,6 +57,9 @@ public String valueAsString() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeDouble(value); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 2bf8f825c21b8..1aa8c816ccba8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -31,33 +31,39 @@ public class ClassificationConfig implements InferenceConfig { public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; - public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD); + public static ClassificationConfig EMPTY_PARAMS = + new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null); private final int numTopClasses; private final String topClassesResultsField; private final String resultsField; + private final int numTopFeatureImportanceValues; public static ClassificationConfig fromMap(Map map) { Map options = new HashMap<>(map); Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); + Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + if (options.isEmpty() == false) { throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); } - return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField); + return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance); } private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig( - (Integer) args[0], (String) args[1], (String) args[2])); + (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3])); static { PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD); + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); } public static ClassificationConfig fromXContent(XContentParser parser) { @@ -65,19 +71,33 @@ public static ClassificationConfig fromXContent(XContentParser parser) { } public ClassificationConfig(Integer numTopClasses) { - this(numTopClasses, null, null); + this(numTopClasses, null, null, null); } public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) { + this(numTopClasses, resultsField, topClassesResultsField, 0); + } + + public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField; this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; + if (featureImportance != null && featureImportance < 0) { + throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + + "] must be greater than or equal to 0"); + } + this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance; } public ClassificationConfig(StreamInput in) throws IOException { this.numTopClasses = in.readInt(); this.topClassesResultsField = in.readString(); this.resultsField = in.readString(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + this.numTopFeatureImportanceValues = in.readVInt(); + } else { + this.numTopFeatureImportanceValues = 0; + } } public int getNumTopClasses() { @@ -92,11 +112,23 @@ public String getResultsField() { return resultsField; } + public int getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + @Override + public boolean requestingImportance() { + return numTopFeatureImportanceValues > 0; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeInt(numTopClasses); out.writeString(topClassesResultsField); out.writeString(resultsField); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeVInt(numTopFeatureImportanceValues); + } } @Override @@ -104,14 +136,15 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ClassificationConfig that = (ClassificationConfig) o; - return Objects.equals(numTopClasses, that.numTopClasses) && - Objects.equals(topClassesResultsField, that.topClassesResultsField) && - Objects.equals(resultsField, that.resultsField); + return Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(topClassesResultsField, that.topClassesResultsField) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); } @Override public int hashCode() { - return Objects.hash(numTopClasses, topClassesResultsField, resultsField); + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); } @Override @@ -122,6 +155,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + if (numTopFeatureImportanceValues > 0) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } builder.endObject(); return builder; } @@ -143,7 +179,7 @@ public boolean isTargetTypeSupported(TargetType targetType) { @Override public Version getMinimalSupportedVersion() { - return MIN_SUPPORTED_VERSION; + return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index 5d1dc7983ff3c..44985d5465182 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -18,4 +18,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable { * All nodes in the cluster must be at least this version */ Version getMinimalSupportedVersion(); + + default boolean requestingImportance() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index ae5a4062a69dc..74790a693eb15 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -13,7 +13,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -98,4 +100,19 @@ public static Double toDouble(Object value) { } return null; } + + public static Map decodeFeatureImportances(Map processedFeatureToOriginalFeatureMap, + Map featureImportances) { + if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) { + return featureImportances; + } + + Map originalFeatureImportance = new HashMap<>(); + featureImportances.forEach((feature, importance) -> { + String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature); + originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance); + }); + + return originalFeatureImportance; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java index b7c4a71b3e79e..335961e4e4cb4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java @@ -16,9 +16,12 @@ */ public class NullInferenceConfig implements InferenceConfig { - public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); + private final boolean requestingFeatureImportance; - private NullInferenceConfig() { } + + public NullInferenceConfig(boolean requestingFeatureImportance) { + this.requestingFeatureImportance = requestingFeatureImportance; + } @Override public boolean isTargetTypeSupported(TargetType targetType) { @@ -37,6 +40,7 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Unable to serialize NullInferenceConfig objects"); } @Override @@ -46,6 +50,11 @@ public String getName() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder; + throw new UnsupportedOperationException("Unable to write xcontent from NullInferenceConfig objects"); + } + + @Override + public boolean requestingImportance() { + return requestingFeatureImportance; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index 6d128a23b05c1..4c8244c734cc3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -26,24 +26,27 @@ public class RegressionConfig implements InferenceConfig { public static final ParseField NAME = new ParseField("regression"); private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); private static final String DEFAULT_RESULTS_FIELD = "predicted_value"; - public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD); + public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null); public static RegressionConfig fromMap(Map map) { Map options = new HashMap<>(map); String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); + Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); if (options.isEmpty() == false) { throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); } - return new RegressionConfig(resultsField); + return new RegressionConfig(resultsField, featureImportance); } private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0])); + new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1])); static { PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); } public static RegressionConfig fromXContent(XContentParser parser) { @@ -51,19 +54,43 @@ public static RegressionConfig fromXContent(XContentParser parser) { } private final String resultsField; + private final int numTopFeatureImportanceValues; public RegressionConfig(String resultsField) { + this(resultsField, 0); + } + + public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) { this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + + "] must be greater than or equal to 0"); + } + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues; } public RegressionConfig(StreamInput in) throws IOException { this.resultsField = in.readString(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + this.numTopFeatureImportanceValues = in.readVInt(); + } else { + this.numTopFeatureImportanceValues = 0; + } + } + + public int getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; } public String getResultsField() { return resultsField; } + @Override + public boolean requestingImportance() { + return numTopFeatureImportanceValues > 0; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -72,6 +99,9 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(resultsField); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeVInt(numTopFeatureImportanceValues); + } } @Override @@ -83,6 +113,9 @@ public String getName() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + if (numTopFeatureImportanceValues > 0) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } builder.endObject(); return builder; } @@ -92,12 +125,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RegressionConfig that = (RegressionConfig)o; - return Objects.equals(this.resultsField, that.resultsField); + return Objects.equals(this.resultsField, that.resultsField) + && Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); } @Override public int hashCode() { - return Objects.hash(resultsField); + return Objects.hash(resultsField, numTopFeatureImportanceValues); } @Override @@ -107,7 +141,7 @@ public boolean isTargetTypeSupported(TargetType targetType) { @Override public Version getMinimalSupportedVersion() { - return MIN_SUPPORTED_VERSION; + return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ShapPath.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ShapPath.java new file mode 100644 index 0000000000000..9b2b844c304ae --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ShapPath.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + + +/** + * Ported from https://github.com/elastic/ml-cpp/blob/master/include/maths/CTreeShapFeatureImportance.h Path struct + */ +public class ShapPath { + private static final double DBL_EPSILON = Double.MIN_VALUE; + + private final PathElement[] pathElements; + private final double[] scale; + private final int elementAndScaleOffset; + + public ShapPath(ShapPath parentPath, int nextIndex) { + this.elementAndScaleOffset = parentPath.elementAndScaleOffset + nextIndex; + this.pathElements = parentPath.pathElements; + this.scale = parentPath.scale; + for (int i = 0; i < nextIndex; i++) { + pathElements[elementAndScaleOffset + i].featureIndex = parentPath.getElement(i).featureIndex; + pathElements[elementAndScaleOffset + i].fractionZeros = parentPath.getElement(i).fractionZeros; + pathElements[elementAndScaleOffset + i].fractionOnes = parentPath.getElement(i).fractionOnes; + scale[elementAndScaleOffset + i] = parentPath.getScale(i); + } + } + + public ShapPath(PathElement[] elements, double[] scale) { + this.pathElements = elements; + this.scale = scale; + this.elementAndScaleOffset = 0; + } + + // Update binomial coefficients to be able to compute Equation (2) from the paper. In particular, + // we have in the line path.scale[i + 1] += fractionOne * path.scale[i] * (i + 1.0) / (pathDepth + + // 1.0) that if we're on the "one" path, i.e. if the last feature selects this path if we include that + // feature in S (then fractionOne is 1), and we need to consider all the additional ways we now have of + // constructing each S of each given cardinality i + 1. Each of these come by adding the last feature + // to sets of size i and we **also** need to scale by the difference in binomial coefficients as both M + // increases by one and i increases by one. So we get additive term 1{last feature selects path if in S} + // * scale(i) * (i+1)! (M+1-(i+1)-1)!/(M+1)! / (i! (M-i-1)!/ M!), whence += scale(i) * (i+1) / (M+1). + public int extend(double fractionZero, double fractionOne, int featureIndex, int nextIndex) { + setValues(nextIndex, fractionOne, fractionZero, featureIndex); + setScale(nextIndex, nextIndex == 0 ? 1.0 : 0.0); + double stepDown = fractionOne / (double)(nextIndex + 1); + double stepUp = fractionZero / (double)(nextIndex + 1); + double countDown = nextIndex * stepDown; + double countUp = stepUp; + for (int i = (nextIndex - 1); i >= 0; --i, countDown -= stepDown, countUp += stepUp) { + setScale(i + 1, getScale(i + 1) + getScale(i) * countDown); + setScale(i, getScale(i) * countUp); + } + return nextIndex + 1; + } + + public double sumUnwoundPath(int pathIndex, int nextIndex) { + double total = 0.0; + int pathDepth = nextIndex - 1; + double nextFractionOne = getScale(pathDepth); + double fractionOne = fractionOnes(pathIndex); + double fractionZero = fractionZeros(pathIndex); + if (fractionOne != 0) { + double pD = pathDepth + 1; + double stepUp = fractionZero / pD; + double stepDown = fractionOne / pD; + double countUp = stepUp; + double countDown = (pD - 1.0) * stepDown; + for (int i = pathDepth - 1; i >= 0; --i, countUp += stepUp, countDown -= stepDown) { + double tmp = nextFractionOne / countDown; + nextFractionOne = getScale(i) - tmp * countUp; + total += tmp; + } + } else { + double pD = pathDepth; + + for(int i = 0; i < pathDepth; i++) { + total += getScale(i) / pD--; + } + total *= (pathDepth + 1) / (fractionZero + DBL_EPSILON); + } + + return total; + } + + public int unwind(int pathIndex, int nextIndex) { + int pathDepth = nextIndex - 1; + double nextFractionOne = getScale(pathDepth); + double fractionOne = fractionOnes(pathIndex); + double fractionZero = fractionZeros(pathIndex); + + if (fractionOne != 0) { + double stepUp = fractionZero / (double)(pathDepth + 1); + double stepDown = fractionOne / (double)nextIndex; + double countUp = 0.0; + double countDown = nextIndex * stepDown; + for (int i = pathDepth; i >= 0; --i, countUp += stepUp, countDown -= stepDown) { + double tmp = nextFractionOne / countDown; + nextFractionOne = getScale(i) - tmp * countUp; + setScale(i, tmp); + } + } else { + double stepDown = (fractionZero + DBL_EPSILON) / (double)(pathDepth + 1); + double countDown = pathDepth * stepDown; + for (int i = 0; i <= pathDepth; ++i, countDown -= stepDown) { + setScale(i, getScale(i) / countDown); + } + } + for (int i = pathIndex; i < pathDepth; ++i) { + PathElement element = getElement(i + 1); + setValues(i, element.fractionOnes, element.fractionZeros, element.featureIndex); + } + return nextIndex - 1; + } + + private void setValues(int index, double fractionOnes, double fractionZeros, int featureIndex) { + pathElements[index + elementAndScaleOffset].fractionOnes = fractionOnes; + pathElements[index + elementAndScaleOffset].fractionZeros = fractionZeros; + pathElements[index + elementAndScaleOffset].featureIndex = featureIndex; + } + + private double getScale(int offset) { + return scale[offset + elementAndScaleOffset]; + } + + private void setScale(int offset, double value) { + scale[offset + elementAndScaleOffset] = value; + } + + public double fractionOnes(int pathIndex) { + return pathElements[pathIndex + elementAndScaleOffset].fractionOnes; + } + + public double fractionZeros(int pathIndex) { + return pathElements[pathIndex + elementAndScaleOffset].fractionZeros; + } + + public int findFeatureIndex(int splitFeature, int nextIndex) { + for (int i = elementAndScaleOffset; i < elementAndScaleOffset + nextIndex; i++) { + if (pathElements[i].featureIndex == splitFeature) { + return i - elementAndScaleOffset; + } + } + return -1; + } + + public int featureIndex(int pathIndex) { + return pathElements[pathIndex + elementAndScaleOffset].featureIndex; + } + + private PathElement getElement(int offset) { + return pathElements[offset + elementAndScaleOffset]; + } + + public static final class PathElement { + private double fractionOnes = 1.0; + private double fractionZeros = 1.0; + private int featureIndex = -1; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index 4bbca5ed0b1d5..6534766c65f5b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -17,12 +18,16 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou /** * Infer against the provided fields * + * NOTE: Must be thread safe + * * @param fields The fields and their values to infer against * @param config The configuration options for inference + * @param featureDecoderMap A map for decoding feature value names to their originating feature. + * Necessary for feature influence. * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - InferenceResults infer(Map fields, InferenceConfig config); + InferenceResults infer(Map fields, InferenceConfig config, @Nullable Map featureDecoderMap); /** * @return {@link TargetType} for the model. @@ -42,4 +47,19 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou * @return The estimated number of operations required at inference time */ long estimatedNumOperations(); + + /** + * @return Does the model support feature importance + */ + boolean supportsFeatureImportance(); + + /** + * Calculates the importance of each feature reference by the model for the passed in field values + * + * NOTE: Must be thread safe + * @param fields The fields inferring against + * @param featureDecoder A Map translating processed feature names to their original feature names + * @return A {@code Map} mapping each featureName to its importance + */ + Map featureImportance(Map fields, Map featureDecoder); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 8539d0fcc5f9a..0ff88ca1c3b13 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -37,6 +37,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -133,18 +134,25 @@ public Ensemble(StreamInput in) throws IOException { } @Override - public InferenceResults infer(Map fields, InferenceConfig config) { + public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } - List inferenceResults = this.models.stream().map(model -> { - InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE); - assert results instanceof SingleValueInferenceResults; - return ((SingleValueInferenceResults)results).value(); - }).collect(Collectors.toList()); + List inferenceResults = new ArrayList<>(this.models.size()); + List> featureInfluence = new ArrayList<>(); + NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance()); + this.models.forEach(model -> { + InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap()); + assert result instanceof SingleValueInferenceResults; + SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result; + inferenceResults.add(inferenceResult.value()); + if (config.requestingImportance()) { + featureInfluence.add(inferenceResult.getFeatureImportance()); + } + }); List processed = outputAggregator.processValues(inferenceResults); - return buildResults(processed, config); + return buildResults(processed, featureInfluence, config, featureDecoderMap); } @Override @@ -152,14 +160,20 @@ public TargetType targetType() { return targetType; } - private InferenceResults buildResults(List processedInferences, InferenceConfig config) { + private InferenceResults buildResults(List processedInferences, + List> featureInfluence, + InferenceConfig config, + Map featureDecoderMap) { // Indicates that the config is useless and the caller just wants the raw value if (config instanceof NullInferenceConfig) { - return new RawInferenceResults(outputAggregator.aggregate(processedInferences)); + return new RawInferenceResults(outputAggregator.aggregate(processedInferences), + InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence))); } switch(targetType) { case REGRESSION: - return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config); + return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), + config, + InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence))); case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; assert classificationWeights == null || processedInferences.size() == classificationWeights.length; @@ -172,6 +186,7 @@ private InferenceResults buildResults(List processedInferences, Inferenc return new ClassificationInferenceResults((double)topClasses.v1(), classificationLabel(topClasses.v1(), classificationLabels), topClasses.v2(), + InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)), config); default: throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); @@ -293,6 +308,27 @@ public long estimatedNumOperations() { return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1); } + @Override + public boolean supportsFeatureImportance() { + return models.stream().allMatch(TrainedModel::supportsFeatureImportance); + } + + Map featureImportance(Map fields) { + return featureImportance(fields, Collections.emptyMap()); + } + + @Override + public Map featureImportance(Map fields, Map featureDecoder) { + Map collapsed = mergeFeatureImportances(models.stream() + .map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap())) + .collect(Collectors.toList())); + return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed); + } + + private static Map mergeFeatureImportances(List> featureImportances) { + return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll); + } + public static Builder builder() { return new Builder(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index 7de2c8f060500..ebad13530df2c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -104,7 +105,11 @@ public LangIdentNeuralNetwork(StreamInput in) throws IOException { } @Override - public InferenceResults infer(Map fields, InferenceConfig config) { + public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { + if (config.requestingImportance()) { + throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance", + NAME.getPreferredName()); + } if (config instanceof ClassificationConfig == false) { throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName()); @@ -138,6 +143,7 @@ public InferenceResults infer(Map fields, InferenceConfig config return new ClassificationInferenceResults(topClasses.v1(), LANGUAGE_NAMES.get(topClasses.v1()), topClasses.v2(), + Collections.emptyMap(), classificationConfig); } @@ -159,6 +165,16 @@ public long estimatedNumOperations() { return numOps; } + @Override + public boolean supportsFeatureImportance() { + return false; + } + + @Override + public Map featureImportance(Map fields, Map featureDecoder) { + throw new UnsupportedOperationException("[lang_ident] does not support feature importance"); + } + @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index ff7fbf813db67..db2dea9855bcf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -44,6 +45,7 @@ import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; @@ -86,6 +88,9 @@ public static Tree fromXContentLenient(XContentParser parser) { private final TargetType targetType; private final List classificationLabels; private final CachedSupplier highestOrderCategory; + // populated lazily when feature importance is calculated + private double[] nodeEstimates; + private Integer maxDepth; Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); @@ -120,7 +125,7 @@ public List getNodes() { } @Override - public InferenceResults infer(Map fields, InferenceConfig config) { + public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); @@ -129,21 +134,23 @@ public InferenceResults infer(Map fields, InferenceConfig config List features = featureNames.stream() .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields))) .collect(Collectors.toList()); - return infer(features, config); - } - private InferenceResults infer(List features, InferenceConfig config) { + Map featureImportance = config.requestingImportance() ? + featureImportance(features, featureDecoderMap) : + Collections.emptyMap(); + TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return buildResult(node.getLeafValue(), config); + + return buildResult(node.getLeafValue(), featureImportance, config); } - private InferenceResults buildResult(Double value, InferenceConfig config) { + private InferenceResults buildResult(Double value, Map featureImportance, InferenceConfig config) { // Indicates that the config is useless and the caller just wants the raw value if (config instanceof NullInferenceConfig) { - return new RawInferenceResults(value); + return new RawInferenceResults(value, featureImportance); } switch (targetType) { case CLASSIFICATION: @@ -156,9 +163,10 @@ private InferenceResults buildResult(Double value, InferenceConfig config) { return new ClassificationInferenceResults(value, classificationLabel(topClasses.v1(), classificationLabels), topClasses.v2(), + featureImportance, config); case REGRESSION: - return new RegressionInferenceResults(value, config); + return new RegressionInferenceResults(value, config, featureImportance); default: throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); } @@ -192,7 +200,6 @@ private List classificationProbability(double inferenceValue) { // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); - // TODO, eventually have TreeNodes contain confidence levels list.set(Double.valueOf(inferenceValue).intValue(), 1.0); return list; } @@ -263,12 +270,138 @@ public void validate() { detectCycle(); } + @Override + public Map featureImportance(Map fields, Map featureDecoder) { + if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) { + throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance"); + } + List features = featureNames.stream() + .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields))) + .collect(Collectors.toList()); + return featureImportance(features, featureDecoder); + } + + private Map featureImportance(List fieldValues, Map featureDecoder) { + calculateNodeEstimatesIfNeeded(); + double[] featureImportance = new double[fieldValues.size()]; + int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2; + ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize]; + for (int i = 0; i < arrSize; i++) { + elements[i] = new ShapPath.PathElement(); + } + double[] scale = new double[arrSize]; + ShapPath initialPath = new ShapPath(elements, scale); + shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0); + return InferenceHelpers.decodeFeatureImportances(featureDecoder, + IntStream.range(0, featureImportance.length) + .boxed() + .collect(Collectors.toMap(featureNames::get, i -> featureImportance[i]))); + } + + private void calculateNodeEstimatesIfNeeded() { + if (this.nodeEstimates != null && this.maxDepth != null) { + return; + } + synchronized (this) { + if (this.nodeEstimates != null && this.maxDepth != null) { + return; + } + double[] estimates = new double[nodes.size()]; + this.maxDepth = fillNodeEstimates(estimates, 0, 0); + this.nodeEstimates = estimates; + } + } + + /** + * Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc + * + * If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native + * side first and then ported to the Java side. + */ + private void shapRecursive(List processedFeatures, + double[] nodeValues, + ShapPath parentSplitPath, + int nodeIndex, + double parentFractionZero, + double parentFractionOne, + int parentFeatureIndex, + double[] featureImportance, + int nextIndex) { + ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex); + TreeNode currNode = nodes.get(nodeIndex); + nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex); + if (currNode.isLeaf()) { + // TODO multi-value???? + double leafValue = nodeValues[nodeIndex]; + for (int i = 1; i < nextIndex; ++i) { + double scale = splitPath.sumUnwoundPath(i, nextIndex); + int inputColumnIndex = splitPath.featureIndex(i); + featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue; + } + } else { + int hotIndex = currNode.compare(processedFeatures); + int coldIndex = hotIndex == currNode.getLeftChild() ? currNode.getRightChild() : currNode.getLeftChild(); + + double incomingFractionZero = 1.0; + double incomingFractionOne = 1.0; + int splitFeature = currNode.getSplitFeature(); + int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex); + if (pathIndex > -1) { + incomingFractionZero = splitPath.fractionZeros(pathIndex); + incomingFractionOne = splitPath.fractionOnes(pathIndex); + nextIndex = splitPath.unwind(pathIndex, nextIndex); + } + + double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples(); + double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples(); + shapRecursive(processedFeatures, nodeValues, splitPath, + hotIndex, incomingFractionZero * hotFractionZero, + incomingFractionOne, splitFeature, featureImportance, nextIndex); + shapRecursive(processedFeatures, nodeValues, splitPath, + coldIndex, incomingFractionZero * coldFractionZero, + 0.0, splitFeature, featureImportance, nextIndex); + } + } + + /** + * This recursively populates the provided {@code double[]} with the node estimated values + * + * Used when calculating feature importance. + * @param nodeEstimates Array to update in place with the node estimated values + * @param nodeIndex Current node index + * @param depth Current depth + * @return The current max depth + */ + private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) { + TreeNode node = nodes.get(nodeIndex); + if (node.isLeaf()) { + nodeEstimates[nodeIndex] = node.getLeafValue(); + return 0; + } + + int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1); + int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1); + long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples(); + long rightWeight = nodes.get(node.getRightChild()).getNumberSamples(); + long divisor = leftWeight + rightWeight; + double averageValue = divisor == 0 ? + 0.0 : + (leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor; + nodeEstimates[nodeIndex] = averageValue; + return Math.max(depthLeft, depthRight) + 1; + } + @Override public long estimatedNumOperations() { // Grabbing the features from the doc + the depth of the tree return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size(); } + @Override + public boolean supportsFeatureImportance() { + return true; + } + /** * The highest index of a feature used any of the nodes. * If no nodes use a feature return -1. This can only happen diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 766fb3d56421d..188a2018ffdf2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -342,8 +342,7 @@ public void testMultiClassIrisInference() throws IOException { }}; assertThat( - ((ClassificationInferenceResults)definition.getTrainedModel() - .infer(fields, ClassificationConfig.EMPTY_PARAMS)) + ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS)) .getClassificationLabel(), equalTo("Iris-setosa")); @@ -354,8 +353,7 @@ public void testMultiClassIrisInference() throws IOException { put("petal_width", 1.4); }}; assertThat( - ((ClassificationInferenceResults)definition.getTrainedModel() - .infer(fields, ClassificationConfig.EMPTY_PARAMS)) + ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS)) .getClassificationLabel(), equalTo("Iris-versicolor")); @@ -366,10 +364,8 @@ public void testMultiClassIrisInference() throws IOException { put("petal_width", 2.0); }}; assertThat( - ((ClassificationInferenceResults)definition.getTrainedModel() - .infer(fields, ClassificationConfig.EMPTY_PARAMS)) + ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS)) .getClassificationLabel(), equalTo("Iris-virginica")); } - } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java index d9d4e9933b24d..1ebf009add7a7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java @@ -8,10 +8,12 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + public class RawInferenceResultsTests extends AbstractWireSerializingTestCase { public static RawInferenceResults createRandomResults() { - return new RawInferenceResults(randomDouble()); + return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08)); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java index 550e49cdfe016..327cb70e654a7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -30,11 +30,12 @@ public void testFromMap() { ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS; assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected)); - expected = new ClassificationConfig(3, "foo", "bar"); + expected = new ClassificationConfig(3, "foo", "bar", 2); Map configMap = new HashMap<>(); configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); + configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2); assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index 65210e7e2699e..2f9cb59040e59 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -24,9 +24,10 @@ public static RegressionConfig randomRegressionConfig() { } public void testFromMap() { - RegressionConfig expected = new RegressionConfig("foo"); + RegressionConfig expected = new RegressionConfig("foo", 3); Map config = new HashMap<>(){{ put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo"); + put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3); }}; assertThat(RegressionConfig.fromMap(config), equalTo(expected)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index dc54c049cf872..f9345fd6e78dd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.junit.Before; import java.io.IOException; import java.util.ArrayList; @@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.equalTo; public class EnsembleTests extends AbstractSerializingTestCase { + private final double eps = 1.0E-8; private boolean lenient; @@ -267,7 +269,8 @@ public void testClassificationProbability() { List scores = Arrays.asList(0.230557435, 0.162032651); double eps = 0.000001; List probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); @@ -278,7 +281,8 @@ public void testClassificationProbability() { expected = Arrays.asList(0.310025518, 0.6899744811); scores = Arrays.asList(0.217017863, 0.2069923443); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); @@ -289,7 +293,8 @@ public void testClassificationProbability() { expected = Arrays.asList(0.768524783, 0.231475216); scores = Arrays.asList(0.230557435, 0.162032651); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); @@ -303,7 +308,8 @@ public void testClassificationProbability() { expected = Arrays.asList(0.6899744811, 0.3100255188); scores = Arrays.asList(0.482982136, 0.0930076556); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); @@ -361,24 +367,28 @@ public void testClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; assertThat(0.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); } public void testMultiClassClassificationInference() { @@ -432,24 +442,28 @@ public void testMultiClassClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(2.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.6); put("bar", null); }}; assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(), + 0.00001)); } public void testRegressionInference() { @@ -489,12 +503,16 @@ public void testRegressionInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(0.9, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.5, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -506,19 +524,25 @@ public void testRegressionInference() { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.8, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; assertThat(1.8, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); } public void testInferNestedFields() { @@ -564,7 +588,9 @@ public void testInferNestedFields() { }}); }}; assertThat(0.9, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); featureMap = new HashMap<>() {{ put("foo", new HashMap<>(){{ @@ -575,7 +601,9 @@ public void testInferNestedFields() { }}); }}; assertThat(0.5, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())) + .value(), + 0.00001)); } public void testOperationsEstimations() { @@ -590,6 +618,114 @@ public void testOperationsEstimations() { assertThat(ensemble.estimatedNumOperations(), equalTo(9L)); } + public void testFeatureImportance() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setNodes( + TreeNode.builder(0) + .setSplitFeature(0) + .setOperator(Operator.LT) + .setLeftChild(1) + .setRightChild(2) + .setThreshold(0.55) + .setNumberSamples(10L), + TreeNode.builder(1) + .setSplitFeature(0) + .setLeftChild(3) + .setRightChild(4) + .setOperator(Operator.LT) + .setThreshold(0.41) + .setNumberSamples(6L), + TreeNode.builder(2) + .setSplitFeature(1) + .setLeftChild(5) + .setRightChild(6) + .setOperator(Operator.LT) + .setThreshold(0.25) + .setNumberSamples(4L), + TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L), + TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L), + TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L), + TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build(); + + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setNodes( + TreeNode.builder(0) + .setSplitFeature(0) + .setOperator(Operator.LT) + .setLeftChild(1) + .setRightChild(2) + .setThreshold(0.45) + .setNumberSamples(10L), + TreeNode.builder(1) + .setSplitFeature(0) + .setLeftChild(3) + .setRightChild(4) + .setOperator(Operator.LT) + .setThreshold(0.25) + .setNumberSamples(5L), + TreeNode.builder(2) + .setSplitFeature(0) + .setLeftChild(5) + .setRightChild(6) + .setOperator(Operator.LT) + .setThreshold(0.59) + .setNumberSamples(5L), + TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L), + TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L), + TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L), + TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build(); + + Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum()) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .setFeatureNames(featureNames) + .build(); + + + Map featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9))); + assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps)); + assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8))); + assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps)); + assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7))); + assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps)); + assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6))); + assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps)); + assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5))); + assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps)); + assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4))); + assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps)); + assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3))); + assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps)); + assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2))); + assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps)); + assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1))); + assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps)); + assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps)); + + featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0))); + assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps)); + assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps)); + } + + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 32ca045d97e47..4e0fe560210da 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.junit.Before; import java.io.IOException; @@ -35,6 +36,7 @@ public class TreeTests extends AbstractSerializingTestCase { + private final double eps = 1.0E-8; private boolean lenient; @Before @@ -118,7 +120,8 @@ public void testInferWithStump() { List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump assertThat(42.0, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); } public void testInfer() { @@ -138,27 +141,31 @@ public void testInfer() { List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(0.3, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.2, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); // This should still work if the internal values are strings List featureVectorStrings = Arrays.asList("0.3", "0.9"); featureMap = zipObjMap(featureNames, featureVectorStrings); assertThat(0.2, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ @@ -166,7 +173,8 @@ public void testInfer() { put("bar", null); }}; assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); } public void testInferNestedFields() { @@ -192,7 +200,8 @@ public void testInferNestedFields() { }}); }}; assertThat(0.3, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left @@ -205,7 +214,8 @@ public void testInferNestedFields() { }}); }}; assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right @@ -218,7 +228,8 @@ public void testInferNestedFields() { }}); }}; assertThat(0.2, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(), + 0.00001)); } public void testTreeClassificationProbability() { @@ -241,7 +252,8 @@ public void testTreeClassificationProbability() { List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); List probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -252,7 +264,8 @@ public void testTreeClassificationProbability() { featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -264,7 +277,8 @@ public void testTreeClassificationProbability() { put("bar", null); }}; probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap())) + .getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -345,6 +359,55 @@ public void testOperationsEstimations() { assertThat(tree.estimatedNumOperations(), equalTo(7L)); } + public void testFeatureImportance() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree = Tree.builder() + .setFeatureNames(featureNames) + .setNodes( + TreeNode.builder(0) + .setSplitFeature(0) + .setOperator(Operator.LT) + .setLeftChild(1) + .setRightChild(2) + .setThreshold(0.5) + .setNumberSamples(4L), + TreeNode.builder(1) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4) + .setOperator(Operator.LT) + .setThreshold(0.5) + .setNumberSamples(2L), + TreeNode.builder(2) + .setSplitFeature(1) + .setLeftChild(5) + .setRightChild(6) + .setOperator(Operator.LT) + .setThreshold(0.5) + .setNumberSamples(2L), + TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L), + TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L), + TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L), + TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build(); + + Map featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)), + Collections.emptyMap()); + assertThat(featureImportance.get("foo"), closeTo(-5.0, eps)); + assertThat(featureImportance.get("bar"), closeTo(-2.5, eps)); + + featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap()); + assertThat(featureImportance.get("foo"), closeTo(-5.0, eps)); + assertThat(featureImportance.get("bar"), closeTo(2.5, eps)); + + featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap()); + assertThat(featureImportance.get("foo"), closeTo(5.0, eps)); + assertThat(featureImportance.get("bar"), closeTo(-2.5, eps)); + + featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap()); + assertThat(featureImportance.get("foo"), closeTo(5.0, eps)); + assertThat(featureImportance.get("bar"), closeTo(2.5, eps)); + } + public void testMaxFeatureIndex() { int numFeatures = randomIntBetween(1, 15); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 2a092df8d5932..422cb00cdb192 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -115,7 +115,10 @@ public void testSimulate() throws IOException { " \"inference\": {\n" + " \"target_field\": \"ml.classification\",\n" + " \"inference_config\": {\"classification\": " + - " {\"num_top_classes\":2, \"top_classes_results_field\": \"result_class_prob\"}},\n" + + " {\"num_top_classes\":2, " + + " \"top_classes_results_field\": \"result_class_prob\"," + + " \"num_top_feature_importance_values\": 2" + + " }},\n" + " \"model_id\": \"test_classification\",\n" + " \"field_mappings\": {\n" + " \"col1\": \"col1\",\n" + @@ -153,6 +156,8 @@ public void testSimulate() throws IOException { String responseString = EntityUtils.toString(response.getEntity()); assertThat(responseString, containsString("\"predicted_value\":\"second\"")); assertThat(responseString, containsString("\"predicted_value\":1.0")); + assertThat(responseString, containsString("\"col2\":0.944")); + assertThat(responseString, containsString("\"col1\":0.19999")); String sourceWithMissingModel = "{\n" + " \"pipeline\": {\n" + @@ -321,16 +326,19 @@ private Map generateSourceDoc() { " \"split_gain\": 12.0,\n" + " \"threshold\": 10.0,\n" + " \"decision_type\": \"lte\",\n" + + " \"number_samples\": 300,\n" + " \"default_left\": true,\n" + " \"left_child\": 1,\n" + " \"right_child\": 2\n" + " },\n" + " {\n" + " \"node_index\": 1,\n" + + " \"number_samples\": 100,\n" + " \"leaf_value\": 1\n" + " },\n" + " {\n" + " \"node_index\": 2,\n" + + " \"number_samples\": 200,\n" + " \"leaf_value\": 2\n" + " }\n" + " ],\n" + @@ -352,15 +360,18 @@ private Map generateSourceDoc() { " \"threshold\": 10.0,\n" + " \"decision_type\": \"lte\",\n" + " \"default_left\": true,\n" + + " \"number_samples\": 150,\n" + " \"left_child\": 1,\n" + " \"right_child\": 2\n" + " },\n" + " {\n" + " \"node_index\": 1,\n" + + " \"number_samples\": 50,\n" + " \"leaf_value\": 1\n" + " },\n" + " {\n" + " \"node_index\": 2,\n" + + " \"number_samples\": 100,\n" + " \"leaf_value\": 2\n" + " }\n" + " ],\n" + @@ -445,6 +456,7 @@ private Map generateSourceDoc() { " {\n" + " \"node_index\": 0,\n" + " \"split_feature\": 0,\n" + + " \"number_samples\": 100,\n" + " \"split_gain\": 12.0,\n" + " \"threshold\": 10.0,\n" + " \"decision_type\": \"lte\",\n" + @@ -454,10 +466,12 @@ private Map generateSourceDoc() { " },\n" + " {\n" + " \"node_index\": 1,\n" + + " \"number_samples\": 80,\n" + " \"leaf_value\": 1\n" + " },\n" + " {\n" + " \"node_index\": 2,\n" + + " \"number_samples\": 20,\n" + " \"leaf_value\": 0\n" + " }\n" + " ],\n" + @@ -476,6 +490,7 @@ private Map generateSourceDoc() { " \"node_index\": 0,\n" + " \"split_feature\": 0,\n" + " \"split_gain\": 12.0,\n" + + " \"number_samples\": 180,\n" + " \"threshold\": 10.0,\n" + " \"decision_type\": \"lte\",\n" + " \"default_left\": true,\n" + @@ -484,10 +499,12 @@ private Map generateSourceDoc() { " },\n" + " {\n" + " \"node_index\": 1,\n" + + " \"number_samples\": 10,\n" + " \"leaf_value\": 1\n" + " },\n" + " {\n" + " \"node_index\": 2,\n" + + " \"number_samples\": 170,\n" + " \"leaf_value\": 0\n" + " }\n" + " ],\n" + diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 57389b3660241..478621b8de880 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -102,6 +102,43 @@ public void testMutateDocumentClassificationTopNClasses() { assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo")); } + public void testMutateDocumentClassificationFeatureInfluence() { + ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2); + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + auditor, + "my_processor", + "ml.my_processor", + "classification_model", + classificationConfig, + Collections.emptyMap()); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + List classes = new ArrayList<>(2); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + + Map featureInfluence = new HashMap<>(); + featureInfluence.put("feature_1", 1.13); + featureInfluence.put("feature_2", -42.0); + + InternalInferModelAction.Response response = new InternalInferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, + "foo", + classes, + featureInfluence, + classificationConfig)), + true); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model")); + assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo")); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13)); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0)); + } + @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClassesWithSpecificField() { ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops"); @@ -154,6 +191,34 @@ public void testMutateDocumentRegression() { assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model")); } + public void testMutateDocumentRegressionWithTopFetures() { + RegressionConfig regressionConfig = new RegressionConfig("foo", 2); + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + auditor, + "my_processor", + "ml.my_processor", + "regression_model", + regressionConfig, + Collections.emptyMap()); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + Map featureInfluence = new HashMap<>(); + featureInfluence.put("feature_1", 1.13); + featureInfluence.put("feature_2", -42.0); + + InternalInferModelAction.Response response = new InternalInferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model")); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13)); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0)); + } + public void testGenerateRequestWithEmptyMapping() { String modelId = "model"; Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);