Skip to content

Commit

Permalink
[7.6] [ML][Inference] Fix weighted mode definition (#51648) (#51696)
Browse files Browse the repository at this point in the history
* [ML][Inference] Fix weighted mode definition (#51648)

Weighted mode inaccurately assumed that the "max value" of the input values would be the maximum class value. This does not make sense. 

Weighted Mode should know how many classes there are. Hence the new parameter `num_classes`. This indicates what the maximum class value to be expected.
  • Loading branch information
benwtrent committed Jan 30, 2020
1 parent 1797be0 commit 7135ce2
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ public class WeightedMode implements OutputAggregator {

public static final String NAME = "weighted_mode";
public static final ParseField WEIGHTS = new ParseField("weights");
public static final ParseField NUM_CLASSES = new ParseField("num_classes");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new WeightedMode((List<Double>)a[0]));
a -> new WeightedMode((Integer)a[0], (List<Double>)a[1]));
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
}

Expand All @@ -49,9 +51,11 @@ public static WeightedMode fromXContent(XContentParser parser) {
}

private final List<Double> weights;
private final int numClasses;

public WeightedMode(List<Double> weights) {
public WeightedMode(int numClasses, List<Double> weights) {
this.weights = weights;
this.numClasses = numClasses;
}

@Override
Expand All @@ -65,6 +69,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
builder.endObject();
return builder;
}
Expand All @@ -74,11 +79,11 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedMode that = (WeightedMode) o;
return Objects.equals(weights, that.weights);
return Objects.equals(weights, that.weights) && numClasses == that.numClasses;
}

@Override
public int hashCode() {
return Objects.hash(weights);
return Objects.hash(weights, numClasses);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
Expand Down Expand Up @@ -69,17 +68,17 @@ public static Ensemble createRandom(TargetType targetType) {
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
.limit(numberOfModels)
.collect(Collectors.toList());
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
List<OutputAggregator> possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
new LogisticRegression(weights)));
if (targetType.equals(TargetType.REGRESSION)) {
possibleAggregators.add(new WeightedSum(weights));
}
OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
List<String> categoryLabels = null;
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
}
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
randomFrom(
new WeightedMode(
categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10),
weights),
new LogisticRegression(weights));
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
Stream.generate(ESTestCase::randomDouble)
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {

WeightedMode createTestInstance(int numberOfWeights) {
return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
return new WeightedMode(
randomIntBetween(2, 10),
Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
}

@Override
Expand All @@ -45,7 +47,7 @@ protected boolean supportsUnknownFields() {

@Override
protected WeightedMode createTestInstance() {
return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100));
return randomBoolean() ? new WeightedMode(randomIntBetween(2, 10), null) : createTestInstance(randomIntBetween(1, 100));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -29,6 +30,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
public static final ParseField NAME = new ParseField("weighted_mode");
public static final ParseField WEIGHTS = new ParseField("weights");
public static final ParseField NUM_CLASSES = new ParseField("num_classes");

private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
Expand All @@ -38,7 +40,8 @@ private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean
ConstructingObjectParser<WeightedMode, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new WeightedMode((List<Double>)a[0]));
a -> new WeightedMode((Integer) a[0], (List<Double>)a[1]));
parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
return parser;
}
Expand All @@ -52,17 +55,23 @@ public static WeightedMode fromXContentLenient(XContentParser parser) {
}

private final double[] weights;
private final int numClasses;

WeightedMode() {
this((List<Double>) null);
WeightedMode(int numClasses) {
this(numClasses, null);
}

private WeightedMode(List<Double> weights) {
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
private WeightedMode(Integer numClasses, List<Double> weights) {
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray(), numClasses);
}

public WeightedMode(double[] weights) {
public WeightedMode(double[] weights, Integer numClasses) {
this.weights = weights;
this.numClasses = ExceptionsHelper.requireNonNull(numClasses, NUM_CLASSES);
if (this.numClasses <= 1) {
throw new IllegalArgumentException("[" + NUM_CLASSES.getPreferredName() + "] must be greater than 1.");
}

}

public WeightedMode(StreamInput in) throws IOException {
Expand All @@ -71,6 +80,7 @@ public WeightedMode(StreamInput in) throws IOException {
} else {
this.weights = null;
}
this.numClasses = in.readVInt();
}

@Override
Expand Down Expand Up @@ -99,7 +109,10 @@ public List<Double> processValues(List<Double> values) {
maxVal = integerValue;
}
}
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
if (maxVal >= numClasses) {
throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]");
}
List<Double> frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY));
for (int i = 0; i < freqArray.size(); i++) {
Double weight = weights == null ? 1.0 : weights[i];
Integer value = freqArray.get(i);
Expand Down Expand Up @@ -133,7 +146,7 @@ public String getName() {

@Override
public boolean compatibleWith(TargetType targetType) {
return true;
return targetType.equals(TargetType.CLASSIFICATION);
}

@Override
Expand All @@ -147,6 +160,7 @@ public void writeTo(StreamOutput out) throws IOException {
if (weights != null) {
out.writeDoubleArray(weights);
}
out.writeVInt(numClasses);
}

@Override
Expand All @@ -155,6 +169,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (weights != null) {
builder.field(WEIGHTS.getPreferredName(), weights);
}
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
builder.endObject();
return builder;
}
Expand All @@ -164,12 +179,12 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedMode that = (WeightedMode) o;
return Arrays.equals(weights, that.weights);
return Arrays.equals(weights, that.weights) && numClasses == that.numClasses;
}

@Override
public int hashCode() {
return Arrays.hashCode(weights);
return Objects.hash(Arrays.hashCode(weights), numClasses);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
Expand All @@ -26,7 +28,9 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -302,4 +306,70 @@ public void testRamUsageEstimation() {
assertThat(test.ramBytesUsed(), greaterThan(0L));
}

public void testMultiClassIrisInference() throws IOException {
// Fairly simple, random forest classification model built to fit in our format
// Trained on the well known Iris dataset
String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" +
"YxRMGlt2WPKwEC/gYe2bOnBl+rOoyzQq3SR4OG5ev3t/9WLmicg+fc9cd1Gm5c3VSfz+2x6t1nlZVts3Wa" +
"Z0ditX93Wrr0vpUuqRIH1zVXPJxVbljmie5K3b1vr3ifPw125wPj65+9u/z8fnfn+4vh0jy9LPLzw/+UGb" +
"Vu8rVhyptb+wOv7iyytaH/FD+PZWVu6xo7u8e92x+3XOaSZVurtm1QydVXZ7W7XPPcIoGWpIVG/etOWbNR" +
"Ru3zqp28r+B5bVrH5a7bZ2s91m+aU5Cc6LMdvu/Z3gL55hndfILdnNOtGPuS1ftD901LDKs+wFYziy3j/d" +
"3FwjgKoJ0m3xJ81N7kvn3cix64aEH1gOfX8CXkVEtemFAahvz2IcgsBCkB0GhEMTKH1Ri3xn49yosYO0Bj" +
"hErDpGy3Y9JLbjSRvoQNAF+jIVvPPi2Bz67gK8iK1v0ptmsWoHoWXFDQG+x9/IeQ8Hbqm+swBGT15dr1wM" +
"CKDNA2yv0GKxE7b4+cwFBWDKQ+BlfDSgsat43tH94xD49diMtoeEVhgaN2mi6iwzMKqFjKUDPEBqCrmq6O" +
"HHd0PViMreajEEFJxlaccAi4B4CgdhzHBHdOcFqCSYTI14g2WS2z0007DfAe4Hy7DdkrI2I+9yGIhitJhh" +
"tTBjXYN+axcX1Ab7Oom2P+RgAtffDLj/A0a5vfkAbL/jWCwJHj9jT3afMzSQtQJYEhR6ibQ984+McsYQqg" +
"m4baTBKMB6LHhDo/Aj8BInDcI6q0ePG/rgMx+57hkXnU+AnVGBxCWH3zq3ijclwI/tW3lC2jSVsWM4oN1O" +
"SIc4XkjRGXjGEosylOUkUQ7AhhkBgSXYc1YvAksw4PG1kGWsAT5tOxbruOKbTnwIkSYxD1MbXsWAIUwMKz" +
"eGUeDUbRwI9Fkek5CiwqAM3Bz6NUgdUt+vBslhIo8UM6kDQac4kDiicpHfe+FwY2SQI5q3oadvnoQ3hMHE" +
"pCaHUgkqoVcRCG5aiKzCUCN03cUtJ4ikJxZTVlcWvDvarL626DiiVLH71pf0qG1y9H7mEPSQBNoTtQpFba" +
"NzfDFfXSNJqPFJBkFb/1iiNLxhSAW3u4Ns7qHHi+i1F9fmyj1vV0sDIZonP0wh+waxjLr1vOPcmxORe7n3" +
"pKOKIhVp9Rtb4+Owa3xCX/TpFPnrig6nKTNisNl8aNEKQRfQITh9kG/NhTzcvpwRZoARZvkh8S6h7Oz1zI" +
"atZeuYWk5nvC4TJ2aFFJXBCTkcO9UuQQ0qb3FXdx4xTPH6dBeApP0CQ43QejN8kd7l64jI1krMVgJfPEf7" +
"h3uq3o/K/ztZqP1QKFagz/G+t1XxwjeIFuqkRbXoTdlOTGnwCIoKZ6ku1AbrBoN6oCdX56w3UEOO0y2B9g" +
"aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" +
"fdpTi9JB0sDp2JR7b309mn5HuPkEAAA==";

TrainedModelDefinition definition = InferenceToXContentCompressor.inflate(compressedDef,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry());

Map<String, Object> fields = new HashMap<String, Object>(){{
put("sepal_length", 5.1);
put("sepal_width", 3.5);
put("petal_length", 1.4);
put("petal_width", 0.2);
}};

assertThat(
((ClassificationInferenceResults)definition.getTrainedModel()
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-setosa"));

fields = new HashMap<String, Object>(){{
put("sepal_length", 7.0);
put("sepal_width", 3.2);
put("petal_length", 4.7);
put("petal_width", 1.4);
}};
assertThat(
((ClassificationInferenceResults)definition.getTrainedModel()
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-versicolor"));

fields = new HashMap<String, Object>(){{
put("sepal_length", 6.5);
put("sepal_width", 3.0);
put("petal_length", 5.2);
put("petal_width", 2.0);
}};
assertThat(
((ClassificationInferenceResults)definition.getTrainedModel()
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-virginica"));
}

}

0 comments on commit 7135ce2

Please sign in to comment.