Skip to content

Commit

Permalink
[Ml] Validate tree feature index is within range (#52460) (#52515)
Browse files Browse the repository at this point in the history
This changes the tree validation code to ensure no node in the tree has a
feature index that is beyond the bounds of the feature_names array.
Specifically this handles the situation where the C++ emits a tree containing
a single node and an empty feature_names list. This is valid tree used to
centre the data in the ensemble but the validation code would reject this
as feature_names is empty. This meant a broken workflow as you cannot GET
the model and PUT it back
  • Loading branch information
davidkyle committed Feb 20, 2020
1 parent 031d90d commit 1fe4fde
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public static Tree createRandom() {
}

public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
int numFeatures = featureNames.size();
int maxFeatureIndex = featureNames.size() -1;
Tree.Builder builder = Tree.builder();
builder.setFeatureNames(featureNames);

TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());

for (int i = 0; i < depth -1; i++) {
Expand All @@ -76,7 +76,7 @@ public static Tree buildRandomTree(List<String> featureNames, int depth, TargetT
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,10 @@ public static Builder builder() {

@Override
public void validate() {
if (featureNames.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
int maxFeatureIndex = maxFeatureIndex();
if (maxFeatureIndex >= featureNames.size()) {
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
}
checkTargetType();
detectMissingNodes();
Expand All @@ -267,6 +269,23 @@ public long estimatedNumOperations() {
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
}

/**
* The highest index of a feature used any of the nodes.
* If no nodes use a feature return -1. This can only happen
* if the tree contains a single leaf node.
*
* @return The max or -1
*/
int maxFeatureIndex() {
int maxFeatureIndex = -1;

for (TreeNode node : nodes) {
maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
}

return maxFeatureIndex;
}

private void checkTargetType() {
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
throw ExceptionsHelper.badRequestException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;


Expand Down Expand Up @@ -72,10 +73,10 @@ public static Tree createRandom() {

public static Tree buildRandomTree(List<String> featureNames, int depth) {
Tree.Builder builder = Tree.builder();
int numFeatures = featureNames.size() - 1;
int maxFeatureIndex = featureNames.size() - 1;
builder.setFeatureNames(featureNames);

TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());

for (int i = 0; i < depth -1; i++) {
Expand All @@ -86,7 +87,7 @@ public static Tree buildRandomTree(List<String> featureNames, int depth) {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
Expand Down Expand Up @@ -339,26 +340,83 @@ public void testTreeWithTargetTypeAndLabelsMismatch() {
assertThat(ex.getMessage(), equalTo(msg));
}

public void testTreeWithEmptyFeatureNames() {
String msg = "[feature_names] must not be empty for tree model";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}

public void testOperationsEstimations() {
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
assertThat(tree.estimatedNumOperations(), equalTo(7L));
}

public void testMaxFeatureIndex() {

int numFeatures = randomIntBetween(1, 15);
// We need a tree where every feature is used, choose a depth big enough to
// accommodate those non-leave nodes (leaf nodes don't have a feature index)
int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1;
List<String> featureNames = new ArrayList<>(numFeatures);
for (int i=0; i<numFeatures; i++) {
featureNames.add("feature" + i);
}

Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);

// build a tree using feature indices 0..numFeatures -1
int featureIndex = 0;
TreeNode.Builder node = builder.addJunction(0, featureIndex++, true, randomDouble());
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());

for (int i = 0; i < depth -1; i++) {
List<Integer> nextNodes = new ArrayList<>();
for (int nodeId : childNodes) {
if (i == depth -2) {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
}
childNodes = nextNodes;
}

Tree tree = builder.build();

assertEquals(numFeatures, tree.maxFeatureIndex() +1);
}

public void testMaxFeatureIndexSingleNodeTree() {
Tree tree = Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build();

assertEquals(-1, tree.maxFeatureIndex());
}

public void testValidateGivenMissingFeatures() {
List<String> featureNames = Arrays.asList("foo", "bar", "baz");

// build a tree referencing a feature at index 3 which is not in the featureNames list
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
builder.addJunction(0, 0, true, randomDouble());
builder.addJunction(1, 1, true, randomDouble());
builder.addJunction(2, 3, true, randomDouble());
builder.addLeaf(3, randomDouble());
builder.addLeaf(4, randomDouble());
builder.addLeaf(5, randomDouble());
builder.addLeaf(6, randomDouble());

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate());
assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array"));
}

public void testValidateGivenTreeWithNoFeatures() {
Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
}

private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ integTest.runner {
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
'ml/inference_crud/Test put ensemble with empty models',
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
'ml/inference_crud/Test put model with empty input.field_names',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@ setup:
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
---
"Test put ensemble with single node and empty feature_names":

- do:
ml.put_trained_model:
model_id: "ensemble_tree_empty_feature_names"
body: >
{
"input": {
"field_names": "fieldy_mc_fieldname"
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": [],
"trained_models": [
{
"tree": {
"feature_names": [],
"tree_structure": [
{
"node_index": 0,
"decision_type": "lte",
"leaf_value": 12.0,
"default_left": true
}]
}
}
]
}
}
}
}
---
"Test put ensemble with empty models":
- do:
catch: /\[trained_models\] must not be empty/
Expand All @@ -192,11 +226,11 @@ setup:
}
}
---
"Test put ensemble with tree where tree has empty feature-names":
"Test put ensemble with tree where tree has out of bounds feature_names index":
- do:
catch: /\[feature_names\] must not be empty/
catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/
ml.put_trained_model:
model_id: "ensemble_tree_missing_feature_names"
model_id: "ensemble_tree_out_of_bounds_feature_names_index"
body: >
{
"input": {
Expand All @@ -213,7 +247,7 @@ setup:
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_feature": 1,
"split_gain": 12.0,
"threshold": 10.0,
"decision_type": "lte",
Expand Down

0 comments on commit 1fe4fde

Please sign in to comment.