Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adds feature importance option to inference processor #52218

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ include::common-options.asciidoc[]
Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`.

`num_top_feature_importance_values`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values]

[discrete]
[[inference-processor-classification-opt]]
Expand All @@ -63,6 +66,9 @@ Specifies the number of top class predictions to return. Defaults to 0.
Specifies the field to which the top classes are written. Defaults to
`top_classes`.

`num_top_feature_importance_values`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values]

[discrete]
[[inference-processor-config-example]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -73,6 +74,7 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,

private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private volatile Map<String, String> decoderMap;

private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
Expand Down Expand Up @@ -115,13 +117,30 @@ public List<PreProcessor> getPreProcessors() {
return preProcessors;
}

private void preProcess(Map<String, Object> fields) {
void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}

public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
return trainedModel.infer(fields, config);
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}

private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -235,6 +236,11 @@ public void process(Map<String, Object> fields) {
fields.put(destField, concatEmbeddings(processedFeatures));
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(destField, fieldName);
}

@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public String getFeatureName() {
return featureName;
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* PreProcessor for one hot encoding a set of categorical values for a given field.
Expand Down Expand Up @@ -80,6 +82,11 @@ public Map<String, String> getHotMap() {
return hotMap;
}

@Override
public Map<String, String> reverseLookup() {
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
* @param fields The fields and their values to process
*/
void process(Map<String, Object> fields);

/**
* @return Reverse lookup map to match resulting features to their original feature name
*/
Map<String, String> reverseLookup();
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public String getFeatureName() {
return featureName;
}

@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}

@Override
public String getName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,25 @@ public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
super(value);
assert config instanceof ClassificationConfig;
ClassificationConfig classificationConfig = (ClassificationConfig)config;
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
}

public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}

private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
Expand Down Expand Up @@ -74,16 +90,17 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) &&
Objects.equals(classificationLabel, that.classificationLabel) &&
Objects.equals(resultsField, that.resultsField) &&
Objects.equals(topNumClassesField, that.topNumClassesField) &&
Objects.equals(topClasses, that.topClasses);
return Objects.equals(value(), that.value())
&& Objects.equals(classificationLabel, that.classificationLabel)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField);
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
}

@Override
Expand All @@ -100,6 +117,9 @@ public void writeResult(IngestDocument document, String parentResultField) {
document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
import org.elasticsearch.ingest.IngestDocument;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class RawInferenceResults extends SingleValueInferenceResults {

public static final String NAME = "raw";

public RawInferenceResults(double value) {
super(value);
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value, featureImportance);
}

public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
}

@Override
Expand All @@ -34,12 +35,13 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value());
return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value());
return Objects.hash(value(), getFeatureImportance());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

public class RegressionInferenceResults extends SingleValueInferenceResults {
Expand All @@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;

public RegressionInferenceResults(double value, InferenceConfig config) {
super(value);
assert config instanceof RegressionConfig;
RegressionConfig regressionConfig = (RegressionConfig)config;
this(value, (RegressionConfig) config, Collections.emptyMap());
}

public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}

private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
this.resultsField = regressionConfig.getResultsField();
}

public RegressionInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
this.resultsField = in.readString();
}

Expand All @@ -44,19 +54,24 @@ public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField);
return Objects.equals(value(), that.value())
&& Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
}

@Override
public int hashCode() {
return Objects.hash(value(), resultsField);
return Objects.hash(value(), resultsField, getFeatureImportance());
}

@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,61 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

public abstract class SingleValueInferenceResults implements InferenceResults {

private final double value;
private final Map<String, Double> featureImportance;

static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
}

SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.featureImportance = Collections.emptyMap();
}
}

SingleValueInferenceResults(double value) {
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
}

public Double value() {
return value;
}

public Map<String, Double> getFeatureImportance() {
return featureImportance;
}

public String valueAsString() {
return String.valueOf(value);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}

}
Loading