Skip to content

Commit

Permalink
Add support neural query type
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill_Ostanin authored and Kirill_Ostanin committed Oct 18, 2023
1 parent 978a021 commit 0195446
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

import javax.annotation.Nullable;
import java.util.function.Function;

@JsonpDeserializable
public class NeuralQuery extends QueryBase implements QueryVariant {

private final String field;
private final String queryText;
private final int k;
@Nullable
private final String modelId;


private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
}

public static NeuralQuery of(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
return fn.apply(new NeuralQuery.Builder()).build();
}

/**
* Query variant kind.
*
* @return The query variant kind.
*/
@Override
public Query.Kind _queryKind() {
return null;
}

/**
* Required - The target field.
*
* @return The target field.
*/
public final String field() {
return this.field;
}

/**
* Required - Search query text.
*
* @return Search query text.
*/
public final String queryText() {
return this.queryText;
}

/**
* Required - The number of neighbors to return.
*
* @return The number of neighbors to return.
*/
public final int k() {
return this.k;
}

/**
* Builder for {@link NeuralQuery}.
*/

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
*
* @return The model_id field.
*/
@Nullable
public final String modelId() {
return this.modelId;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);

super.serializeInternal(generator, mapper);

generator.write("query_text", this.queryText);

if (this.modelId != null) {
generator.write("model_id", this.modelId);
}

generator.write("k", this.k);

generator.writeEnd();
}

/**
* Builder for {@link NeuralQuery}.
*/
public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builder> implements ObjectBuilder<NeuralQuery> {
@Nullable
private String field;
@Nullable
private String queryText;
@Nullable
private String modelId;
@Nullable
private Integer k;

/**
* Required - The target field.
*
* @param field The target field.
* @return This builder.
*/
public NeuralQuery.Builder field(@Nullable String field) {
this.field = field;
return this;
}

/**
* Required - Search query text.
*
* @param queryText Search query text.
* @return This builder.
*/
public NeuralQuery.Builder queryText(@Nullable String queryText) {
this.queryText = queryText;
return this;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
*
* @param modelId The model_id field.
* @return This builder.
*/
public NeuralQuery.Builder modelId(@Nullable String modelId) {
this.modelId = modelId;
return this;
}

/**
* Required - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
*/
public NeuralQuery.Builder k(@Nullable Integer k) {
this.k = k;
return this;
}

@Override
protected NeuralQuery.Builder self() {
return this;
}

/**
* Builds a {@link NeuralQuery}.
*
* @return The built {@link NeuralQuery}.
*/
@Override
public NeuralQuery build() {
_checkSingleUse();

return new NeuralQuery(this);
}
}

public static final JsonpDeserializer<NeuralQuery> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
NeuralQuery.Builder::new,
NeuralQuery::setupNeuralQueryDeserializer
);

protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuery.Builder> op) {
setupQueryBaseDeserializer(op);

op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ public enum Kind implements JsonEnum {

Nested("nested"),

Neural("neural"),

ParentId("parent_id"),

Percolate("percolate"),
Expand Down Expand Up @@ -706,6 +708,23 @@ public NestedQuery nested() {
return TaggedUnionUtils.get(this, Kind.Nested);
}

/**
* Is this variant instance of kind {@code neural}?
*/
public boolean isNeural() {
return _kind == Kind.Neural;
}

/**
* Get the {@code neural} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code neural} kind.
*/
public NeuralQuery neural() {
return TaggedUnionUtils.get(this, Kind.Neural);
}

/**
* Is this variant instance of kind {@code parent_id}?
*/
Expand Down Expand Up @@ -1450,6 +1469,16 @@ public ObjectBuilder<Query> nested(Function<NestedQuery.Builder, ObjectBuilder<N
return this.nested(fn.apply(new NestedQuery.Builder()).build());
}

public ObjectBuilder<Query> neural(NeuralQuery v) {
this._kind = Kind.Neural;
this._value = v;
return this;
}

public ObjectBuilder<Query> neural(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
return this.neural(fn.apply(new NeuralQuery.Builder()).build());
}

public ObjectBuilder<Query> parentId(ParentIdQuery v) {
this._kind = Kind.ParentId;
this._value = v;
Expand Down Expand Up @@ -1747,6 +1776,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::moreLikeThis, MoreLikeThisQuery._DESERIALIZER, "more_like_this");
op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match");
op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested");
op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural");
op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id");
op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate");
op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ public static NestedQuery.Builder nested() {
return new NestedQuery.Builder();
}

/**
* Creates a builder for the {@link NeuralQuery nested} {@code Query} variant.
*/
public static NeuralQuery.Builder neural() {
return new NeuralQuery.Builder();
}

/**
* Creates a builder for the {@link ParentIdQuery parent_id} {@code Query}
* variant.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,48 @@ public void testNestedVariantsWithContainerProperties() {
assertEquals("m1 value", search.aggregations().get("agg1").meta().get("m1").to(String.class));
assertEquals("m2 value", search.aggregations().get("agg1").meta().get("m2").to(String.class));
}

@Test
public void testNeuralQuery() {

SearchRequest searchRequest = SearchRequest.of(
s -> s.query(
q -> q.neural(n -> n.field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
)
)
);

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}

@Test
public void testNeuralQueryFromJson() {

String json = "{\n" +
" \"from\": 0,\n" +
" \"size\": 100,\n" +
" \"query\": {\n" +
" \"neural\": {\n" +
" \"passage_embedding\": {\n" +
" \"query_text\": \"Hi world\",\n" +
" \"model_id\": \"bQ1J8ooBpBj3wT4HVUsb\",\n" +
" \"k\": 100\n" +
" }\n" +
" }\n" +
" }\n" +
"}";

SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper);

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}
}

0 comments on commit 0195446

Please sign in to comment.