Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/138632.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138632
summary: Correctly handle empty inputs in `chunkedInfer()`
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,18 @@ public void chunkedInfer(
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
if (supportsChunkedInfer()) {
if (input.isEmpty()) {
chunkedInferListener.onResponse(List.of());
} else {
// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
}
} else {
chunkedInferListener.onFailure(
new UnsupportedOperationException(Strings.format("%s service does not support chunked inference", name()))
);
}
}).addListener(listener);
}

Expand Down Expand Up @@ -183,6 +192,10 @@ protected abstract void doChunkedInfer(
ActionListener<List<ChunkedInference>> listener
);

protected boolean supportsChunkedInfer() {
return true;
}

public void start(Model model, ActionListener<Boolean> listener) {
SubscribableListener.newForked(this::init)
.<Boolean>andThen((doStartListener) -> doStart(model, doStartListener))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
throw new UnsupportedOperationException("AI21 service does not support chunked inference");
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
throw new UnsupportedOperationException("Anthropic service does not support chunked inference");
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_15_0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
listener.onFailure(new ElasticsearchStatusException("Chunked inference is not supported for rerank task", RestStatus.BAD_REQUEST));
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
protected void doUnifiedCompletionInfer(
Model model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME)));
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
public String name() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ public void chunkedInfer(
listener.onFailure(createInvalidModelException(model));
return;
}
if (input.isEmpty()) {
listener.onResponse(List.of());
}
try {
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
var batchedRequests = new EmbeddingRequestChunker<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -491,6 +492,27 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx
testChunkedInfer(TaskType.SPARSE_EMBEDDING, null);
}

public void testChunkedInfer_noInputs() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
var model = createModelForTaskType(randomFrom(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING), null);

service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputTypeTests.randomWithIngestAndSearch(),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

}
assertThat(listener.actionGet(TIMEOUT), empty());
}

private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException {
var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -1323,6 +1324,50 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException {
var model = AmazonBedrockEmbeddingsModelTests.createModel(
"id",
"region",
"model",
AmazonBedrockProvider.AMAZONTITAN,
null,
"access",
"secret"
);

var sender = createMockSender();
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
mockClusterServiceEmpty()
);

try (
var service = new AmazonBedrockService(
factory,
amazonBedrockFactory,
createWithEmptySettings(threadPool),
mockClusterServiceEmpty()
)
) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
}
}

private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException {
var sender = createMockSender();
var factory = mock(HttpRequestSender.Factory.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -1294,6 +1295,27 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException {
var model = AzureAiStudioEmbeddingsModelTests.createModel(
"id",
getUrl(webServer),
AzureAiStudioProvider.OPENAI,
AzureAiStudioEndpointType.TOKEN,
"apikey"
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
List<ChunkInferenceInput> input = List.of();
service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}

private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -953,6 +954,32 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException, URISyntaxException {
var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", null, "apikey", null, "id");

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {

model.setUri(new URI(getUrl(webServer)));
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
List<ChunkInferenceInput> input = List.of();
service.chunkedInfer(
model,
null,
input,
new HashMap<>(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}

private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -1377,6 +1378,28 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException {
var model = CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1024, "model", null);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}

private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE;
import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH;
import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -822,4 +823,29 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
assertThat(requestMap.get("input"), is(List.of("a")));
}
}

public void testChunkedInfer_noInputs() throws IOException {
var model = createInternalEmbeddingModel(
new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);

try (var service = createService(threadPool, clientManager)) {

PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand Down Expand Up @@ -349,12 +351,21 @@ public void testUnifiedCompletionMalformedError() throws Exception {
}}""");
}

public void testDoChunkedInferAlwaysFails() throws IOException {
public void testChunkedInferFails() throws IOException {
try (var service = createService()) {
service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> {
assertThat(e, isA(UnsupportedOperationException.class));
assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion"));
}));
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(mock(), null, List.of(new ChunkInferenceInput("a")), Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
var exception = expectThrows(UnsupportedOperationException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("deepseek service does not support chunked inference"));
}
}

public void testChunkedInferFails_noInputs() throws IOException {
try (var service = createService()) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(mock(), null, List.of(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
var exception = expectThrows(UnsupportedOperationException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("deepseek service does not support chunked inference"));
}
}

Expand Down
Loading