From 2b88ff8b950e33ddf262c2dd6feafb7e35d1e5b1 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 24 Nov 2025 13:14:52 -0500 Subject: [PATCH] Use Wrapped Action Listeners in ShardBulkInferenceActionFilter (#138505) --- .../ShardBulkInferenceActionFilter.java | 202 +++++++++--------- 1 file changed, 105 insertions(+), 97 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index aef7b9089992b..a0be296f286b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -50,6 +50,8 @@ import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; @@ -92,6 +94,8 @@ * */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1); /** @@ -325,127 +329,131 @@ private void executeChunkedInferenceAsync( final Releasable onFinish ) { if (inferenceProvider == null) { - ActionListener modelLoadingListener = new ActionListener<>() { - @Override - public void onResponse(UnparsedModel unparsedModel) { - var service = inferenceServiceRegistry.getService(unparsedModel.service()); - if (service.isEmpty() == false) { - var provider = new InferenceProvider( - service.get(), - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ) - ); - executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); - } else { - try (onFinish) { - for (FieldInferenceRequest request : requests) { - inferenceResults.get(request.bulkItemIndex).failures.add( - new ResourceNotFoundException( - "Inference service [{}] not found for field [{}]", - unparsedModel.service(), - request.field - ) - ); - } - } - } - } - - @Override - public void onFailure(Exception exc) { + ActionListener modelLoadingListener = ActionListener.wrap(unparsedModel -> { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); + } else { try (onFinish) { for (FieldInferenceRequest request : requests) { - Exception failure; - if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { - failure = new ResourceNotFoundException( - "Inference id [{}] not found for field [{}]", - inferenceId, + inferenceResults.get(request.bulkItemIndex).failures.add( + new ResourceNotFoundException( + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), request.field - ); - } else { - failure = new InferenceException( - "Error loading inference for inference id [{}] on field [{}]", - exc, - inferenceId, - request.field - ); - } - inferenceResults.get(request.bulkItemIndex).failures.add(failure); + ) + ); } } } - }; - modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); - return; - } - final List inputs = requests.stream() - .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) - .collect(Collectors.toList()); - - ActionListener> completionListener = new ActionListener<>() { - - @Override - public void onResponse(List results) { + }, exc -> { try (onFinish) { - var requestsIterator = requests.iterator(); - int success = 0; - for (ChunkedInference result : results) { - var request = requestsIterator.next(); - var acc = inferenceResults.get(request.bulkItemIndex); - if (result instanceof ChunkedInferenceError error) { - recordRequestCountMetrics(inferenceProvider.model, 1, error.exception()); - acc.addFailure( - new InferenceException( - "Exception when running inference id [{}] on field [{}]", - error.exception(), - inferenceProvider.model.getInferenceEntityId(), - request.field - ) + for (FieldInferenceRequest request : requests) { + Exception failure; + if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { + failure = new ResourceNotFoundException( + "Inference id [{}] not found for field [{}]", + inferenceId, + request.field ); } else { - success++; - acc.addOrUpdateResponse( - new FieldInferenceResponse( - request.field(), - request.sourceField(), - useLegacyFormat ? request.input() : null, - request.inputOrder(), - request.offsetAdjustment(), - inferenceProvider.model, - result - ) + failure = new InferenceException( + "Error loading inference for inference id [{}] on field [{}]", + exc, + inferenceId, + request.field ); } + inferenceResults.get(request.bulkItemIndex).failures.add(failure); } - if (success > 0) { - recordRequestCountMetrics(inferenceProvider.model, success, null); + + if (ExceptionsHelper.status(exc).getStatus() >= 500) { + List fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList(); + logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc); } } - } + }); + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + final List inputs = requests.stream() + .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) + .collect(Collectors.toList()); - @Override - public void onFailure(Exception exc) { - try (onFinish) { - recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc); - for (FieldInferenceRequest request : requests) { - addInferenceResponseFailure( - request.bulkItemIndex, + ActionListener> completionListener = ActionListener.wrap(results -> { + try (onFinish) { + var requestsIterator = requests.iterator(); + int success = 0; + for (ChunkedInference result : results) { + var request = requestsIterator.next(); + var acc = inferenceResults.get(request.bulkItemIndex); + if (result instanceof ChunkedInferenceError error) { + recordRequestCountMetrics(inferenceProvider.model, 1, error.exception()); + acc.addFailure( new InferenceException( "Exception when running inference id [{}] on field [{}]", - exc, + error.exception(), inferenceProvider.model.getInferenceEntityId(), request.field ) ); + } else { + success++; + acc.addOrUpdateResponse( + new FieldInferenceResponse( + request.field(), + request.sourceField(), + useLegacyFormat ? request.input() : null, + request.inputOrder(), + request.offsetAdjustment(), + inferenceProvider.model, + result + ) + ); } } + if (success > 0) { + recordRequestCountMetrics(inferenceProvider.model, success, null); + } } - }; + }, exc -> { + try (onFinish) { + recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc); + for (FieldInferenceRequest request : requests) { + addInferenceResponseFailure( + request.bulkItemIndex, + new InferenceException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + + if (ExceptionsHelper.status(exc).getStatus() >= 500) { + List fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList(); + logger.error( + "Exception when running inference id [" + + inferenceProvider.model.getInferenceEntityId() + + "] on fields " + + fields, + exc + ); + } + } + }); + inferenceProvider.service() .chunkedInfer( inferenceProvider.model(),