Skip to content

Commit

Permalink
Add support for approximate k-NN queries (#548)
Browse files Browse the repository at this point in the history
* Add Knn query type

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Integration test

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Checkstyle fix

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Run unreleased test

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Fixes

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Assume knn plugin installed

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Changelog

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

* Add to QueryBuilders

Signed-off-by: Thomas Farr <tsfarr@amazon.com>

---------

Signed-off-by: Thomas Farr <tsfarr@amazon.com>
Signed-off-by: Thomas Farr <xtansia@xtansia.com>
(cherry picked from commit f427f27)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Jul 5, 2023
1 parent b540c98 commit 8feb0c8
Show file tree
Hide file tree
Showing 10 changed files with 404 additions and 2 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased]
### Added
- Add support for knn_vector field type ([#529](https://github.com/opensearch-project/opensearch-java/pull/524))
- Add support for knn_vector field type ([#524](https://github.com/opensearch-project/opensearch-java/pull/524))
- Add translog option object and missing translog sync interval option in index settings ([#518](https://github.com/opensearch-project/opensearch-java/pull/518))
- Adds the option to set slices=auto for UpdateByQueryRequest, DeleteByQueryRequest and ReindexRequest ([#538](https://github.com/opensearch-project/opensearch-java/pull/538))
- Add support for approximate k-NN queries ([#548](https://github.com/opensearch-project/opensearch-java/pull/548))

### Dependencies
- Bumps `com.github.jk1.dependency-license-report` from 2.2 to 2.4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public final Builder parameters(@Nullable Map<String, JsonData> map) {
return this;
}

/**
* API name: {@code parameters}
*/
public final Builder parameters(String key, JsonData value) {
this.parameters = _mapPut(this.parameters, key, value);
return this;
}

/**
* Builds a {@link KnnVectorMethod}.
*
Expand Down Expand Up @@ -194,10 +202,12 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

if (this.parameters != null) {
generator.writeKey("parameters");
generator.writeStartObject();
for (Map.Entry<String, JsonData> item0 : this.parameters.entrySet()) {
generator.writeKey(item0.getKey());
item0.getValue().serialize(generator, mapper);
}
generator.writeEnd();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ public final Builder method(@Nullable KnnVectorMethod value) {
return this;
}

/**
* API name: {@code method}
*/
public final Builder method(Function<KnnVectorMethod.Builder, ObjectBuilder<KnnVectorMethod>> fn) {
return this.method(fn.apply(new KnnVectorMethod.Builder()).build());
}

/**
* Builds a {@link KnnVectorProperty}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

@JsonpDeserializable
public class KnnQuery extends QueryBase implements QueryVariant {
private final String field;
private final float[] vector;
private final int k;
@Nullable
private final Query filter;

private KnnQuery(Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.filter = builder.filter;
}

public static KnnQuery of(Function<Builder, ObjectBuilder<KnnQuery>> fn) {
return fn.apply(new Builder()).build();
}

/**
* Query variant kind.
* @return The query variant kind.
*/
@Override
public Query.Kind _queryKind() {
return Query.Kind.Knn;
}

/**
* Required - The target field.
* @return The target field.
*/
public final String field() {
return this.field;
}

/**
* Required - The vector to search for.
* @return The vector to search for.
*/
public final float[] vector() {
return this.vector;
}

/**
* Required - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
public final int k() {
return this.k;
}

/**
* Optional - A query to filter the results of the query.
* @return The filter query.
*/
@Nullable
public final Query filter() {
return this.filter;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);

super.serializeInternal(generator, mapper);

// TODO: Implement the rest of the serialization.

generator.writeKey("vector");
generator.writeStartArray();
for (float value : this.vector) {
generator.write(value);
}
generator.writeEnd();

generator.write("k", this.k);

if (this.filter != null) {
generator.writeKey("filter");
this.filter.serialize(generator, mapper);
}

generator.writeEnd();
}

/**
* Builder for {@link KnnQuery}.
*/
public static class Builder extends QueryBase.AbstractBuilder<Builder> implements ObjectBuilder<KnnQuery> {
@Nullable
private String field;
@Nullable
private float[] vector;
@Nullable
private Integer k;
@Nullable
private Query filter;

/**
* Required - The target field.
* @param field The target field.
* @return This builder.
*/
public Builder field(@Nullable String field) {
this.field = field;
return this;
}

/**
* Required - The vector to search for.
*
* @param vector The vector to search for.
* @return This builder.
*/
public Builder vector(@Nullable float[] vector) {
this.vector = vector;
return this;
}

/**
* Required - The number of neighbors the search of each graph will return.
*
* @param k The number of neighbors to return.
* @return This builder.
*/
public Builder k(@Nullable Integer k) {
this.k = k;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
* @param filter The filter query.
* @return This builder.
*/
public Builder filter(@Nullable Query filter) {
this.filter = filter;
return this;
}

@Override
protected Builder self() {
return this;
}

/**
* Builds a {@link KnnQuery}.
*
* @return The built {@link KnnQuery}.
*/
@Override
public KnnQuery build() {
_checkSingleUse();

return new KnnQuery(this);
}
}

public static final JsonpDeserializer<KnnQuery> _DESERIALIZER = ObjectBuilderDeserializer
.lazy(Builder::new, KnnQuery::setupKnnQueryDeserializer);

protected static void setupKnnQueryDeserializer(ObjectDeserializer<Builder> op) {
setupQueryBaseDeserializer(op);
op.add((b, v) -> {
float[] vector = new float[v.size()];
int i = 0;
for (Float value : v) {
vector[i++] = value;
}
b.vector(vector);
}, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector");
op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(Builder::field, JsonpDeserializer.stringDeserializer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ public enum Kind implements JsonEnum {

Intervals("intervals"),

Knn("knn"),

Match("match"),

MatchAll("match_all"),
Expand Down Expand Up @@ -535,6 +537,23 @@ public IntervalsQuery intervals() {
return TaggedUnionUtils.get(this, Kind.Intervals);
}

/**
* Is this variant instance of kind {@code knn}?
*/
public boolean isKnn() {
return _kind == Kind.Knn;
}

/**
* Get the {@code knn} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code knn} kind.
*/
public KnnQuery knn() {
return TaggedUnionUtils.get(this, Kind.Knn);
}

/**
* Is this variant instance of kind {@code match}?
*/
Expand Down Expand Up @@ -1340,6 +1359,16 @@ public ObjectBuilder<Query> intervals(Function<IntervalsQuery.Builder, ObjectBui
return this.intervals(fn.apply(new IntervalsQuery.Builder()).build());
}

public ObjectBuilder<Query> knn(KnnQuery v) {
this._kind = Kind.Knn;
this._value = v;
return this;
}

public ObjectBuilder<Query> knn(Function<KnnQuery.Builder, ObjectBuilder<KnnQuery>> fn) {
return this.knn(fn.apply(new KnnQuery.Builder()).build());
}

public ObjectBuilder<Query> match(MatchQuery v) {
this._kind = Kind.Match;
this._value = v;
Expand Down Expand Up @@ -1728,6 +1757,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::hasParent, HasParentQuery._DESERIALIZER, "has_parent");
op.add(Builder::ids, IdsQuery._DESERIALIZER, "ids");
op.add(Builder::intervals, IntervalsQuery._DESERIALIZER, "intervals");
op.add(Builder::knn, KnnQuery._DESERIALIZER, "knn");
op.add(Builder::match, MatchQuery._DESERIALIZER, "match");
op.add(Builder::matchAll, MatchAllQuery._DESERIALIZER, "match_all");
op.add(Builder::matchBoolPrefix, MatchBoolPrefixQuery._DESERIALIZER, "match_bool_prefix");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ public static IntervalsQuery.Builder intervals() {
return new IntervalsQuery.Builder();
}

/**
* Creates a builder for the {@link KnnQuery knn} {@code Query} variant.
*/
public static KnnQuery.Builder knn() {
return new KnnQuery.Builder();
}

/**
* Creates a builder for the {@link MatchQuery match} {@code Query} variant.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

}
if (this.knnAlgoParamEfSearch != null) {
generator.writeKey("knn.algo_param_ef_search");
generator.writeKey("knn.algo_param.ef_search");
generator.write(this.knnAlgoParamEfSearch);

}
Expand Down
Loading

0 comments on commit 8feb0c8

Please sign in to comment.