Skip to content

Commit

Permalink
[ML][Inference] Add support for multi-value leaves to the tree model (#…
Browse files Browse the repository at this point in the history
…52531)

This adds support for multi-value leaves. This is a prerequisite for multi-class boosted tree classification.
  • Loading branch information
benwtrent committed Feb 27, 2020
1 parent 2d85e41 commit e39eade
Show file tree
Hide file tree
Showing 26 changed files with 575 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -225,7 +226,7 @@ public Builder addLeaf(int nodeIndex, double value) {
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
nodes.add(null);
}
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(Collections.singletonList(value)));
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public class TreeNode implements ToXContentObject {
Expand Down Expand Up @@ -61,7 +62,7 @@ public class TreeNode implements ToXContentObject {
PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE);
PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX);
PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN);
PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE);
PARSER.declareDoubleArray(Builder::setLeafValue, LEAF_VALUE);
PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES);
}

Expand All @@ -74,7 +75,7 @@ public static Builder fromXContent(XContentParser parser) {
private final Integer splitFeature;
private final int nodeIndex;
private final Double splitGain;
private final Double leafValue;
private final List<Double> leafValue;
private final Boolean defaultLeft;
private final Integer leftChild;
private final Integer rightChild;
Expand All @@ -86,7 +87,7 @@ public static Builder fromXContent(XContentParser parser) {
Integer splitFeature,
int nodeIndex,
Double splitGain,
Double leafValue,
List<Double> leafValue,
Boolean defaultLeft,
Integer leftChild,
Integer rightChild,
Expand Down Expand Up @@ -123,7 +124,7 @@ public Double getSplitGain() {
return splitGain;
}

public Double getLeafValue() {
public List<Double> getLeafValue() {
return leafValue;
}

Expand Down Expand Up @@ -212,7 +213,7 @@ public static class Builder {
private Integer splitFeature;
private int nodeIndex;
private Double splitGain;
private Double leafValue;
private List<Double> leafValue;
private Boolean defaultLeft;
private Integer leftChild;
private Integer rightChild;
Expand Down Expand Up @@ -250,7 +251,7 @@ public Builder setSplitGain(Double splitGain) {
return this;
}

public Builder setLeafValue(Double leafValue) {
public Builder setLeafValue(List<Double> leafValue) {
this.leafValue = leafValue;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.test.AbstractXContentTestCase;

import java.io.IOException;
import java.util.Collections;

public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {

Expand All @@ -48,7 +49,7 @@ protected TreeNode createTestInstance() {
public static TreeNode createRandomLeafNode(double internalValue) {
return TreeNode.builder(randomInt(100))
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
.setLeafValue(internalValue)
.setLeafValue(Collections.singletonList(internalValue))
.setNumberSamples(randomNonNegativeLong())
.build();
}
Expand All @@ -60,7 +61,7 @@ public static TreeNode.Builder createRandom(int nodeIndex,
Integer featureIndex,
Operator operator) {
return TreeNode.builder(nodeIndex)
.setLeafValue(left == null ? randomDouble() : null)
.setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
.setLeftChild(left)
.setRightChild(right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,51 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.ingest.IngestDocument;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;

public class RawInferenceResults extends SingleValueInferenceResults {
public class RawInferenceResults implements InferenceResults {

public static final String NAME = "raw";

public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value, featureImportance);
private final double[] value;
private final Map<String, Double> featureImportance;

public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}

public double[] getValue() {
return value;
}

public RawInferenceResults(StreamInput in) throws IOException {
super(in);
public Map<String, Double> getFeatureImportance() {
return featureImportance;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
throw new UnsupportedOperationException("[raw] does not support wire serialization");
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
return Arrays.equals(value, that.value)
&& Objects.equals(featureImportance, that.featureImportance);
}

@Override
public int hashCode() {
return Objects.hash(value(), getFeatureImportance());
return Objects.hash(Arrays.hashCode(value), featureImportance);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,29 @@ private InferenceHelpers() { }
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude) {

if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
throw ExceptionsHelper
.serverError(
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
null,
probabilities.size(),
probabilities.length,
classificationLabels.size());
}

List<Double> scores = classificationWeights == null ?
double[] scores = classificationWeights == null ?
probabilities :
IntStream.range(0, probabilities.size())
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
.boxed()
.collect(Collectors.toList());
IntStream.range(0, probabilities.length)
.mapToDouble(i -> probabilities[i] * classificationWeights[i])
.toArray();

int[] sortedIndices = IntStream.range(0, probabilities.size())
int[] sortedIndices = IntStream.range(0, scores.length)
.boxed()
.sorted(Comparator.comparing(scores::get).reversed())
.sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed())
.mapToInt(i -> i)
.toArray();

Expand All @@ -59,14 +58,14 @@ public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>>

List<String> labels = classificationLabels == null ?
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
IntStream.range(0, probabilities.length).boxed().map(String::valueOf).collect(Collectors.toList()) :
classificationLabels;

int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx]));
}

return Tuple.tuple(sortedIndices[0], topClassEntries);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.apache.lucene.util.Accountable;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
Expand Down Expand Up @@ -62,4 +63,8 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* @return A {@code Map<String, Double>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);

default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
Expand All @@ -20,7 +21,6 @@
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
Expand Down Expand Up @@ -139,19 +139,20 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> inferenceResults = new ArrayList<>(this.models.size());
double[][] inferenceResults = new double[this.models.size()][];
List<Map<String, Double>> featureInfluence = new ArrayList<>();
int i = 0;
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
this.models.forEach(model -> {
for (TrainedModel model : models) {
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
assert result instanceof SingleValueInferenceResults;
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
inferenceResults.add(inferenceResult.value());
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) {
featureInfluence.add(inferenceResult.getFeatureImportance());
}
});
List<Double> processed = outputAggregator.processValues(inferenceResults);
}
double[] processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, featureInfluence, config, featureDecoderMap);
}

Expand All @@ -160,13 +161,13 @@ public TargetType targetType() {
return targetType;
}

private InferenceResults buildResults(List<Double> processedInferences,
private InferenceResults buildResults(double[] processedInferences,
List<Map<String, Double>> featureInfluence,
InferenceConfig config,
Map<String, String> featureDecoderMap) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
}
switch(targetType) {
Expand All @@ -176,7 +177,7 @@ private InferenceResults buildResults(List<Double> processedInferences,
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
Expand Down Expand Up @@ -356,6 +357,11 @@ public Collection<Accountable> getChildResources() {
return Collections.unmodifiableCollection(accountables);
}

@Override
public Version getMinimalCompatibilityVersion() {
return models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(Version::compareTo).orElse(Version.V_7_6_0);
}

public static class Builder {
private List<String> featureNames;
private List<TrainedModel> trainedModels;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;

import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;

public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {

Expand Down Expand Up @@ -78,31 +78,39 @@ public Integer expectedValueSize() {
}

@Override
public List<Double> processValues(List<Double> values) {
public double[] processValues(double[][] values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.size() != weights.length) {
if (weights != null && values.length != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
double summation = weights == null ?
values.stream().mapToDouble(Double::valueOf).sum() :
IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
double probOfClassOne = sigmoid(summation);
double[] sumOnAxis1 = new double[values[0].length];
for (int j = 0; j < values.length; j++) {
double[] value = values[j];
double weight = weights == null ? 1.0 : weights[j];
for(int i = 0; i < value.length; i++) {
if (i >= sumOnAxis1.length) {
throw new IllegalArgumentException("value entries must have the same dimensions");
}
sumOnAxis1[i] += (value[i] * weight);
}
}
if (sumOnAxis1.length > 1) {
return softMax(sumOnAxis1);
}

double probOfClassOne = sigmoid(sumOnAxis1[0]);
assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
return new double[] {1.0 - probOfClassOne, probOfClassOne};
}

@Override
public double aggregate(List<Double> values) {
public double aggregate(double[] values) {
Objects.requireNonNull(values, "values must not be null");
assert values.size() == 2;
int bestValue = 0;
double bestProb = Double.NEGATIVE_INFINITY;
for (int i = 0; i < values.size(); i++) {
if (values.get(i) == null) {
throw new IllegalArgumentException("values must not contain null values");
}
if (values.get(i) > bestProb) {
bestProb = values.get(i);
for (int i = 0; i < values.length; i++) {
if (values[i] > bestProb) {
bestProb = values[i];
bestValue = i;
}
}
Expand Down
Loading

0 comments on commit e39eade

Please sign in to comment.