Skip to content

Commit

Permalink
[7.x][ML] Remove top level importance from classification inference r…
Browse files Browse the repository at this point in the history
…esults (#62486) (#62964)

As we have decided top level importance for classification is not useful,
it has been removed from the results from the training job. This commit
also removes them from inference.

Backport of #62486
  • Loading branch information
dimitris-athanasiou committed Sep 29, 2020
1 parent cc33df8 commit 7f6c1ff
Show file tree
Hide file tree
Showing 20 changed files with 649 additions and 264 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class FeatureImportance implements ToXContentObject {

static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareDouble(optionalConstructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObjectArray(optionalConstructorArg(),
(p, c) -> ClassImportance.fromXContent(p),
new ParseField(FeatureImportance.CLASSES));
Expand All @@ -58,10 +58,10 @@ public static FeatureImportance fromXContent(XContentParser parser) {
}

private final List<ClassImportance> classImportance;
private final double importance;
private final Double importance;
private final String featureName;

public FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
public FeatureImportance(String featureName, Double importance, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
Expand All @@ -71,7 +71,7 @@ public List<ClassImportance> getClassImportance() {
return classImportance;
}

public double getImportance() {
public Double getImportance() {
return importance;
}

Expand All @@ -83,7 +83,9 @@ public String getFeatureName() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (importance != null) {
builder.field(IMPORTANCE, importance);
}
if (classImportance != null && classImportance.isEmpty() == false) {
builder.field(CLASSES, classImportance);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
protected FeatureImportance createTestInstance() {
return new FeatureImportance(
randomAlphaOfLength(10),
randomDoubleBetween(-10.0, 10.0, false),
randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

abstract class AbstractFeatureImportance implements Writeable, ToXContentObject {

public abstract String getFeatureName();

public abstract Map<String, Object> toMap();

@Override
public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(toMap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -26,157 +25,101 @@
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class FeatureImportance implements Writeable, ToXContentObject {
public class ClassificationFeatureImportance extends AbstractFeatureImportance {

private final List<ClassImportance> classImportance;
private final double importance;
private final String featureName;
static final String IMPORTANCE = "importance";

static final String FEATURE_NAME = "feature_name";
static final String CLASSES = "classes";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}

public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
importance,
classImportance);
}

public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
classImportance);
}

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance",
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
private static final ConstructingObjectParser<ClassificationFeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("classification_feature_importance",
a -> new ClassificationFeatureImportance((String) a[0], (List<ClassImportance>) a[1])
);

static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareString(constructorArg(), new ParseField(ClassificationFeatureImportance.FEATURE_NAME));
PARSER.declareObjectArray(optionalConstructorArg(),
(p, c) -> ClassImportance.fromXContent(p),
new ParseField(FeatureImportance.CLASSES));
new ParseField(ClassificationFeatureImportance.CLASSES));
}

public static FeatureImportance fromXContent(XContentParser parser) {
public static ClassificationFeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
public ClassificationFeatureImportance(String featureName, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
this.classImportance = classImportance == null ? Collections.emptyList() : Collections.unmodifiableList(classImportance);
}

public FeatureImportance(StreamInput in) throws IOException {
public ClassificationFeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
if (in.getVersion().before(Version.V_7_10_0)) {
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.classImportance = ClassImportance.fromMap(classImportance);
} else {
this.classImportance = in.readList(ClassImportance::new);
}
} else {
this.classImportance = null;
}
this.classImportance = in.readList(ClassImportance::new);
}

public List<ClassImportance> getClassImportance() {
return classImportance;
}

public double getImportance() {
return importance;
}

@Override
public String getFeatureName() {
return featureName;
}

public double getTotalImportance() {
if (classImportance.size() == 2) {
// Binary classification. We can return the first class importance here
return Math.abs(classImportance.get(0).getImportance());
}
return classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.featureName);
out.writeDouble(this.importance);
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
if (out.getVersion().before(Version.V_7_10_0)) {
out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
} else {
out.writeList(this.classImportance);
}
}
out.writeString(featureName);
out.writeList(classImportance);
}

@Override
public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
if (classImportance != null) {
if (classImportance.isEmpty() == false) {
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
}
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.field(CLASSES, classImportance);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object;
ClassificationFeatureImportance that = (ClassificationFeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}

@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
return Objects.hash(featureName, classImportance);
}

public static class ClassImportance implements Writeable, ToXContentObject {

static final String CLASS_NAME = "class_name";
static final String IMPORTANCE = "importance";

private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance_class_importance",
a -> new ClassImportance((String) a[0], (Double) a[1])
new ConstructingObjectParser<>("classification_feature_importance_class_importance",
a -> new ClassImportance(a[0], (Double) a[1])
);

static {
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
}

private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
return new ClassImportance(entry.getKey(), entry.getValue());
}

private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
}

private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
}

public static ClassImportance fromXContent(XContentParser parser) {
Expand Down Expand Up @@ -219,11 +162,7 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME, className);
builder.field(IMPORTANCE, importance);
builder.endObject();
return builder;
return builder.map(toMap());
}

@Override
Expand Down

0 comments on commit 7f6c1ff

Please sign in to comment.