diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e71d603c..d097a218a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Added support for "script_fields" in multi search request ([#632](https://github.com/opensearch-project/opensearch-java/pull/632)) - Added size attribute to MultiTermsAggregation ([#627](https://github.com/opensearch-project/opensearch-java/pull/627)) - Added version increment workflow that executes after release ([#664](https://github.com/opensearch-project/opensearch-java/pull/664)) +- Added support for neural query type ([#674](https://github.com/opensearch-project/opensearch-java/pull/674)) ### Dependencies - Bumps `org.ajoberstar.grgit:grgit-gradle` from 5.0.0 to 5.2.0 diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java new file mode 100644 index 000000000..083200fe8 --- /dev/null +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java @@ -0,0 +1,200 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.client.opensearch._types.query_dsl; + +import jakarta.json.stream.JsonGenerator; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.opensearch.client.json.JsonpDeserializable; +import org.opensearch.client.json.JsonpDeserializer; +import org.opensearch.client.json.JsonpMapper; +import org.opensearch.client.json.ObjectBuilderDeserializer; +import org.opensearch.client.json.ObjectDeserializer; +import org.opensearch.client.util.ApiTypeHelper; +import org.opensearch.client.util.ObjectBuilder; + +@JsonpDeserializable +public class NeuralQuery extends QueryBase implements QueryVariant { + + private final String field; + private final String queryText; + private final int k; + @Nullable + private final String modelId; + + private NeuralQuery(NeuralQuery.Builder builder) { + super(builder); + + this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field"); + this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText"); + this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k"); + this.modelId = builder.modelId; + } + + public static NeuralQuery of(Function> fn) { + return fn.apply(new NeuralQuery.Builder()).build(); + } + + /** + * Query variant kind. + * + * @return The query variant kind. + */ + @Override + public Query.Kind _queryKind() { + return Query.Kind.Neural; + } + + /** + * Required - The target field. + * + * @return The target field. + */ + public final String field() { + return this.field; + } + + /** + * Required - Search query text. + * + * @return Search query text. + */ + public final String queryText() { + return this.queryText; + } + + /** + * Required - The number of neighbors to return. + * + * @return The number of neighbors to return. + */ + public final int k() { + return this.k; + } + + /** + * Builder for {@link NeuralQuery}. + */ + + /** + * Optional - The model_id field if the default model for the index or field is set. + * Required - The model_id field if there is no default model set for the index or field. + * + * @return The model_id field. + */ + @Nullable + public final String modelId() { + return this.modelId; + } + + @Override + protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { + generator.writeStartObject(this.field); + + super.serializeInternal(generator, mapper); + + generator.write("query_text", this.queryText); + + if (this.modelId != null) { + generator.write("model_id", this.modelId); + } + + generator.write("k", this.k); + + generator.writeEnd(); + } + + /** + * Builder for {@link NeuralQuery}. + */ + public static class Builder extends QueryBase.AbstractBuilder implements ObjectBuilder { + private String field; + private String queryText; + private Integer k; + @Nullable + private String modelId; + + /** + * Required - The target field. + * + * @param field The target field. + * @return This builder. + */ + public NeuralQuery.Builder field(@Nullable String field) { + this.field = field; + return this; + } + + /** + * Required - Search query text. + * + * @param queryText Search query text. + * @return This builder. + */ + public NeuralQuery.Builder queryText(@Nullable String queryText) { + this.queryText = queryText; + return this; + } + + /** + * Optional - The model_id field if the default model for the index or field is set. + * Required - The model_id field if there is no default model set for the index or field. + * + * @param modelId The model_id field. + * @return This builder. + */ + public NeuralQuery.Builder modelId(@Nullable String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Required - The number of neighbors to return. + * + * @param k The number of neighbors to return. + * @return This builder. + */ + public NeuralQuery.Builder k(@Nullable Integer k) { + this.k = k; + return this; + } + + @Override + protected NeuralQuery.Builder self() { + return this; + } + + /** + * Builds a {@link NeuralQuery}. + * + * @return The built {@link NeuralQuery}. + */ + @Override + public NeuralQuery build() { + _checkSingleUse(); + + return new NeuralQuery(this); + } + } + + public static final JsonpDeserializer _DESERIALIZER = ObjectBuilderDeserializer.lazy( + NeuralQuery.Builder::new, + NeuralQuery::setupNeuralQueryDeserializer + ); + + protected static void setupNeuralQueryDeserializer(ObjectDeserializer op) { + setupQueryBaseDeserializer(op); + + op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text"); + op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id"); + op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k"); + + op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer()); + } +} diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java index b1d39da0a..7413df359 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/Query.java @@ -118,6 +118,8 @@ public enum Kind implements JsonEnum { Nested("nested"), + Neural("neural"), + ParentId("parent_id"), Percolate("percolate"), @@ -706,6 +708,23 @@ public NestedQuery nested() { return TaggedUnionUtils.get(this, Kind.Nested); } + /** + * Is this variant instance of kind {@code neural}? + */ + public boolean isNeural() { + return _kind == Kind.Neural; + } + + /** + * Get the {@code neural} variant value. + * + * @throws IllegalStateException + * if the current variant is not of the {@code neural} kind. + */ + public NeuralQuery neural() { + return TaggedUnionUtils.get(this, Kind.Neural); + } + /** * Is this variant instance of kind {@code parent_id}? */ @@ -1450,6 +1469,16 @@ public ObjectBuilder nested(Function neural(NeuralQuery v) { + this._kind = Kind.Neural; + this._value = v; + return this; + } + + public ObjectBuilder neural(Function> fn) { + return this.neural(fn.apply(new NeuralQuery.Builder()).build()); + } + public ObjectBuilder parentId(ParentIdQuery v) { this._kind = Kind.ParentId; this._value = v; @@ -1747,6 +1776,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer op) { op.add(Builder::moreLikeThis, MoreLikeThisQuery._DESERIALIZER, "more_like_this"); op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match"); op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested"); + op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural"); op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id"); op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate"); op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned"); diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java index dc6db5051..f165ac706 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/QueryBuilders.java @@ -254,6 +254,13 @@ public static NestedQuery.Builder nested() { return new NestedQuery.Builder(); } + /** + * Creates a builder for the {@link NeuralQuery nested} {@code Query} variant. + */ + public static NeuralQuery.Builder neural() { + return new NeuralQuery.Builder(); + } + /** * Creates a builder for the {@link ParentIdQuery parent_id} {@code Query} * variant. diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java index b511f4ec1..e504050c1 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java @@ -205,4 +205,42 @@ public void testNestedVariantsWithContainerProperties() { assertEquals("m1 value", search.aggregations().get("agg1").meta().get("m1").to(String.class)); assertEquals("m2 value", search.aggregations().get("agg1").meta().get("m2").to(String.class)); } + + @Test + public void testNeuralQuery() { + + SearchRequest searchRequest = SearchRequest.of( + s -> s.query(q -> q.neural(n -> n.field("passage_embedding").queryText("Hi world").modelId("bQ1J8ooBpBj3wT4HVUsb").k(100))) + ); + + assertEquals("passage_embedding", searchRequest.query().neural().field()); + assertEquals("Hi world", searchRequest.query().neural().queryText()); + assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId()); + assertEquals(100, searchRequest.query().neural().k()); + } + + @Test + public void testNeuralQueryFromJson() { + + String json = "{\n" + + " \"from\": 0,\n" + + " \"size\": 100,\n" + + " \"query\": {\n" + + " \"neural\": {\n" + + " \"passage_embedding\": {\n" + + " \"query_text\": \"Hi world!\",\n" + + " \"model_id\": \"bQ1J8ooBpBj3wT4HVUsb\",\n" + + " \"k\": 100\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper); + + assertEquals("passage_embedding", searchRequest.query().neural().field()); + assertEquals("Hi world!", searchRequest.query().neural().queryText()); + assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId()); + assertEquals(100, searchRequest.query().neural().k()); + } }