Skip to content

Commit

Permalink
[ML][Inference] add tags url param to GET (#51330)
Browse files Browse the repository at this point in the history
Adds a new URL parameter, `tags` to the GET _ml/inference/<model_id> endpoint.

This parameter allows the list of models to be further reduced to those who contain all the provided tags.
  • Loading branch information
benwtrent committed Jan 24, 2020
1 parent c64f4d1 commit c9e285c
Show file tree
Hide file tree
Showing 16 changed files with 177 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest)
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
}
if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.Nullable;

import java.util.Arrays;
Expand All @@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable {
public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags";

private final List<String> ids;
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Boolean decompressDefinition;
private PageParams pageParams;
private List<String> tags;

/**
* Helper method to create a request that will get ALL TrainedModelConfigs
Expand Down Expand Up @@ -111,6 +114,29 @@ public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinit
return this;
}

public List<String> getTags() {
return tags;
}

/**
* The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}.
*
* The models returned will match ALL tags supplied.
* If none are provided, only the provided ids are used to find models
* @param tags The tags to match when finding models
*/
public GetTrainedModelsRequest setTags(List<String> tags) {
this.tags = tags;
return this;
}

/**
* See {@link GetTrainedModelsRequest#setTags(List)}
*/
public GetTrainedModelsRequest setTags(String... tags) {
return setTags(Arrays.asList(tags));
}

@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ public void testGetTrainedModels() {
.setAllowNoMatch(false)
.setDecompressDefinition(true)
.setIncludeDefinition(false)
.setTags("tag1", "tag2")
.setPageParams(new PageParams(100, 300));

Request request = MLRequestConverters.getTrainedModels(getRequest);
Expand All @@ -845,6 +846,7 @@ public void testGetTrainedModels() {
hasEntry("size", "300"),
hasEntry("allow_no_match", "false"),
hasEntry("decompress_definition", "true"),
hasEntry("tags", "tag1,tag2"),
hasEntry("include_model_definition", "false")
));
assertNull(request.getEntity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3587,8 +3587,10 @@ public void testGetTrainedModels() throws Exception {
.setPageParams(new PageParams(0, 1)) // <2>
.setIncludeDefinition(false) // <3>
.setDecompressDefinition(false) // <4>
.setAllowNoMatch(true); // <5>
.setAllowNoMatch(true) // <5>
.setTags("regression"); // <6>
// end::get-trained-models-request
request.setTags((List<String>)null);

// tag::get-trained-models-execute
GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);
Expand Down
3 changes: 3 additions & 0 deletions docs/java-rest/high-level/ml/get-trained-models.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request]
<5> Allow empty response if no Trained Models match the provided ID patterns.
If false, an error will be thrown if no Trained Models match the
ID patterns.
<6> An optional list of tags used to narrow the model search. A Trained Model
can have many tags or none. The trained models in the response will
contain all the provided tags.

include::../execution.asciidoc[]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition]
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=size]

`tags`::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]

[[ml-get-inference-response-codes]]
==== {api-response-codes-title}
Expand All @@ -97,4 +100,4 @@ The following example gets configuration information for all the trained models:
--------------------------------------------------
GET _ml/inference/
--------------------------------------------------
// TEST[skip:TBD]
// TEST[skip:TBD]
6 changes: 6 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,12 @@ to `false`. When `true`, only a single model must match the ID patterns
provided, otherwise a bad request is returned.
end::include-model-definition[]

tag::tags[]
A comma delimited string of tags. A {infer} model can have many tags, or none.
When supplied, only {infer} models that contain all the supplied tags are
returned.
end::tags[]

tag::indices[]
An array of index names. Wildcards are supported. For example:
`["it_ops_metrics", "server*"]`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -33,18 +34,26 @@ public static class Request extends AbstractGetResourcesRequest {

public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");

private final boolean includeModelDefinition;
private final List<String> tags;

public Request(String id, boolean includeModelDefinition) {
public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id);
setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
this.tags = tags == null ? Collections.emptyList() : tags;
}

public Request(StreamInput in) throws IOException {
super(in);
this.includeModelDefinition = in.readBoolean();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.tags = in.readStringList();
} else {
this.tags = Collections.emptyList();
}
}

@Override
Expand All @@ -56,15 +65,22 @@ public boolean isIncludeModelDefinition() {
return includeModelDefinition;
}

public List<String> getTags() {
return tags;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(includeModelDefinition);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeStringCollection(tags);
}
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition);
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
}

@Override
Expand All @@ -76,7 +92,7 @@ public boolean equals(Object obj) {
return false;
}
Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas

@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
Request request = new Request(randomAlphaOfLength(20),
randomBoolean(),
randomBoolean() ? null :
randomList(10, () -> randomAlphaOfLength(10)));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;


Expand Down Expand Up @@ -70,7 +71,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
listener::onFailure
);

provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
provider.expandIds(request.getResourceId(),
request.isAllowNoResources(),
request.getPageParams(),
new HashSet<>(request.getTags()),
idExpansionListener);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -94,7 +95,11 @@ protected void doExecute(Task task,
listener::onFailure
);

trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
trainedModelProvider.expandIds(request.getResourceId(),
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
idsListener);
}

static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
Expand Down Expand Up @@ -381,14 +382,15 @@ public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener)
public void expandIds(String idExpression,
boolean allowNoResources,
@Nullable PageParams pageParams,
Set<String> tags,
ActionListener<Tuple<Long, Set<String>>> idsListener) {
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
// If there are no resources, there might be no mapping for the id field.
// This makes sure we don't get an error if that happens.
.unmappedType("long"))
.query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
.query(buildExpandIdsQuery(tokens, tags));
if (pageParams != null) {
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
}
Expand All @@ -404,13 +406,23 @@ public void expandIds(String idExpression,
indicesOptions.expandWildcardsClosed(),
indicesOptions))
.source(sourceBuilder);
Set<String> foundResourceIds = new LinkedHashSet<>();
if (tags.isEmpty()) {
foundResourceIds.addAll(matchedResourceIds(tokens));
} else {
for(String resourceId : matchedResourceIds(tokens)) {
// Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
foundResourceIds.add(resourceId);
}
}
}

executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(
response -> {
Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSource = hit.getSourceAsMap();
Expand All @@ -433,7 +445,15 @@ public void expandIds(String idExpression,
idsListener::onFailure
),
client::search);
}

static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
for(String tag : tags) {
boolQueryBuilder.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag));
}
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
}

TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
Expand Down Expand Up @@ -467,7 +487,7 @@ TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefiniti
}
}

private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import org.elasticsearch.xpack.ml.MachineLearning;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;

import static org.elasticsearch.rest.RestRequest.Method.GET;
Expand Down Expand Up @@ -47,7 +49,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
false
);
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
List<String> tags = Arrays.asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;

import java.util.Arrays;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.oneOf;
import static org.mockito.Mockito.mock;

public class TrainedModelProviderTests extends ESTestCase {
Expand Down Expand Up @@ -60,6 +68,24 @@ public void testGetModelThatExistsAsResource() throws Exception {
}
}

public void testExpandIdsQuery() {
QueryBuilder queryBuilder = TrainedModelProvider.buildExpandIdsQuery(new String[]{"model*", "trained_mode"},
Arrays.asList("tag1", "tag2"));
assertThat(queryBuilder, is(instanceOf(ConstantScoreQueryBuilder.class)));

QueryBuilder innerQuery = ((ConstantScoreQueryBuilder)queryBuilder).innerQuery();
assertThat(innerQuery, is(instanceOf(BoolQueryBuilder.class)));

((BoolQueryBuilder)innerQuery).filter().forEach(qb -> {
if (qb instanceof TermQueryBuilder) {
assertThat(((TermQueryBuilder)qb).fieldName(), equalTo(TrainedModelConfig.TAGS.getPreferredName()));
assertThat(((TermQueryBuilder)qb).value(), is(oneOf("tag1", "tag2")));
return;
}
assertThat(qb, is(instanceOf(BoolQueryBuilder.class)));
});
}

public void testGetModelThatExistsAsResourceButIsMissing() {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
Expand Down
Loading

0 comments on commit c9e285c

Please sign in to comment.