Skip to content

Commit

Permalink
[ML] Hybrid retrieval for Semantic search. (#91348)
Browse files Browse the repository at this point in the history
Adds the query option to the _semantic_search endpoint for hybrid retrieval. 
Scoring is controlled by the boost fields of the knn search and the query.
  • Loading branch information
davidkyle committed Nov 9, 2022
1 parent 8a4a8ba commit b46ee9c
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 69 deletions.
77 changes: 56 additions & 21 deletions docs/reference/search/semantic-search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,6 @@ The resulting dense vector is then used in a <<knn-search,k-nearest neighbor (kn
created with the same text embedding model. The search results are semantically similar as learned
by the model.

////
[source,console]
----
PUT my-index
{
"mappings": {
"properties": {
"text_embedding": {
"type": "dense_vector",
"dims": 512,
"index": true,
"similarity": "cosine"
}
}
}
}
----
////

[source,console]
----
GET my-index/_semantic_search
Expand Down Expand Up @@ -110,15 +91,69 @@ value must be less than `num_candidates`.
shard. Cannot exceed 10,000. {es} collects `num_candidates` results from each
shard, then merges them to find the top `k` results. Increasing
`num_candidates` tends to improve the accuracy of the final `k` results.
====
`filter`::
(Optional, <<query-dsl,Query DSL object>>) Query to filter the documents that
can match. The kNN search will return the top `k` documents that also match
this filter. The value can be a single query or a list of queries. If `filter`
is not provided, all documents are allowed to match.
====

`query`::
(Optional, <<query-dsl,query object>>) Defines the search definition using the
<<query-dsl,Query DSL>>.

`text_embedding_config`::
(Object, optional) Override certain setting of the text embedding model's configuration
.Properties of text_embedding inference
[%collapsible%open]
=====
`results_field`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
`tokenization`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
+
.Properties of tokenization
[%collapsible%open]
======
`bert`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
+
.Properties of bert
[%collapsible%open]
=======
`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
=======
`roberta`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta]
+
.Properties of roberta
[%collapsible%open]
=======
`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
=======
`mpnet`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
+
.Properties of mpnet
[%collapsible%open]
=======
`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
=======
======
=====

include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def]
include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def]
Expand All @@ -129,5 +164,5 @@ include::{es-repo-dir}/search/search.asciidoc[tag=stored-fields-def]
[[semantic-search-api-response-body]]
==== {api-response-body-title}

A sementic search response has the same structure as a kNN search response.
The semantic search response has the same structure as a kNN search response.

Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ private SemanticSearchAction(String name) {
public static class Request extends ActionRequest implements IndicesRequest.Replaceable {

public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed
public static final ParseField TEXT_EMBEDDING_CONFIG = new ParseField("text_embedding_config");

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME);

Expand All @@ -67,15 +68,14 @@ public static class Request extends ActionRequest implements IndicesRequest.Repl
PARSER.declareObject(
Request.Builder::setUpdate,
(p, c) -> TextEmbeddingConfigUpdate.fromXContentStrict(p),
InferTrainedModelDeploymentAction.Request.INFERENCE_CONFIG
TEXT_EMBEDDING_CONFIG
);
PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
PARSER.declareFieldArray(
Request.Builder::setFilters,
PARSER.declareObject(
Request.Builder::setQueryBuilder,
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
KnnSearchBuilder.FILTER_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
SearchSourceBuilder.QUERY_FIELD
);
PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
PARSER.declareField(
(p, request, c) -> request.setFetchSource(FetchSourceContext.fromXContent(p)),
SearchSourceBuilder._SOURCE_FIELD,
Expand All @@ -99,16 +99,21 @@ public static class Request extends ActionRequest implements IndicesRequest.Repl
SearchSourceBuilder.STORED_FIELDS_FIELD,
ObjectParser.ValueType.STRING_ARRAY
);
PARSER.declareInt(Request.Builder::setSize, SearchSourceBuilder.SIZE_FIELD);
}

public static Request parseRestRequest(RestRequest restRequest) throws IOException {
Builder builder = new Builder(Strings.splitStringByCommaToArray(restRequest.param("index")));
builder.setRouting(restRequest.param("routing"));
if (restRequest.hasContentOrSourceParam()) {
try (XContentParser contentParser = restRequest.contentOrSourceParamParser()) {
PARSER.parse(contentParser, builder, null);
}
}
// Query parameters are preferred to body parameters.
if (restRequest.hasParam("size")) {
builder.setSize(restRequest.paramAsInt("size", -1));
}
builder.setRouting(restRequest.param("routing"));
return builder.build();
}

Expand All @@ -117,13 +122,14 @@ public static Request parseRestRequest(RestRequest restRequest) throws IOExcepti
private final String queryString;
private final String modelId;
private final TimeValue inferenceTimeout;
private final QueryBuilder query;
private final KnnQueryOptions knnQueryOptions;
private final TextEmbeddingConfigUpdate embeddingConfig;
private final List<QueryBuilder> filters;
private final FetchSourceContext fetchSource;
private final List<FieldAndFormat> fields;
private final List<FieldAndFormat> docValueFields;
private final StoredFieldsContext storedFields;
private final int size;

public Request(StreamInput in) throws IOException {
super(in);
Expand All @@ -132,45 +138,44 @@ public Request(StreamInput in) throws IOException {
queryString = in.readString();
modelId = in.readString();
inferenceTimeout = in.readOptionalTimeValue();
query = in.readOptionalNamedWriteable(QueryBuilder.class);
knnQueryOptions = new KnnQueryOptions(in);
embeddingConfig = in.readOptionalWriteable(TextEmbeddingConfigUpdate::new);
if (in.readBoolean()) {
filters = in.readNamedWriteableList(QueryBuilder.class);
} else {
filters = null;
}
fetchSource = in.readOptionalWriteable(FetchSourceContext::readFrom);
fields = in.readOptionalList(FieldAndFormat::new);
docValueFields = in.readOptionalList(FieldAndFormat::new);
storedFields = in.readOptionalWriteable(StoredFieldsContext::new);
size = in.readInt();
}

Request(
String[] indices,
String routing,
String queryString,
String modelId,
QueryBuilder query,
KnnQueryOptions knnQueryOptions,
TextEmbeddingConfigUpdate embeddingConfig,
TimeValue inferenceTimeout,
List<QueryBuilder> filters,
FetchSourceContext fetchSource,
List<FieldAndFormat> fields,
List<FieldAndFormat> docValueFields,
StoredFieldsContext storedFields
StoredFieldsContext storedFields,
int size
) {
this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
this.routing = routing;
this.queryString = queryString;
this.modelId = modelId;
this.query = query;
this.knnQueryOptions = knnQueryOptions;
this.embeddingConfig = embeddingConfig;
this.inferenceTimeout = inferenceTimeout;
this.filters = filters;
this.fetchSource = fetchSource;
this.fields = fields;
this.docValueFields = docValueFields;
this.storedFields = storedFields;
this.size = size;
}

@Override
Expand All @@ -181,18 +186,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(queryString);
out.writeString(modelId);
out.writeOptionalTimeValue(inferenceTimeout);
out.writeOptionalNamedWriteable(query);
knnQueryOptions.writeTo(out);
out.writeOptionalWriteable(embeddingConfig);
if (filters != null) {
out.writeBoolean(true);
out.writeNamedWriteableList(filters);
} else {
out.writeBoolean(false);
}
out.writeOptionalWriteable(fetchSource);
out.writeOptionalCollection(fields);
out.writeOptionalCollection(docValueFields);
out.writeOptionalWriteable(storedFields);
out.writeInt(size);
}

@Override
Expand Down Expand Up @@ -231,6 +232,10 @@ public TimeValue getInferenceTimeout() {
return inferenceTimeout;
}

public QueryBuilder getQuery() {
return query;
}

public KnnQueryOptions getKnnQueryOptions() {
return knnQueryOptions;
}
Expand All @@ -239,10 +244,6 @@ public TextEmbeddingConfigUpdate getEmbeddingConfig() {
return embeddingConfig;
}

public List<QueryBuilder> getFilters() {
return filters;
}

public FetchSourceContext getFetchSource() {
return fetchSource;
}
Expand All @@ -259,6 +260,10 @@ public StoredFieldsContext getStoredFields() {
return storedFields;
}

public int getSize() {
return size;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand All @@ -269,13 +274,14 @@ public boolean equals(Object o) {
&& Objects.equals(queryString, request.queryString)
&& Objects.equals(modelId, request.modelId)
&& Objects.equals(inferenceTimeout, request.inferenceTimeout)
&& Objects.equals(query, request.query)
&& Objects.equals(knnQueryOptions, request.knnQueryOptions)
&& Objects.equals(embeddingConfig, request.embeddingConfig)
&& Objects.equals(filters, request.filters)
&& Objects.equals(fetchSource, request.fetchSource)
&& Objects.equals(fields, request.fields)
&& Objects.equals(docValueFields, request.docValueFields)
&& Objects.equals(storedFields, request.storedFields);
&& Objects.equals(storedFields, request.storedFields)
&& size == request.size;
}

@Override
Expand All @@ -285,13 +291,14 @@ public int hashCode() {
queryString,
modelId,
inferenceTimeout,
query,
knnQueryOptions,
embeddingConfig,
filters,
fetchSource,
fields,
docValueFields,
storedFields
storedFields,
size
);
result = 31 * result + Arrays.hashCode(indices);
return result;
Expand Down Expand Up @@ -321,12 +328,13 @@ public static class Builder {
private String queryString;
private TimeValue timeout;
private TextEmbeddingConfigUpdate update;
private QueryBuilder queryBuilder;
private KnnQueryOptions knnSearchBuilder;
private List<QueryBuilder> filters;
private FetchSourceContext fetchSource;
private List<FieldAndFormat> fields;
private List<FieldAndFormat> docValueFields;
private StoredFieldsContext storedFields;
private int size = -1;

Builder(String[] indices) {
this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
Expand Down Expand Up @@ -360,8 +368,8 @@ void setKnnSearch(KnnQueryOptions knnSearchBuilder) {
this.knnSearchBuilder = knnSearchBuilder;
}

private void setFilters(List<QueryBuilder> filters) {
this.filters = filters;
void setQueryBuilder(QueryBuilder queryBuilder) {
this.queryBuilder = queryBuilder;
}

private void setFetchSource(FetchSourceContext fetchSource) {
Expand All @@ -380,20 +388,25 @@ private void setStoredFields(StoredFieldsContext storedFields) {
this.storedFields = storedFields;
}

private void setSize(int size) {
this.size = size;
}

Request build() {
return new Request(
indices,
routing,
queryString,
modelId,
queryBuilder,
knnSearchBuilder,
update,
timeout,
filters,
fetchSource,
fields,
docValueFields,
storedFields
storedFields,
size
);
}
}
Expand Down Expand Up @@ -528,7 +541,12 @@ public KnnSearchBuilder toKnnSearchBuilder(float[] queryVector) {
if (queryVector == null) {
throw new IllegalStateException("[query_vector] not set on the Knn query");
}
return new KnnSearchBuilder(field, queryVector, k, numCands);
var builder = new KnnSearchBuilder(field, queryVector, k, numCands);
builder.boost(boost);
if (filterQueries.isEmpty() == false) {
builder.addFilterQueries(filterQueries);
}
return builder;
}

@Override
Expand Down

0 comments on commit b46ee9c

Please sign in to comment.