Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
Expand Down Expand Up @@ -92,6 +94,8 @@
*
*/
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class);

private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);

/**
Expand Down Expand Up @@ -325,127 +329,131 @@ private void executeChunkedInferenceAsync(
final Releasable onFinish
) {
if (inferenceProvider == null) {
ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() {
@Override
public void onResponse(UnparsedModel unparsedModel) {
var service = inferenceServiceRegistry.getService(unparsedModel.service());
if (service.isEmpty() == false) {
var provider = new InferenceProvider(
service.get(),
service.get()
.parsePersistedConfigWithSecrets(
inferenceId,
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
)
);
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
} else {
try (onFinish) {
for (FieldInferenceRequest request : requests) {
inferenceResults.get(request.bulkItemIndex).failures.add(
new ResourceNotFoundException(
"Inference service [{}] not found for field [{}]",
unparsedModel.service(),
request.field
)
);
}
}
}
}

@Override
public void onFailure(Exception exc) {
ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> {
var service = inferenceServiceRegistry.getService(unparsedModel.service());
if (service.isEmpty() == false) {
var provider = new InferenceProvider(
service.get(),
service.get()
.parsePersistedConfigWithSecrets(
inferenceId,
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
)
);
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
} else {
try (onFinish) {
for (FieldInferenceRequest request : requests) {
Exception failure;
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
failure = new ResourceNotFoundException(
"Inference id [{}] not found for field [{}]",
inferenceId,
inferenceResults.get(request.bulkItemIndex).failures.add(
new ResourceNotFoundException(
"Inference service [{}] not found for field [{}]",
unparsedModel.service(),
request.field
);
} else {
failure = new InferenceException(
"Error loading inference for inference id [{}] on field [{}]",
exc,
inferenceId,
request.field
);
}
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
)
);
}
}
}
};
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
return;
}
final List<ChunkInferenceInput> inputs = requests.stream()
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
.collect(Collectors.toList());

ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {

@Override
public void onResponse(List<ChunkedInference> results) {
}, exc -> {
try (onFinish) {
var requestsIterator = requests.iterator();
int success = 0;
for (ChunkedInference result : results) {
var request = requestsIterator.next();
var acc = inferenceResults.get(request.bulkItemIndex);
if (result instanceof ChunkedInferenceError error) {
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
acc.addFailure(
new InferenceException(
"Exception when running inference id [{}] on field [{}]",
error.exception(),
inferenceProvider.model.getInferenceEntityId(),
request.field
)
for (FieldInferenceRequest request : requests) {
Exception failure;
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
failure = new ResourceNotFoundException(
"Inference id [{}] not found for field [{}]",
inferenceId,
request.field
);
} else {
success++;
acc.addOrUpdateResponse(
new FieldInferenceResponse(
request.field(),
request.sourceField(),
useLegacyFormat ? request.input() : null,
request.inputOrder(),
request.offsetAdjustment(),
inferenceProvider.model,
result
)
failure = new InferenceException(
"Error loading inference for inference id [{}] on field [{}]",
exc,
inferenceId,
request.field
);
}
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
}
if (success > 0) {
recordRequestCountMetrics(inferenceProvider.model, success, null);

if (ExceptionsHelper.status(exc).getStatus() >= 500) {
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc);
}
}
}
});
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
return;
}
final List<ChunkInferenceInput> inputs = requests.stream()
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
.collect(Collectors.toList());

@Override
public void onFailure(Exception exc) {
try (onFinish) {
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
for (FieldInferenceRequest request : requests) {
addInferenceResponseFailure(
request.bulkItemIndex,
ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> {
try (onFinish) {
var requestsIterator = requests.iterator();
int success = 0;
for (ChunkedInference result : results) {
var request = requestsIterator.next();
var acc = inferenceResults.get(request.bulkItemIndex);
if (result instanceof ChunkedInferenceError error) {
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
acc.addFailure(
new InferenceException(
"Exception when running inference id [{}] on field [{}]",
exc,
error.exception(),
inferenceProvider.model.getInferenceEntityId(),
request.field
)
);
} else {
success++;
acc.addOrUpdateResponse(
new FieldInferenceResponse(
request.field(),
request.sourceField(),
useLegacyFormat ? request.input() : null,
request.inputOrder(),
request.offsetAdjustment(),
inferenceProvider.model,
result
)
);
}
}
if (success > 0) {
recordRequestCountMetrics(inferenceProvider.model, success, null);
}
}
};
}, exc -> {
try (onFinish) {
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
for (FieldInferenceRequest request : requests) {
addInferenceResponseFailure(
request.bulkItemIndex,
new InferenceException(
"Exception when running inference id [{}] on field [{}]",
exc,
inferenceProvider.model.getInferenceEntityId(),
request.field
)
);
}

if (ExceptionsHelper.status(exc).getStatus() >= 500) {
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
logger.error(
"Exception when running inference id ["
+ inferenceProvider.model.getInferenceEntityId()
+ "] on fields "
+ fields,
exc
);
}
}
});

inferenceProvider.service()
.chunkedInfer(
inferenceProvider.model(),
Expand Down