From bef0079740173bc564c7ffc1ae33ed64253ce91e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 10 Oct 2025 17:26:53 -0400 Subject: [PATCH 01/18] Adding bulk storage of multiple models --- .../inference/registry/ModelRegistry.java | 303 ++++++++++------ .../registry/ModelRegistryMetadata.java | 25 ++ ...nferenceServiceAuthorizationHandlerV2.java | 339 ++++++++++++++++++ 3 files changed, 553 insertions(+), 114 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7cd1cf5999d11..8858522df20d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -18,9 +18,9 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.GroupedActionListener; @@ -45,6 +45,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.gateway.GatewayService; @@ -654,154 +655,197 @@ public void storeModel(Model model, ActionListener listener, TimeValue } private void storeModel(Model model, boolean updateClusterState, ActionListener listener, TimeValue timeout) { - ActionListener bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout); + storeModels(List.of(model), updateClusterState, listener.delegateFailureAndWrap((delegate, responses) -> { + var firstFailureResponse = responses.stream().filter(ModelResponse::failed).findFirst(); + if (firstFailureResponse.isPresent() == false) { + delegate.onResponse(Boolean.TRUE); + return; + } - IndexRequest configRequest = createIndexRequest( - Model.documentId(model.getConfigurations().getInferenceEntityId()), - InferenceIndex.INDEX_NAME, - model.getConfigurations(), - false - ); + var failureItem = firstFailureResponse.get(); + if (ExceptionsHelper.unwrapCause(failureItem.failureCause()) instanceof VersionConflictEngineException) { + // TODO do we want to include the cause? + delegate.onFailure( + new ElasticsearchStatusException( + "Inference endpoint [{}] already exists", + RestStatus.BAD_REQUEST, + failureItem.failureCause, + failureItem.inferenceId + ) + ); + return; + } - IndexRequest secretsRequest = createIndexRequest( - Model.documentId(model.getConfigurations().getInferenceEntityId()), - InferenceSecretsIndex.INDEX_NAME, - model.getSecrets(), - false - ); + delegate.onFailure( + new ElasticsearchStatusException( + format("Failed to store inference endpoint [%s]", failureItem.inferenceId), + RestStatus.INTERNAL_SERVER_ERROR, + failureItem.failureCause() + ) + ); + }), timeout); + } - client.prepareBulk() - .add(configRequest) - .add(secretsRequest) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .execute(bulkResponseActionListener); + // TODO rename + public record ModelResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) { + public boolean failed() { + return failureCause != null; + } } - private ActionListener getStoreIndexListener( - Model model, + public void storeModels( + List models, boolean updateClusterState, - ActionListener listener, + ActionListener> listener, TimeValue timeout ) { - return ActionListener.wrap(bulkItemResponses -> { - var inferenceEntityId = model.getConfigurations().getInferenceEntityId(); + var bulkRequestBuilder = client.prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (var model : models) { + bulkRequestBuilder.add( + createIndexRequestBuilder(model.getInferenceEntityId(), InferenceIndex.INDEX_NAME, model.getConfigurations(), false) + ); + + bulkRequestBuilder.add( + createIndexRequestBuilder(model.getInferenceEntityId(), InferenceSecretsIndex.INDEX_NAME, model.getSecrets(), false) + ); + } + bulkRequestBuilder.execute(getStoreMultipleModelsListener(models, updateClusterState, listener, timeout)); + } + + private ActionListener getStoreMultipleModelsListener( + List models, + boolean updateClusterState, + ActionListener> listener, + TimeValue timeout + ) { + var docIdToInferenceId = models.stream() + .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId)); + var inferenceIdToModel = models.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity())); + + return ActionListener.wrap(bulkItemResponses -> { + var inferenceEntityIds = String.join(", ", models.stream().map(Model::getInferenceEntityId).toList()); if (bulkItemResponses.getItems().length == 0) { - logger.warn( - format("Storing inference endpoint [%s] failed, no items were received from the bulk response", inferenceEntityId) - ); + logger.warn("Storing inference endpoints [{}] failed, no items were received from the bulk response", inferenceEntityIds); listener.onFailure( new ElasticsearchStatusException( - format( - "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service", - inferenceEntityId - ), - RestStatus.INTERNAL_SERVER_ERROR + "Failed to store inference endpoints [{}], empty bulk response received.", + RestStatus.INTERNAL_SERVER_ERROR, + inferenceEntityIds ) ); return; } - BulkItemResponse.Failure failure = getFirstBulkFailure(bulkItemResponses); - - if (failure == null) { - if (updateClusterState) { - var storeListener = getStoreMetadataListener(inferenceEntityId, listener); - try { - metadataTaskQueue.submitTask( - "add model [" + inferenceEntityId + "]", - new AddModelMetadataTask( - ProjectId.DEFAULT, - inferenceEntityId, - new MinimalServiceSettings(model), - storeListener - ), - timeout - ); - } catch (Exception exc) { - storeListener.onFailure(exc); - } - } else { - listener.onResponse(Boolean.TRUE); - } - return; - } + var responseInfo = getResponseInfo(bulkItemResponses, docIdToInferenceId, inferenceIdToModel); - logBulkFailures(model.getConfigurations().getInferenceEntityId(), bulkItemResponses); - - if (ExceptionsHelper.unwrapCause(failure.getCause()) instanceof VersionConflictEngineException) { - listener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId)); - return; + if (updateClusterState) { + updateClusterState( + responseInfo.successfullyStoredModels, + listener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.response)), + timeout + ); + } else { + listener.onResponse(responseInfo.response); } - - listener.onFailure( - new ElasticsearchStatusException( - format("Failed to store inference endpoint [%s]", inferenceEntityId), - RestStatus.INTERNAL_SERVER_ERROR, - failure.getCause() - ) - ); }, e -> { - String errorMessage = format("Failed to store inference endpoint [%s]", model.getConfigurations().getInferenceEntityId()); + String errorMessage = format( + "Failed to store inference endpoints [%s]", + models.stream().map(Model::getInferenceEntityId).collect(Collectors.joining(", ")) + ); logger.warn(errorMessage, e); listener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); }); } - private ActionListener getStoreMetadataListener(String inferenceEntityId, ActionListener listener) { - return new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse resp) { - listener.onResponse(true); - } + // TODO rename + private record ResponseInfo(List response, List successfullyStoredModels) {} - @Override - public void onFailure(Exception exc) { - logger.warn( - format("Failed to add inference endpoint [%s] minimal service settings to cluster state", inferenceEntityId), - exc - ); - deleteModel(inferenceEntityId, ActionListener.running(() -> { - listener.onFailure( - new ElasticsearchStatusException( - format( - "Failed to add the inference endpoint [%s]. The service may be in an " - + "inconsistent state. Please try deleting and re-adding the endpoint.", - inferenceEntityId - ), - RestStatus.INTERNAL_SERVER_ERROR, - exc - ) - ); - })); + // TODO rename + private static ResponseInfo getResponseInfo( + BulkResponse bulkResponse, + Map docIdToInferenceId, + Map inferenceIdToModel + ) { + var response = new ArrayList(); + var modelsSuccessfullyStored = new ArrayList(); + + for (var item : bulkResponse.getItems()) { + var failure = item.getFailure(); + + String inferenceIdOrUnknown = "unknown"; + var inferenceIdMaybeNull = docIdToInferenceId.get(item.getId()); + if (inferenceIdMaybeNull == null) { + logger.warn("Failed to find inference id for document id [{}]", item.getId()); + } else { + inferenceIdOrUnknown = inferenceIdMaybeNull; } - }; - } - private static void logBulkFailures(String inferenceEntityId, BulkResponse bulkResponse) { - for (BulkItemResponse item : bulkResponse.getItems()) { - if (item.isFailed()) { + if (item.isFailed() && failure != null) { + response.add(new ModelResponse(inferenceIdOrUnknown, item.status(), failure.getCause())); logger.warn( format( - "Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]", - inferenceEntityId, + "Failed to store document id: [%s] inference id: [%s] index: [%s] bulk failure message [%s]", + item.getId(), + inferenceIdOrUnknown, item.getIndex(), item.getFailureMessage() ) ); + } else { + response.add(new ModelResponse(inferenceIdOrUnknown, item.status(), null)); + + if (inferenceIdMaybeNull != null) { + var modelForResponseItem = inferenceIdToModel.get(inferenceIdMaybeNull); + if (modelForResponseItem != null) { + modelsSuccessfullyStored.add(modelForResponseItem); + } + } } } + + return new ResponseInfo(response, modelsSuccessfullyStored); } - private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkResponse) { - for (BulkItemResponse item : bulkResponse.getItems()) { - if (item.isFailed()) { - return item.getFailure(); - } - } + private void updateClusterState(List models, ActionListener listener, TimeValue timeout) { + var inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet()); + var storeListener = listener.delegateResponse((delegate, exc) -> { + logger.warn(format("Failed to add inference endpoint %s minimal service settings to cluster state", inferenceIdsSet), exc); + deleteModels( + inferenceIdsSet, + ActionListener.running( + () -> delegate.onFailure( + new ElasticsearchStatusException( + format( + "Failed to add the inference endpoints %s. The service may be in an " + + "inconsistent state. Please try deleting and re-adding the endpoint.", + inferenceIdsSet + ), + RestStatus.INTERNAL_SERVER_ERROR, + exc + ) + ) + ) + ); + }); - return null; + try { + metadataTaskQueue.submitTask( + format("add models %s", inferenceIdsSet), + new AddModelMetadataTask( + ProjectId.DEFAULT, + models.stream() + .map(model -> new ModelAndSettings(model.getInferenceEntityId(), new MinimalServiceSettings(model))) + .toList(), + storeListener + ), + timeout + ); + } catch (Exception exc) { + storeListener.onFailure(exc); + } } public synchronized void removeDefaultConfigs(Set inferenceEntityIds, ActionListener listener) { @@ -922,6 +966,32 @@ private static DeleteByQueryRequest createDeleteRequest(Set inferenceEnt return request; } + private IndexRequestBuilder createIndexRequestBuilder( + String inferenceId, + String indexName, + ToXContentObject body, + boolean allowOverwriting + ) { + try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) { + XContentBuilder source = body.toXContent( + xContentBuilder, + new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())) + ); + + return new IndexRequestBuilder(client).setIndex(indexName) + .setCreate(allowOverwriting == false) + .setId(Model.documentId(inferenceId)) + .setSource(source); + } catch (IOException ex) { + throw new ElasticsearchException( + "Unexpected serialization exception for index [{}] inference ID [{}]", + ex, + indexName, + inferenceId + ); + } + } + private static IndexRequest createIndexRequest(String docId, String indexName, ToXContentObject body, boolean allowOverwriting) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { var request = new IndexRequest(indexName); @@ -1082,9 +1152,10 @@ ModelRegistryMetadata executeTask(ModelRegistryMetadata current) { } } + public record ModelAndSettings(String inferenceEntityId, MinimalServiceSettings settings) {} + private static class AddModelMetadataTask extends MetadataTask { - private final String inferenceEntityId; - private final MinimalServiceSettings settings; + private final List models = new ArrayList<>(); AddModelMetadataTask( ProjectId projectId, @@ -1093,13 +1164,17 @@ private static class AddModelMetadataTask extends MetadataTask { ActionListener listener ) { super(projectId, listener); - this.inferenceEntityId = inferenceEntityId; - this.settings = settings; + this.models.add(new ModelAndSettings(inferenceEntityId, settings)); + } + + AddModelMetadataTask(ProjectId projectId, List models, ActionListener listener) { + super(projectId, listener); + this.models.addAll(models); } @Override ModelRegistryMetadata executeTask(ModelRegistryMetadata current) { - return current.withAddedModel(inferenceEntityId, settings); + return current.withAddedModels(models); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 7f67b876bcad5..6b13f8a2e30cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -107,6 +107,31 @@ public ModelRegistryMetadata withAddedModel(String inferenceEntityId, MinimalSer return new ModelRegistryMetadata(settingsBuilder.build(), newTombstone); } + public ModelRegistryMetadata withAddedModels(List models) { + var modifiedMap = false; + ImmutableOpenMap.Builder settingsBuilder = ImmutableOpenMap.builder(modelMap); + + for (var model : models) { + if (model.settings().equals(modelMap.get(model.inferenceEntityId())) == false) { + modifiedMap = true; + + settingsBuilder.fPut(model.inferenceEntityId(), model.settings()); + } + } + + if (modifiedMap == false) { + return this; + } + + if (isUpgraded) { + return new ModelRegistryMetadata(settingsBuilder.build()); + } + + var newTombstone = new HashSet<>(tombstones); + models.forEach(existing -> newTombstone.remove(existing.inferenceEntityId())); + return new ModelRegistryMetadata(settingsBuilder.build(), newTombstone); + } + public ModelRegistryMetadata withRemovedModel(Set inferenceEntityIds) { var mapBuilder = ImmutableOpenMap.builder(modelMap); for (var toDelete : inferenceEntityIds) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java new file mode 100644 index 0000000000000..59501fdc45170 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java @@ -0,0 +1,339 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +public class ElasticInferenceServiceAuthorizationHandlerV2 implements Closeable { + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandlerV2.class); + + private record AuthorizedContent( + ElasticInferenceServiceAuthorizationModel taskTypesAndModels, + List configIds, + List defaultModelConfigs + ) { + static AuthorizedContent empty() { + return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); + } + } + + private final ServiceComponents serviceComponents; + private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); + private final ModelRegistry modelRegistry; + private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; + private final Map defaultModelsConfigs; + private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); + private final EnumSet implementedTaskTypes; + private final InferenceService inferenceService; + private final Sender sender; + private final Runnable callback; + private final AtomicReference lastAuthTask = new AtomicReference<>(null); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; + + public ElasticInferenceServiceAuthorizationHandlerV2( + ServiceComponents serviceComponents, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Map defaultModelsConfigs, + EnumSet implementedTaskTypes, + InferenceService inferenceService, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings + ) { + this( + serviceComponents, + modelRegistry, + authorizationRequestHandler, + defaultModelsConfigs, + implementedTaskTypes, + Objects.requireNonNull(inferenceService), + sender, + elasticInferenceServiceSettings, + null + ); + } + + // default for testing + ElasticInferenceServiceAuthorizationHandlerV2( + ServiceComponents serviceComponents, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Map defaultModelsConfigs, + EnumSet implementedTaskTypes, + InferenceService inferenceService, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + // this is a hack to facilitate testing + Runnable callback + ) { + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.modelRegistry = Objects.requireNonNull(modelRegistry); + this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); + this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); + this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); + // allow the service to be null for testing + this.inferenceService = inferenceService; + this.sender = Objects.requireNonNull(sender); + this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); + this.callback = callback; + } + + /** + * Initializes the authorization handler by scheduling the first authorization request. + */ + public void init() { + logger.debug("Initializing authorization logic"); + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); + } + + /** + * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. + * @param waitTime the max time to wait + * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} + */ + public void waitForAuthorizationToComplete(TimeValue waitTime) { + try { + if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("The wait time has expired for authorization to complete."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Waiting for authorization to complete was interrupted"); + } + } + + public synchronized Set supportedStreamingTasks() { + var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); + authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); + + return authorizedStreamingTaskTypes; + } + + public synchronized List defaultConfigIds() { + return authorizedContent.get().configIds; + } + + public synchronized void defaultConfigs(ActionListener> defaultsListener) { + var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); + defaultsListener.onResponse(models); + } + + public synchronized EnumSet supportedTaskTypes() { + return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); + } + + public synchronized boolean hideFromConfigurationApi() { + return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; + } + + @Override + public void close() throws IOException { + shutdown.set(true); + if (lastAuthTask.get() != null) { + lastAuthTask.get().cancel(); + } + } + + private void scheduleAuthorizationRequest() { + try { + if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { + return; + } + + // this call has to be on the individual thread otherwise we get an exception + var random = Randomness.get(); + var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); + var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); + + logger.debug( + () -> Strings.format( + "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", + elasticInferenceServiceSettings.getAuthRequestInterval().millis(), + jitter + ) + ); + logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); + + lastAuthTask.set( + serviceComponents.threadPool() + .schedule( + this::scheduleAndSendAuthorizationRequest, + waitTime, + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) + ) + ); + } catch (Exception e) { + logger.warn("Failed scheduling authorization request", e); + } + } + + private void scheduleAndSendAuthorizationRequest() { + if (shutdown.get()) { + return; + } + + scheduleAuthorizationRequest(); + sendAuthorizationRequest(); + } + + private void sendAuthorizationRequest() { + try { + ActionListener listener = ActionListener.wrap((model) -> { + setAuthorizedContent(model); + if (callback != null) { + callback.run(); + } + }, e -> { + // we don't need to do anything if there was a failure, everything is disabled by default + firstAuthorizationCompletedLatch.countDown(); + }); + + authorizationHandler.getAuthorization(listener, sender); + } catch (Exception e) { + logger.warn("Failure while sending the request to retrieve authorization", e); + // we don't need to do anything if there was a failure, everything is disabled by default + firstAuthorizationCompletedLatch.countDown(); + } + } + + private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { + logger.debug(() -> Strings.format("Received authorization response, %s", auth)); + + var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); + logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); + + // recalculate which default config ids and models are authorized now + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); + + var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); + var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); + authorizedContent.set( + new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) + ); + + authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); + handleRevokedDefaultConfigs(authorizedDefaultModelIds); + } + + private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { + var authorizedModels = auth.getAuthorizedModelIds(); + var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); + authorizedDefaultModelIds.retainAll(authorizedModels); + + return authorizedDefaultModelIds; + } + + private List getAuthorizedDefaultConfigIds( + Set authorizedDefaultModelIds, + ElasticInferenceServiceAuthorizationModel auth + ) { + var authorizedConfigIds = new ArrayList(); + for (var id : authorizedDefaultModelIds) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { + logger.warn( + Strings.format( + "The authorization response included the default model: %s, " + + "but did not authorize the assumed task type of the model: %s. Enabling model.", + id, + modelConfig.model().getTaskType() + ) + ); + } + authorizedConfigIds.add( + new InferenceService.DefaultConfigId( + modelConfig.model().getInferenceEntityId(), + modelConfig.settings(), + inferenceService + ) + ); + } + } + + authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); + return authorizedConfigIds; + } + + private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { + var authorizedModels = new ArrayList(); + for (var id : authorizedDefaultModelIds) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + authorizedModels.add(modelConfig); + } + } + + authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); + return authorizedModels; + } + + private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { + // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked + var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); + unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); + + // get all the default inference endpoint ids for the unauthorized model ids + var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() + .map(defaultModelsConfigs::get) // get all the model configs + .filter(Objects::nonNull) // limit to only non-null + .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids + .collect(Collectors.toSet()); + + var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { + logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); + firstAuthorizationCompletedLatch.countDown(); + }, e -> { + logger.warn( + Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) + ); + firstAuthorizationCompletedLatch.countDown(); + }); + + logger.debug( + () -> Strings.format( + "Synchronizing default inference endpoints, attempting to remove ids: %s", + unauthorizedDefaultInferenceEndpointIds + ) + ); + modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); + } +} From feb96ba0b52d2df98a435adf63ec31eb435499ac Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 14 Oct 2025 15:47:46 -0400 Subject: [PATCH 02/18] Adding tests --- .../inference/registry/ModelRegistry.java | 183 +++++++++------- .../registry/ModelRegistryTests.java | 204 +++++++++++++++++- 2 files changed, 303 insertions(+), 84 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 8858522df20d0..c4b41b874f8fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -17,9 +17,8 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -532,11 +531,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi SubscribableListener.newForked((subListener) -> { // in this block, we try to update the stored model configurations - IndexRequest configRequest = createIndexRequest( - Model.documentId(inferenceEntityId), + var requestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, newModel.getConfigurations(), - true + true, + client ); ActionListener storeConfigListener = subListener.delegateResponse((l, e) -> { @@ -545,7 +545,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); - client.prepareBulk().add(configRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); + client.prepareBulk().add(requestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); }).andThen((subListener, configResponse) -> { // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets @@ -570,11 +570,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi ); } else { // Since the model configurations were successfully updated, we can now try to store the new secrets - IndexRequest secretsRequest = createIndexRequest( - Model.documentId(newModel.getConfigurations().getInferenceEntityId()), + var requestBuilder = createIndexRequestBuilder( + newModel.getConfigurations().getInferenceEntityId(), InferenceSecretsIndex.INDEX_NAME, newModel.getSecrets(), - true + true, + client ); ActionListener storeSecretsListener = subListener.delegateResponse((l, e) -> { @@ -584,7 +585,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi }); client.prepareBulk() - .add(secretsRequest) + .add(requestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(storeSecretsListener); } @@ -592,12 +593,14 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi // in this block, we respond to the success or failure of updating the model secrets if (secretsResponse.hasFailures()) { // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations - IndexRequest configRequest = createIndexRequest( - Model.documentId(inferenceEntityId), + var requestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, existingModel.getConfigurations(), - true + true, + client ); + logger.error( "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state", inferenceEntityId @@ -609,7 +612,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); client.prepareBulk() - .add(configRequest) + .add(requestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(rollbackConfigListener); } else { @@ -656,7 +659,7 @@ public void storeModel(Model model, ActionListener listener, TimeValue private void storeModel(Model model, boolean updateClusterState, ActionListener listener, TimeValue timeout) { storeModels(List.of(model), updateClusterState, listener.delegateFailureAndWrap((delegate, responses) -> { - var firstFailureResponse = responses.stream().filter(ModelResponse::failed).findFirst(); + var firstFailureResponse = responses.stream().filter(ModelStoreResponse::failed).findFirst(); if (firstFailureResponse.isPresent() == false) { delegate.onResponse(Boolean.TRUE); return; @@ -664,7 +667,6 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< var failureItem = firstFailureResponse.get(); if (ExceptionsHelper.unwrapCause(failureItem.failureCause()) instanceof VersionConflictEngineException) { - // TODO do we want to include the cause? delegate.onFailure( new ElasticsearchStatusException( "Inference endpoint [{}] already exists", @@ -686,28 +688,31 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< }), timeout); } - // TODO rename - public record ModelResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) { + public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) { public boolean failed() { return failureCause != null; } } - public void storeModels( + public void storeModels(List models, ActionListener> listener, TimeValue timeout) { + storeModels(models, true, listener, timeout); + } + + private void storeModels( List models, boolean updateClusterState, - ActionListener> listener, + ActionListener> listener, TimeValue timeout ) { var bulkRequestBuilder = client.prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (var model : models) { bulkRequestBuilder.add( - createIndexRequestBuilder(model.getInferenceEntityId(), InferenceIndex.INDEX_NAME, model.getConfigurations(), false) + createIndexRequestBuilder(model.getInferenceEntityId(), InferenceIndex.INDEX_NAME, model.getConfigurations(), false, client) ); bulkRequestBuilder.add( - createIndexRequestBuilder(model.getInferenceEntityId(), InferenceSecretsIndex.INDEX_NAME, model.getSecrets(), false) + createIndexRequestBuilder(model.getInferenceEntityId(), InferenceSecretsIndex.INDEX_NAME, model.getSecrets(), false, client) ); } @@ -717,14 +722,15 @@ public void storeModels( private ActionListener getStoreMultipleModelsListener( List models, boolean updateClusterState, - ActionListener> listener, + ActionListener> listener, TimeValue timeout ) { - var docIdToInferenceId = models.stream() - .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId)); - var inferenceIdToModel = models.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity())); - return ActionListener.wrap(bulkItemResponses -> { + var docIdToInferenceId = models.stream() + .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId, (id1, id2) -> id1)); + var inferenceIdToModel = models.stream() + .collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity(), (id1, id2) -> id1)); + var inferenceEntityIds = String.join(", ", models.stream().map(Model::getInferenceEntityId).toList()); if (bulkItemResponses.getItems().length == 0) { logger.warn("Storing inference endpoints [{}] failed, no items were received from the bulk response", inferenceEntityIds); @@ -743,12 +749,12 @@ private ActionListener getStoreMultipleModelsListener( if (updateClusterState) { updateClusterState( - responseInfo.successfullyStoredModels, - listener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.response)), + responseInfo.v2(), + listener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.v1())), timeout ); } else { - listener.onResponse(responseInfo.response); + listener.onResponse(responseInfo.v1()); } }, e -> { String errorMessage = format( @@ -760,53 +766,81 @@ private ActionListener getStoreMultipleModelsListener( }); } - // TODO rename - private record ResponseInfo(List response, List successfullyStoredModels) {} - - // TODO rename - private static ResponseInfo getResponseInfo( + private static Tuple, List> getResponseInfo( BulkResponse bulkResponse, Map docIdToInferenceId, Map inferenceIdToModel ) { - var response = new ArrayList(); + var responses = new ArrayList(); var modelsSuccessfullyStored = new ArrayList(); - for (var item : bulkResponse.getItems()) { - var failure = item.getFailure(); - - String inferenceIdOrUnknown = "unknown"; - var inferenceIdMaybeNull = docIdToInferenceId.get(item.getId()); - if (inferenceIdMaybeNull == null) { - logger.warn("Failed to find inference id for document id [{}]", item.getId()); - } else { - inferenceIdOrUnknown = inferenceIdMaybeNull; + var bulkItems = bulkResponse.getItems(); + for (int i = 0; i < bulkItems.length; i += 2) { + var configurationItem = bulkItems[i]; + var configStoreResponse = createModelStoreResponse(configurationItem, docIdToInferenceId); + var modelFromBulkItem = getModelFromMap(docIdToInferenceId.get(configurationItem.getId()), inferenceIdToModel); + + if (i + 1 >= bulkResponse.getItems().length) { + logger.error("Expected an even number of bulk response items, got [{}]", bulkResponse.getItems().length); + responses.add(configStoreResponse); + if (configStoreResponse.failed() == false && modelFromBulkItem != null) { + modelsSuccessfullyStored.add(modelFromBulkItem); + } + return new Tuple<>(responses, modelsSuccessfullyStored); } - if (item.isFailed() && failure != null) { - response.add(new ModelResponse(inferenceIdOrUnknown, item.status(), failure.getCause())); - logger.warn( - format( - "Failed to store document id: [%s] inference id: [%s] index: [%s] bulk failure message [%s]", - item.getId(), - inferenceIdOrUnknown, - item.getIndex(), - item.getFailureMessage() - ) - ); - } else { - response.add(new ModelResponse(inferenceIdOrUnknown, item.status(), null)); + var secretsItem = bulkItems[i + 1]; + var secretsStoreResponse = createModelStoreResponse(secretsItem, docIdToInferenceId); - if (inferenceIdMaybeNull != null) { - var modelForResponseItem = inferenceIdToModel.get(inferenceIdMaybeNull); - if (modelForResponseItem != null) { - modelsSuccessfullyStored.add(modelForResponseItem); - } + if (configStoreResponse.failed()) { + responses.add(configStoreResponse); + } else if (secretsStoreResponse.failed()) { + responses.add(secretsStoreResponse); + } else { + responses.add(configStoreResponse); + if (modelFromBulkItem != null) { + modelsSuccessfullyStored.add(modelFromBulkItem); } } } - return new ResponseInfo(response, modelsSuccessfullyStored); + return new Tuple<>(responses, modelsSuccessfullyStored); + } + + private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item, Map docIdToInferenceId) { + var failure = item.getFailure(); + + String inferenceIdOrUnknown = "unknown"; + var inferenceIdMaybeNull = docIdToInferenceId.get(item.getId()); + if (inferenceIdMaybeNull == null) { + logger.warn("Failed to find inference id for document id [{}]", item.getId()); + } else { + inferenceIdOrUnknown = inferenceIdMaybeNull; + } + + if (item.isFailed() && failure != null) { + logger.warn( + format( + "Failed to store document id: [%s] inference id: [%s] index: [%s] bulk failure message [%s]", + item.getId(), + inferenceIdOrUnknown, + item.getIndex(), + item.getFailureMessage() + ) + ); + + return new ModelStoreResponse(inferenceIdOrUnknown, item.status(), failure.getCause()); + } else { + return new ModelStoreResponse(inferenceIdOrUnknown, item.status(), null); + } + } + + private static Model getModelFromMap(@Nullable String inferenceId, Map inferenceIdToModel) { + if (inferenceId != null) { + return inferenceIdToModel.get(inferenceId); + } + + return null; } private void updateClusterState(List models, ActionListener listener, TimeValue timeout) { @@ -966,11 +1000,13 @@ private static DeleteByQueryRequest createDeleteRequest(Set inferenceEnt return request; } - private IndexRequestBuilder createIndexRequestBuilder( + // default for testing + static IndexRequestBuilder createIndexRequestBuilder( String inferenceId, String indexName, ToXContentObject body, - boolean allowOverwriting + boolean allowOverwriting, + Client client ) { try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) { XContentBuilder source = body.toXContent( @@ -992,21 +1028,6 @@ private IndexRequestBuilder createIndexRequestBuilder( } } - private static IndexRequest createIndexRequest(String docId, String indexName, ToXContentObject body, boolean allowOverwriting) { - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - var request = new IndexRequest(indexName); - XContentBuilder source = body.toXContent( - builder, - new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())) - ); - var operation = allowOverwriting ? DocWriteRequest.OpType.INDEX : DocWriteRequest.OpType.CREATE; - - return request.opType(operation).id(docId).source(source); - } catch (IOException ex) { - throw new ElasticsearchException(format("Unexpected serialization exception for index [%s] doc [%s]", indexName, docId), ex); - } - } - private static UnparsedModel modelToUnparsedModel(Model model) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { model.getConfigurations() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index eee8550ec6524..2478ddf62c2cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -8,10 +8,13 @@ package org.elasticsearch.xpack.inference.registry; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; @@ -23,6 +26,7 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.junit.Before; @@ -96,6 +100,195 @@ public void testGetModelWithSecrets() { assertThat(secretSettings.get("api_key"), equalTo("secret")); } + public void testStoreModels_StoresSingleInferenceEndpoint() { + var inferenceId = "1"; + var secrets = "secret"; + + var model = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + registry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), is(1)); + assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null))); + + assertMinimalServiceSettings(registry, model); + + var listener = new PlainActionFuture(); + registry.getModelWithSecrets(inferenceId, listener); + + var returnedModel = listener.actionGet(TIMEOUT); + assertThat(returnedModel.inferenceEntityId(), is(model.getInferenceEntityId())); + assertThat(returnedModel.service(), is(model.getConfigurations().getService())); + assertThat(returnedModel.taskType(), is(model.getConfigurations().getTaskType())); + assertEquals(TaskType.SPARSE_EMBEDDING, returnedModel.taskType()); + assertThat(returnedModel.secrets().keySet(), hasSize(1)); + assertThat(returnedModel.secrets().get("secret_settings"), instanceOf(Map.class)); + @SuppressWarnings("unchecked") + var secretSettings = (Map) returnedModel.secrets().get("secret_settings"); + assertThat(secretSettings.get("api_key"), equalTo(secrets)); + } + + public void testStoreModels_StoresMultipleInferenceEndpoints() { + var secrets = "secret"; + + var model1 = new TestModel( + "1", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + "2", + TaskType.TEXT_EMBEDDING, + "foo", + new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + registry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), is(2)); + assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null))); + assertThat(response.get(1), is(new ModelRegistry.ModelStoreResponse("2", RestStatus.CREATED, null))); + + assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets); + assertModelAndMinimalSettingsWithSecrets(registry, model2, secrets); + } + + private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) { + assertMinimalServiceSettings(registry, model); + + var listener1 = new PlainActionFuture(); + registry.getModelWithSecrets(model.getInferenceEntityId(), listener1); + + var storedModel1 = listener1.actionGet(TIMEOUT); + assertModel(storedModel1, model, secrets); + } + + private static void assertModel(UnparsedModel model, Model expected, String secrets) { + assertThat(model.inferenceEntityId(), is(expected.getInferenceEntityId())); + assertThat(model.service(), is(expected.getConfigurations().getService())); + assertThat(model.taskType(), is(expected.getConfigurations().getTaskType())); + assertThat(model.taskType(), is(expected.getConfigurations().getTaskType())); + assertThat(model.secrets().keySet(), hasSize(1)); + assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class)); + @SuppressWarnings("unchecked") + var secretSettings = (Map) model.secrets().get("secret_settings"); + assertThat(secretSettings.get("api_key"), is(secrets)); + } + + public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() { + var secrets = "secret"; + + var model1 = new TestModel( + "1", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + // using the same inference id as model1 to cause a failure + "1", + TaskType.TEXT_EMBEDDING, + "foo", + new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + registry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), is(2)); + assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null))); + assertThat(response.get(1).inferenceId(), is(model2.getInferenceEntityId())); + assertThat(response.get(1).status(), is(RestStatus.CONFLICT)); + assertTrue(response.get(1).failed()); + + var cause = response.get(1).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + + assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets); + } + + public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { + var secrets = "secret"; + + var model1 = new TestModel( + "1", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + storeCorruptedModel(model1); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + registry.storeModels(List.of(model1), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), is(1)); + assertThat(response.get(0).inferenceId(), is(model1.getInferenceEntityId())); + assertThat(response.get(0).status(), is(RestStatus.CONFLICT)); + assertTrue(response.get(0).failed()); + + var cause = response.get(0).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + } + + private void storeCorruptedModel(Model model) { + var bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + bulkRequestBuilder.add( + ModelRegistry.createIndexRequestBuilder( + model.getInferenceEntityId(), + InferenceIndex.INDEX_NAME, + model.getConfigurations(), + false, + client() + ) + ); + + var listener = new PlainActionFuture(); + bulkRequestBuilder.execute(listener); + + var bulkResponse = listener.actionGet(TIMEOUT); + if (bulkResponse.hasFailures()) { + fail( + Strings.format( + "Failed to store model inference id: %s, for test. Error: %s", + model.getInferenceEntityId(), + bulkResponse.buildFailureMessage() + ) + ); + } + } + public void testGetModelNoSecrets() { assertStoreModel( registry, @@ -129,10 +322,11 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe var model = TestModel.createRandomInstance(); assertStoreModel(registry, model); - ResourceAlreadyExistsException exception = expectThrows( - ResourceAlreadyExistsException.class, + var exception = expectThrows( + ElasticsearchStatusException.class, () -> assertStoreModel(registry, model) ); + assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); assertThat( exception.getMessage(), is(format("Inference endpoint [%s] already exists", model.getConfigurations().getInferenceEntityId())) @@ -242,6 +436,10 @@ public static void assertStoreModel(ModelRegistry registry, Model model) { registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); assertTrue(storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertMinimalServiceSettings(registry, model); + } + + private static void assertMinimalServiceSettings(ModelRegistry registry, Model model) { var settings = registry.getMinimalServiceSettings(model.getInferenceEntityId()); assertNotNull(settings); assertThat(settings.taskType(), equalTo(model.getTaskType())); From 5f36d450eff540ca889ef30038100672f21c5992 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 14 Oct 2025 15:50:47 -0400 Subject: [PATCH 03/18] Adding log for duplicate ids --- .../xpack/inference/registry/ModelRegistry.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index c4b41b874f8fb..bfc75db22c2aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -727,7 +727,10 @@ private ActionListener getStoreMultipleModelsListener( ) { return ActionListener.wrap(bulkItemResponses -> { var docIdToInferenceId = models.stream() - .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId, (id1, id2) -> id1)); + .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId, (id1, id2) -> { + logger.warn("Encountered duplicate inference ids when storing endpoints: [{}]", id1); + return id1; + })); var inferenceIdToModel = models.stream() .collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity(), (id1, id2) -> id1)); From ebb64760f0768dbadc37a76b451a4163d101862b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 14 Oct 2025 20:06:08 +0000 Subject: [PATCH 04/18] [CI] Auto commit changes from spotless --- .../xpack/inference/registry/ModelRegistryTests.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 2478ddf62c2cc..e1de6d3eaaaeb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -322,10 +322,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe var model = TestModel.createRandomInstance(); assertStoreModel(registry, model); - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> assertStoreModel(registry, model) - ); + var exception = expectThrows(ElasticsearchStatusException.class, () -> assertStoreModel(registry, model)); assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); assertThat( exception.getMessage(), From c30df2b3725e4d9e06df122ca1093096008e61d4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 14 Oct 2025 16:44:23 -0400 Subject: [PATCH 05/18] Removing unused code --- .../inference/registry/ModelRegistry.java | 22 +- ...nferenceServiceAuthorizationHandlerV2.java | 339 ------------------ 2 files changed, 12 insertions(+), 349 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index bfc75db22c2aa..3e49235f617ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -671,7 +671,7 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< new ElasticsearchStatusException( "Inference endpoint [{}] already exists", RestStatus.BAD_REQUEST, - failureItem.failureCause, + failureItem.failureCause(), failureItem.inferenceId ) ); @@ -752,12 +752,12 @@ private ActionListener getStoreMultipleModelsListener( if (updateClusterState) { updateClusterState( - responseInfo.v2(), - listener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.v1())), + responseInfo.successfullyStoredModels(), + listener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.responses())), timeout ); } else { - listener.onResponse(responseInfo.v1()); + listener.onResponse(responseInfo.responses()); } }, e -> { String errorMessage = format( @@ -769,13 +769,15 @@ private ActionListener getStoreMultipleModelsListener( }); } - private static Tuple, List> getResponseInfo( + private record ResponseInfo(List responses, List successfullyStoredModels) {} + + private static ResponseInfo getResponseInfo( BulkResponse bulkResponse, Map docIdToInferenceId, Map inferenceIdToModel ) { var responses = new ArrayList(); - var modelsSuccessfullyStored = new ArrayList(); + var successfullyStoredModels = new ArrayList(); var bulkItems = bulkResponse.getItems(); for (int i = 0; i < bulkItems.length; i += 2) { @@ -787,9 +789,9 @@ private static Tuple, List> getResponseInfo( logger.error("Expected an even number of bulk response items, got [{}]", bulkResponse.getItems().length); responses.add(configStoreResponse); if (configStoreResponse.failed() == false && modelFromBulkItem != null) { - modelsSuccessfullyStored.add(modelFromBulkItem); + successfullyStoredModels.add(modelFromBulkItem); } - return new Tuple<>(responses, modelsSuccessfullyStored); + return new ResponseInfo(responses, successfullyStoredModels); } var secretsItem = bulkItems[i + 1]; @@ -802,12 +804,12 @@ private static Tuple, List> getResponseInfo( } else { responses.add(configStoreResponse); if (modelFromBulkItem != null) { - modelsSuccessfullyStored.add(modelFromBulkItem); + successfullyStoredModels.add(modelFromBulkItem); } } } - return new Tuple<>(responses, modelsSuccessfullyStored); + return new ResponseInfo(responses, successfullyStoredModels); } private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item, Map docIdToInferenceId) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java deleted file mode 100644 index 59501fdc45170..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; - -import java.io.Closeable; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeSet; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - -public class ElasticInferenceServiceAuthorizationHandlerV2 implements Closeable { - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandlerV2.class); - - private record AuthorizedContent( - ElasticInferenceServiceAuthorizationModel taskTypesAndModels, - List configIds, - List defaultModelConfigs - ) { - static AuthorizedContent empty() { - return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); - } - } - - private final ServiceComponents serviceComponents; - private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); - private final ModelRegistry modelRegistry; - private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final Map defaultModelsConfigs; - private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); - private final EnumSet implementedTaskTypes; - private final InferenceService inferenceService; - private final Sender sender; - private final Runnable callback; - private final AtomicReference lastAuthTask = new AtomicReference<>(null); - private final AtomicBoolean shutdown = new AtomicBoolean(false); - private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; - - public ElasticInferenceServiceAuthorizationHandlerV2( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings - ) { - this( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - defaultModelsConfigs, - implementedTaskTypes, - Objects.requireNonNull(inferenceService), - sender, - elasticInferenceServiceSettings, - null - ); - } - - // default for testing - ElasticInferenceServiceAuthorizationHandlerV2( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings, - // this is a hack to facilitate testing - Runnable callback - ) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.modelRegistry = Objects.requireNonNull(modelRegistry); - this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); - this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); - this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); - // allow the service to be null for testing - this.inferenceService = inferenceService; - this.sender = Objects.requireNonNull(sender); - this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.callback = callback; - } - - /** - * Initializes the authorization handler by scheduling the first authorization request. - */ - public void init() { - logger.debug("Initializing authorization logic"); - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); - } - - /** - * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } - } - - public synchronized Set supportedStreamingTasks() { - var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - return authorizedStreamingTaskTypes; - } - - public synchronized List defaultConfigIds() { - return authorizedContent.get().configIds; - } - - public synchronized void defaultConfigs(ActionListener> defaultsListener) { - var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); - defaultsListener.onResponse(models); - } - - public synchronized EnumSet supportedTaskTypes() { - return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); - } - - public synchronized boolean hideFromConfigurationApi() { - return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; - } - - @Override - public void close() throws IOException { - shutdown.set(true); - if (lastAuthTask.get() != null) { - lastAuthTask.get().cancel(); - } - } - - private void scheduleAuthorizationRequest() { - try { - if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { - return; - } - - // this call has to be on the individual thread otherwise we get an exception - var random = Randomness.get(); - var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); - var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); - - logger.debug( - () -> Strings.format( - "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", - elasticInferenceServiceSettings.getAuthRequestInterval().millis(), - jitter - ) - ); - logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); - - lastAuthTask.set( - serviceComponents.threadPool() - .schedule( - this::scheduleAndSendAuthorizationRequest, - waitTime, - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) - ) - ); - } catch (Exception e) { - logger.warn("Failed scheduling authorization request", e); - } - } - - private void scheduleAndSendAuthorizationRequest() { - if (shutdown.get()) { - return; - } - - scheduleAuthorizationRequest(); - sendAuthorizationRequest(); - } - - private void sendAuthorizationRequest() { - try { - ActionListener listener = ActionListener.wrap((model) -> { - setAuthorizedContent(model); - if (callback != null) { - callback.run(); - } - }, e -> { - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - }); - - authorizationHandler.getAuthorization(listener, sender); - } catch (Exception e) { - logger.warn("Failure while sending the request to retrieve authorization", e); - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - } - } - - private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { - logger.debug(() -> Strings.format("Received authorization response, %s", auth)); - - var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); - logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); - - // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); - - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); - authorizedContent.set( - new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) - ); - - authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); - handleRevokedDefaultConfigs(authorizedDefaultModelIds); - } - - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultConfigIds( - Set authorizedDefaultModelIds, - ElasticInferenceServiceAuthorizationModel auth - ) { - var authorizedConfigIds = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { - logger.warn( - Strings.format( - "The authorization response included the default model: %s, " - + "but did not authorize the assumed task type of the model: %s. Enabling model.", - id, - modelConfig.model().getTaskType() - ) - ); - } - authorizedConfigIds.add( - new InferenceService.DefaultConfigId( - modelConfig.model().getInferenceEntityId(), - modelConfig.settings(), - inferenceService - ) - ); - } - } - - authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); - return authorizedConfigIds; - } - - private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { - var authorizedModels = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - authorizedModels.add(modelConfig); - } - } - - authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); - return authorizedModels; - } - - private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { - // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked - var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); - - // get all the default inference endpoint ids for the unauthorized model ids - var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() - .map(defaultModelsConfigs::get) // get all the model configs - .filter(Objects::nonNull) // limit to only non-null - .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids - .collect(Collectors.toSet()); - - var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { - logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); - firstAuthorizationCompletedLatch.countDown(); - }, e -> { - logger.warn( - Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) - ); - firstAuthorizationCompletedLatch.countDown(); - }); - - logger.debug( - () -> Strings.format( - "Synchronizing default inference endpoints, attempting to remove ids: %s", - unauthorizedDefaultInferenceEndpointIds - ) - ); - modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); - } -} From 4ea87d15483a125dced954244e5b6d69fa38e4f1 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 14 Oct 2025 16:50:43 -0400 Subject: [PATCH 06/18] Removing constructor --- .../xpack/inference/registry/ModelRegistry.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 3e49235f617ba..85ca54b3a7137 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -1183,16 +1183,6 @@ public record ModelAndSettings(String inferenceEntityId, MinimalServiceSettings private static class AddModelMetadataTask extends MetadataTask { private final List models = new ArrayList<>(); - AddModelMetadataTask( - ProjectId projectId, - String inferenceEntityId, - MinimalServiceSettings settings, - ActionListener listener - ) { - super(projectId, listener); - this.models.add(new ModelAndSettings(inferenceEntityId, settings)); - } - AddModelMetadataTask(ProjectId projectId, List models, ActionListener listener) { super(projectId, listener); this.models.addAll(models); From 155c366f17ae910792ff8dd294a53b3fa6e0cc8a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 15 Oct 2025 10:57:48 -0400 Subject: [PATCH 07/18] Adding more tests --- .../registry/ModelRegistryMetadata.java | 17 +-- .../registry/ModelRegistryMetadataTests.java | 144 ++++++++++++++++++ 2 files changed, 148 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 6b13f8a2e30cd..4bf23103af5a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -93,18 +93,7 @@ public static ModelRegistryMetadata fromState(ProjectMetadata projectMetadata) { ); public ModelRegistryMetadata withAddedModel(String inferenceEntityId, MinimalServiceSettings settings) { - final var existing = modelMap.get(inferenceEntityId); - if (existing != null && settings.equals(existing)) { - return this; - } - var settingsBuilder = ImmutableOpenMap.builder(modelMap); - settingsBuilder.fPut(inferenceEntityId, settings); - if (isUpgraded) { - return new ModelRegistryMetadata(settingsBuilder.build()); - } - var newTombstone = new HashSet<>(tombstones); - newTombstone.remove(inferenceEntityId); - return new ModelRegistryMetadata(settingsBuilder.build(), newTombstone); + return withAddedModels(List.of(new ModelRegistry.ModelAndSettings(inferenceEntityId, settings))); } public ModelRegistryMetadata withAddedModels(List models) { @@ -285,7 +274,9 @@ public boolean equals(Object obj) { return false; } ModelRegistryMetadata other = (ModelRegistryMetadata) obj; - return Objects.equals(this.modelMap, other.modelMap) && isUpgraded == other.isUpgraded; + return Objects.equals(this.modelMap, other.modelMap) + && isUpgraded == other.isUpgraded + && Objects.equals(this.tombstones, other.tombstones); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java index 19af5ad61b988..7001dd0a1d4aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java @@ -19,11 +19,14 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; public class ModelRegistryMetadataTests extends AbstractChunkedSerializingTestCase { public static ModelRegistryMetadata randomInstance() { @@ -106,4 +109,145 @@ public void testAlreadyUpgraded() { var exc = expectThrows(IllegalArgumentException.class, () -> metadata.withUpgradedModels(indexMetadata.getModelMap())); assertThat(exc.getMessage(), containsString("upgraded")); } + + public void testWithAddedModel_ReturnsSameMetadataInstance() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newMetadata = metadata.withAddedModel(inferenceId, settings); + assertThat(newMetadata, sameInstance(metadata)); + assertThat(newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + } + + public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newInferenceId = "new_id"; + var newSettings = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); + assertThat( + newMetadata, + is(new ModelRegistryMetadata(ImmutableOpenMap.builder(Map.of(inferenceId, settings, newInferenceId, newSettings)).build())) + ); + } + + public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId_WithTombstoneRemoved() { + var inferenceId = "id"; + var newInferenceId = "new_id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)); + + var newSettings = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); + assertThat( + newMetadata, + is( + new ModelRegistryMetadata( + ImmutableOpenMap.builder(Map.of(inferenceId, settings, newInferenceId, newSettings)).build(), + new HashSet<>() + ) + ) + ); + } + + public void testWithAddedModels_ReturnsSameMetadataInstance() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newMetadata = metadata.withAddedModels( + List.of(new ModelRegistry.ModelAndSettings(inferenceId, settings), new ModelRegistry.ModelAndSettings(inferenceId, settings)) + ); + assertThat(newMetadata, sameInstance(metadata)); + assertThat(newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + } + + public void testWithAddedModels_ReturnsSameMetadataInstance_MultipleEntriesInMap() { + var inferenceId = "id"; + var inferenceId2 = "id2"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings, inferenceId2, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newMetadata = metadata.withAddedModels( + List.of( + new ModelRegistry.ModelAndSettings(inferenceId, settings), + new ModelRegistry.ModelAndSettings(inferenceId, settings), + new ModelRegistry.ModelAndSettings(inferenceId2, settings) + ) + ); + assertThat(newMetadata, sameInstance(metadata)); + assertThat(newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + } + + public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var inferenceId2 = "new_id"; + var settings2 = MinimalServiceSettingsTests.randomInstance(); + var inferenceId3 = "new_id2"; + var settings3 = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModels( + List.of( + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + // This should be ignored since it's a duplicate + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + new ModelRegistry.ModelAndSettings(inferenceId3, settings3) + ) + ); + assertThat( + newMetadata, + is( + new ModelRegistryMetadata( + ImmutableOpenMap.builder(Map.of(inferenceId, settings, inferenceId2, settings2, inferenceId3, settings3)).build() + ) + ) + ); + } + + public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_WithTombstoneRemoved() { + var inferenceId = "id"; + var newInferenceId = "new_id"; + var newInferenceId2 = "new_id2"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)); + + var newSettings = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModels( + List.of( + // This will cause the new settings to be used for inferenceId + new ModelRegistry.ModelAndSettings(inferenceId, newSettings), + new ModelRegistry.ModelAndSettings(newInferenceId, newSettings), + new ModelRegistry.ModelAndSettings(newInferenceId2, newSettings) + ) + ); + assertThat( + newMetadata, + is( + new ModelRegistryMetadata( + ImmutableOpenMap.builder(Map.of(inferenceId, newSettings, newInferenceId, newSettings, newInferenceId2, newSettings)) + .build(), + new HashSet<>() + ) + ) + ); + } } From b7f41f8d4fa0660422440b6b0c5d46ca14fdc2de Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 16 Oct 2025 16:24:30 -0400 Subject: [PATCH 08/18] Adding in logic to delete models when a failure occurs --- .../inference/registry/ModelRegistry.java | 47 +++++++++++++++---- .../registry/ModelRegistryTests.java | 44 ++++------------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 85ca54b3a7137..7b6d12424f91c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -725,6 +725,17 @@ private ActionListener getStoreMultipleModelsListener( ActionListener> listener, TimeValue timeout ) { + var cleanupListener = listener.>delegateFailureAndWrap((delegate, responses) -> { + var inferenceIdsToBeRemoved = responses.stream() + .filter(r -> r.modifiedIndex() && r.modelStoreResponse().failed()) + .map(r -> r.modelStoreResponse().inferenceId()) + .collect(Collectors.toSet()); + + var storageResponses = responses.stream().map(StoreResponseWithIndexInfo::modelStoreResponse).toList(); + + deleteModels(inferenceIdsToBeRemoved, ActionListener.running(() -> delegate.onResponse(storageResponses))); + }); + return ActionListener.wrap(bulkItemResponses -> { var docIdToInferenceId = models.stream() .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId, (id1, id2) -> { @@ -753,11 +764,11 @@ private ActionListener getStoreMultipleModelsListener( if (updateClusterState) { updateClusterState( responseInfo.successfullyStoredModels(), - listener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.responses())), + cleanupListener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.responses())), timeout ); } else { - listener.onResponse(responseInfo.responses()); + cleanupListener.onResponse(responseInfo.responses()); } }, e -> { String errorMessage = format( @@ -769,14 +780,17 @@ private ActionListener getStoreMultipleModelsListener( }); } - private record ResponseInfo(List responses, List successfullyStoredModels) {} + private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {} + + private record ResponseInfo(List responses, List successfullyStoredModels) { + } private static ResponseInfo getResponseInfo( BulkResponse bulkResponse, Map docIdToInferenceId, Map inferenceIdToModel ) { - var responses = new ArrayList(); + var responses = new ArrayList(); var successfullyStoredModels = new ArrayList(); var bulkItems = bulkResponse.getItems(); @@ -787,9 +801,21 @@ private static ResponseInfo getResponseInfo( if (i + 1 >= bulkResponse.getItems().length) { logger.error("Expected an even number of bulk response items, got [{}]", bulkResponse.getItems().length); - responses.add(configStoreResponse); - if (configStoreResponse.failed() == false && modelFromBulkItem != null) { - successfullyStoredModels.add(modelFromBulkItem); + + if (configStoreResponse.failed()) { + responses.add(new StoreResponseWithIndexInfo(configStoreResponse, false)); + } else { + // if we didn't get the last item for some reason, assume it is a failure + responses.add( + new StoreResponseWithIndexInfo( + new ModelStoreResponse( + configStoreResponse.inferenceId(), + RestStatus.INTERNAL_SERVER_ERROR, + new IllegalStateException("Failed to receive part of bulk response") + ), + true + ) + ); } return new ResponseInfo(responses, successfullyStoredModels); } @@ -798,11 +824,12 @@ private static ResponseInfo getResponseInfo( var secretsStoreResponse = createModelStoreResponse(secretsItem, docIdToInferenceId); if (configStoreResponse.failed()) { - responses.add(configStoreResponse); + responses.add(new StoreResponseWithIndexInfo(configStoreResponse, secretsStoreResponse.failed() == false)); } else if (secretsStoreResponse.failed()) { - responses.add(secretsStoreResponse); + // if we got here that means the configuration store bulk item succeeded, so we did modify an index + responses.add(new StoreResponseWithIndexInfo(secretsStoreResponse, true)); } else { - responses.add(configStoreResponse); + responses.add(new StoreResponseWithIndexInfo(configStoreResponse, true)); if (modelFromBulkItem != null) { successfullyStoredModels.add(modelFromBulkItem); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 080d7112d3be4..28e8e2d4037cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -8,14 +8,12 @@ package org.elasticsearch.xpack.inference.registry; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -237,6 +235,7 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets); + assertIndicesContainExpectedDocsCount(model1, 2); } public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { @@ -251,7 +250,7 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE new TestModel.TestSecretSettings(secrets) ); - storeCorruptedModel(model1); + storeCorruptedModel(model1, false); PlainActionFuture> storeListener = new PlainActionFuture<>(); registry.storeModels(List.of(model1), storeListener, TimeValue.THIRTY_SECONDS); @@ -268,34 +267,6 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); } - private void storeCorruptedModel(Model model) { - var bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - bulkRequestBuilder.add( - ModelRegistry.createIndexRequestBuilder( - model.getInferenceEntityId(), - InferenceIndex.INDEX_NAME, - model.getConfigurations(), - false, - client() - ) - ); - - var listener = new PlainActionFuture(); - bulkRequestBuilder.execute(listener); - - var bulkResponse = listener.actionGet(TIMEOUT); - if (bulkResponse.hasFailures()) { - fail( - Strings.format( - "Failed to store model inference id: %s, for test. Error: %s", - model.getInferenceEntityId(), - bulkResponse.buildFailureMessage() - ) - ); - } - } - public void testGetModelNoSecrets() { assertStoreModel( registry, @@ -325,7 +296,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { assertStoreModel(registry, model); } - public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVersionConflict() { + public void testStoreModel_ThrowsException_WhenFailureIsAVersionConflict() { var model = TestModel.createRandomInstance(); assertStoreModel(registry, model); @@ -462,8 +433,9 @@ public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { PlainActionFuture secondStoreListener = new PlainActionFuture<>(); registry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); - expectThrows(ResourceAlreadyExistsException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); - + var exception = expectThrows(ElasticsearchStatusException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat(exception.getMessage(), containsString("already exists")); + assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); assertIndicesContainExpectedDocsCount(model, 2); } @@ -484,7 +456,9 @@ private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); - expectThrows(ResourceAlreadyExistsException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + var exception = expectThrows(ElasticsearchStatusException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat(exception.getMessage(), containsString("already exists")); + assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); assertIndicesContainExpectedDocsCount(model, 0); } From f4d9f2c628ba7ddedb4973631ba4c5539327190f Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 16 Oct 2025 16:31:46 -0400 Subject: [PATCH 09/18] revert rename changes --- .../xpack/inference/registry/ModelRegistry.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7b6d12424f91c..e2280854434a2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -531,7 +531,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi SubscribableListener.newForked((subListener) -> { // in this block, we try to update the stored model configurations - var requestBuilder = createIndexRequestBuilder( + var configRequestBuilder = createIndexRequestBuilder( inferenceEntityId, InferenceIndex.INDEX_NAME, newModel.getConfigurations(), @@ -545,7 +545,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); - client.prepareBulk().add(requestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); + client.prepareBulk().add(configRequestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); }).andThen((subListener, configResponse) -> { // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets @@ -570,7 +570,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi ); } else { // Since the model configurations were successfully updated, we can now try to store the new secrets - var requestBuilder = createIndexRequestBuilder( + var secretsRequestBuilder = createIndexRequestBuilder( newModel.getConfigurations().getInferenceEntityId(), InferenceSecretsIndex.INDEX_NAME, newModel.getSecrets(), @@ -585,7 +585,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi }); client.prepareBulk() - .add(requestBuilder) + .add(secretsRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(storeSecretsListener); } @@ -593,7 +593,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi // in this block, we respond to the success or failure of updating the model secrets if (secretsResponse.hasFailures()) { // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations - var requestBuilder = createIndexRequestBuilder( + var configRequestBuilder = createIndexRequestBuilder( inferenceEntityId, InferenceIndex.INDEX_NAME, existingModel.getConfigurations(), @@ -612,7 +612,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); client.prepareBulk() - .add(requestBuilder) + .add(configRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(rollbackConfigListener); } else { From cd8c832914e5cd78d3dc40e9dcf3ff2fd52d2839 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 16 Oct 2025 16:32:49 -0400 Subject: [PATCH 10/18] formatting --- .../xpack/inference/registry/ModelRegistry.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index e2280854434a2..978f10cb47240 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -545,7 +545,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); - client.prepareBulk().add(configRequestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); + client.prepareBulk() + .add(configRequestBuilder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .execute(storeConfigListener); }).andThen((subListener, configResponse) -> { // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets @@ -782,8 +785,7 @@ private ActionListener getStoreMultipleModelsListener( private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {} - private record ResponseInfo(List responses, List successfullyStoredModels) { - } + private record ResponseInfo(List responses, List successfullyStoredModels) {} private static ResponseInfo getResponseInfo( BulkResponse bulkResponse, From b6cebf536c3ab24d1dc2485d498ff785524fbad9 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 17 Oct 2025 17:05:49 -0400 Subject: [PATCH 11/18] Starting on feedback --- .../inference/registry/ModelRegistry.java | 21 ++++++--- .../registry/ModelRegistryMetadataTests.java | 44 +++++++++++++++++-- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 978f10cb47240..9228ae9e5585a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -574,7 +574,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi } else { // Since the model configurations were successfully updated, we can now try to store the new secrets var secretsRequestBuilder = createIndexRequestBuilder( - newModel.getConfigurations().getInferenceEntityId(), + inferenceEntityId, InferenceSecretsIndex.INDEX_NAME, newModel.getSecrets(), true, @@ -675,7 +675,7 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< "Inference endpoint [{}] already exists", RestStatus.BAD_REQUEST, failureItem.failureCause(), - failureItem.inferenceId + failureItem.inferenceId() ) ); return; @@ -683,7 +683,7 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< delegate.onFailure( new ElasticsearchStatusException( - format("Failed to store inference endpoint [%s]", failureItem.inferenceId), + format("Failed to store inference endpoint [%s]", failureItem.inferenceId()), RestStatus.INTERNAL_SERVER_ERROR, failureItem.failureCause() ) @@ -748,8 +748,8 @@ private ActionListener getStoreMultipleModelsListener( var inferenceIdToModel = models.stream() .collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity(), (id1, id2) -> id1)); - var inferenceEntityIds = String.join(", ", models.stream().map(Model::getInferenceEntityId).toList()); if (bulkItemResponses.getItems().length == 0) { + var inferenceEntityIds = String.join(", ", models.stream().map(Model::getInferenceEntityId).toList()); logger.warn("Storing inference endpoints [{}] failed, no items were received from the bulk response", inferenceEntityIds); listener.onFailure( @@ -825,6 +825,13 @@ private static ResponseInfo getResponseInfo( var secretsItem = bulkItems[i + 1]; var secretsStoreResponse = createModelStoreResponse(secretsItem, docIdToInferenceId); + assert secretsStoreResponse.inferenceId().equals(configStoreResponse.inferenceId()) + : "Mismatched inference ids in bulk response items, configuration id [" + + configStoreResponse.inferenceId() + + "] secrets id [" + + secretsStoreResponse.inferenceId() + + "]"; + if (configStoreResponse.failed()) { responses.add(new StoreResponseWithIndexInfo(configStoreResponse, secretsStoreResponse.failed() == false)); } else if (secretsStoreResponse.failed()) { @@ -880,7 +887,7 @@ private static Model getModelFromMap(@Nullable String inferenceId, Map models, ActionListener listener, TimeValue timeout) { var inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet()); var storeListener = listener.delegateResponse((delegate, exc) -> { - logger.warn(format("Failed to add inference endpoint %s minimal service settings to cluster state", inferenceIdsSet), exc); + logger.warn(format("Failed to add minimal service settings to cluster state for inference endpoints %s", inferenceIdsSet), exc); deleteModels( inferenceIdsSet, ActionListener.running( @@ -888,7 +895,7 @@ private void updateClusterState(List models, ActionListener models, ActionListener { @@ -63,7 +64,33 @@ protected ModelRegistryMetadata createTestInstance() { @Override protected ModelRegistryMetadata mutateInstance(ModelRegistryMetadata instance) { - return randomValueOtherThan(instance, this::createTestInstance); + int choice = randomIntBetween(0, 2); + switch (choice) { + case 0: // Mutate modelMap + var models = new HashMap<>(instance.getModelMap()); + models.put(randomAlphaOfLength(10), MinimalServiceSettingsTests.randomInstance()); + if (instance.isUpgraded()) { + return new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + } else { + return new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), new HashSet<>(instance.getTombstones())); + } + case 1: // Mutate tombstones + if (instance.getTombstones() == null) { + return new ModelRegistryMetadata(instance.getModelMap(), Set.of(randomAlphaOfLength(10))); + } else { + var tombstones = new HashSet<>(instance.getTombstones()); + tombstones.add(randomAlphaOfLength(10)); + return new ModelRegistryMetadata(instance.getModelMap(), tombstones); + } + case 2: // Mutate isUpgraded + if (instance.isUpgraded()) { + return new ModelRegistryMetadata(instance.getModelMap(), new HashSet<>()); + } else { + return new ModelRegistryMetadata(instance.getModelMap()); + } + default: + throw new IllegalStateException("Unexpected value: " + choice); + } } @Override @@ -119,7 +146,6 @@ public void testWithAddedModel_ReturnsSameMetadataInstance() { var newMetadata = metadata.withAddedModel(inferenceId, settings); assertThat(newMetadata, sameInstance(metadata)); - assertThat(newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); } public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId() { @@ -132,6 +158,9 @@ public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId() { var newInferenceId = "new_id"; var newSettings = MinimalServiceSettingsTests.randomInstance(); var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); assertThat( newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(Map.of(inferenceId, settings, newInferenceId, newSettings)).build())) @@ -148,6 +177,9 @@ public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId_With var newSettings = MinimalServiceSettingsTests.randomInstance(); var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); assertThat( newMetadata, is( @@ -170,7 +202,6 @@ public void testWithAddedModels_ReturnsSameMetadataInstance() { List.of(new ModelRegistry.ModelAndSettings(inferenceId, settings), new ModelRegistry.ModelAndSettings(inferenceId, settings)) ); assertThat(newMetadata, sameInstance(metadata)); - assertThat(newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); } public void testWithAddedModels_ReturnsSameMetadataInstance_MultipleEntriesInMap() { @@ -189,7 +220,6 @@ public void testWithAddedModels_ReturnsSameMetadataInstance_MultipleEntriesInMap ) ); assertThat(newMetadata, sameInstance(metadata)); - assertThat(newMetadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); } public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId() { @@ -211,6 +241,9 @@ public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId() { new ModelRegistry.ModelAndSettings(inferenceId3, settings3) ) ); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); assertThat( newMetadata, is( @@ -239,6 +272,9 @@ public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_Wit new ModelRegistry.ModelAndSettings(newInferenceId2, newSettings) ) ); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); assertThat( newMetadata, is( From 5441da4d2416622024a5cd101fa8eef3484b8c73 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 20 Oct 2025 09:57:45 -0400 Subject: [PATCH 12/18] Improving tests --- .../registry/ModelRegistryMetadataTests.java | 30 ++++- .../registry/ModelRegistryTests.java | 126 ++++++++++++++---- 2 files changed, 127 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java index 924da8b1fa99c..8726632a77095 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java @@ -178,7 +178,7 @@ public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId_With var newSettings = MinimalServiceSettingsTests.randomInstance(); var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); // ensure metadata hasn't changed - assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)))); assertThat(newMetadata, not(is(metadata))); assertThat( newMetadata, @@ -254,6 +254,34 @@ public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId() { ); } + public void testWithAddedModels_ReturnsNewMetadataInstance_UsesOverridingSettings() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var inferenceId2 = "new_id"; + var settings2 = MinimalServiceSettingsTests.randomInstance(); + var settings3 = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModels( + List.of( + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + // This should be ignored since it's a duplicate inference id + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + // This should replace the existing settings for inferenceId2 + new ModelRegistry.ModelAndSettings(inferenceId2, settings3) + ) + ); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); + assertThat( + newMetadata, + is(new ModelRegistryMetadata(ImmutableOpenMap.builder(Map.of(inferenceId, settings, inferenceId2, settings3)).build())) + ); + } + public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_WithTombstoneRemoved() { var inferenceId = "id"; var newInferenceId = "new_id"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 28e8e2d4037cc..8d87222b791c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -123,7 +123,7 @@ public void testStoreModels_StoresSingleInferenceEndpoint() { var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), is(1)); - assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null))); + assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertMinimalServiceSettings(registry, model); @@ -131,22 +131,16 @@ public void testStoreModels_StoresSingleInferenceEndpoint() { registry.getModelWithSecrets(inferenceId, listener); var returnedModel = listener.actionGet(TIMEOUT); - assertThat(returnedModel.inferenceEntityId(), is(model.getInferenceEntityId())); - assertThat(returnedModel.service(), is(model.getConfigurations().getService())); - assertThat(returnedModel.taskType(), is(model.getConfigurations().getTaskType())); - assertEquals(TaskType.SPARSE_EMBEDDING, returnedModel.taskType()); - assertThat(returnedModel.secrets().keySet(), hasSize(1)); - assertThat(returnedModel.secrets().get("secret_settings"), instanceOf(Map.class)); - @SuppressWarnings("unchecked") - var secretSettings = (Map) returnedModel.secrets().get("secret_settings"); - assertThat(secretSettings.get("api_key"), equalTo(secrets)); + assertModel(returnedModel, model, secrets); } public void testStoreModels_StoresMultipleInferenceEndpoints() { var secrets = "secret"; + var inferenceId1 = "1"; + var inferenceId2 = "2"; var model1 = new TestModel( - "1", + inferenceId1, TaskType.SPARSE_EMBEDDING, "foo", new TestModel.TestServiceSettings(null, null, null, null), @@ -155,7 +149,7 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() { ); var model2 = new TestModel( - "2", + inferenceId2, TaskType.TEXT_EMBEDDING, "foo", new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), @@ -168,8 +162,8 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() { var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), is(2)); - assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null))); - assertThat(response.get(1), is(new ModelRegistry.ModelStoreResponse("2", RestStatus.CREATED, null))); + assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); + assertThat(response.get(1), is(new ModelRegistry.ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets); assertModelAndMinimalSettingsWithSecrets(registry, model2, secrets); @@ -178,18 +172,17 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() { private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) { assertMinimalServiceSettings(registry, model); - var listener1 = new PlainActionFuture(); - registry.getModelWithSecrets(model.getInferenceEntityId(), listener1); + var listener = new PlainActionFuture(); + registry.getModelWithSecrets(model.getInferenceEntityId(), listener); - var storedModel1 = listener1.actionGet(TIMEOUT); - assertModel(storedModel1, model, secrets); + var storedModel = listener.actionGet(TIMEOUT); + assertModel(storedModel, model, secrets); } private static void assertModel(UnparsedModel model, Model expected, String secrets) { assertThat(model.inferenceEntityId(), is(expected.getInferenceEntityId())); assertThat(model.service(), is(expected.getConfigurations().getService())); assertThat(model.taskType(), is(expected.getConfigurations().getTaskType())); - assertThat(model.taskType(), is(expected.getConfigurations().getTaskType())); assertThat(model.secrets().keySet(), hasSize(1)); assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class)); @SuppressWarnings("unchecked") @@ -200,8 +193,10 @@ private static void assertModel(UnparsedModel model, Model expected, String secr public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() { var secrets = "secret"; + var inferenceId = "1"; + var model1 = new TestModel( - "1", + inferenceId, TaskType.SPARSE_EMBEDDING, "foo", new TestModel.TestServiceSettings(null, null, null, null), @@ -211,7 +206,7 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic var model2 = new TestModel( // using the same inference id as model1 to cause a failure - "1", + inferenceId, TaskType.TEXT_EMBEDDING, "foo", new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), @@ -224,7 +219,7 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), is(2)); - assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null))); + assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertThat(response.get(1).inferenceId(), is(model2.getInferenceEntityId())); assertThat(response.get(1).status(), is(RestStatus.CONFLICT)); assertTrue(response.get(1).failed()); @@ -241,7 +236,7 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { var secrets = "secret"; - var model1 = new TestModel( + var model = new TestModel( "1", TaskType.SPARSE_EMBEDDING, "foo", @@ -250,13 +245,67 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE new TestModel.TestSecretSettings(secrets) ); - storeCorruptedModel(model1, false); + storeCorruptedModel(model, false); PlainActionFuture> storeListener = new PlainActionFuture<>(); - registry.storeModels(List.of(model1), storeListener, TimeValue.THIRTY_SECONDS); + registry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), is(1)); + assertThat(response.get(0).inferenceId(), is(model.getInferenceEntityId())); + assertThat(response.get(0).status(), is(RestStatus.CONFLICT)); + assertTrue(response.get(0).failed()); + + var cause = response.get(0).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + // Since there was a partial write, both documents should be removed + assertIndicesContainExpectedDocsCount(model, 0); + } + + public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint() { + var secrets = "secret"; + + var inferenceId1 = "1"; + var inferenceId2 = "2"; + var inferenceId3 = "3"; + + var model1 = new TestModel( + inferenceId1, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId2, + TaskType.CHAT_COMPLETION, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model3 = new TestModel( + inferenceId3, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + storeCorruptedModel(model1, false); + storeCorruptedModel(model2, true); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + registry.storeModels(List.of(model1, model2, model3), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), is(3)); assertThat(response.get(0).inferenceId(), is(model1.getInferenceEntityId())); assertThat(response.get(0).status(), is(RestStatus.CONFLICT)); assertTrue(response.get(0).failed()); @@ -265,13 +314,34 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE assertNotNull(cause); assertThat(cause, instanceOf(VersionConflictEngineException.class)); assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + + // Since we did a partial write of model1's secrets, both documents should be removed + assertIndicesContainExpectedDocsCount(model1, 0); + + assertThat(response.get(1).inferenceId(), is(model2.getInferenceEntityId())); + assertThat(response.get(1).status(), is(RestStatus.CONFLICT)); + assertTrue(response.get(1).failed()); + + cause = response.get(1).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_2]: version conflict, document already exists")); + + // Since we did a partial write of model2's configurations, both documents should be removed + assertIndicesContainExpectedDocsCount(model2, 0); + + // model3 should be stored successfully + assertModelAndMinimalSettingsWithSecrets(registry, model3, secrets); + assertIndicesContainExpectedDocsCount(model3, 2); } public void testGetModelNoSecrets() { + var inferenceId = "1"; + assertStoreModel( registry, new TestModel( - "1", + inferenceId, TaskType.SPARSE_EMBEDDING, "foo", new TestModel.TestServiceSettings(null, null, null, null), @@ -281,10 +351,10 @@ public void testGetModelNoSecrets() { ); var getListener = new PlainActionFuture(); - registry.getModel("1", getListener); + registry.getModel(inferenceId, getListener); var modelConfig = getListener.actionGet(TIMEOUT); - assertEquals("1", modelConfig.inferenceEntityId()); + assertEquals(inferenceId, modelConfig.inferenceEntityId()); assertEquals("foo", modelConfig.service()); assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); assertNotNull(modelConfig.settings().keySet()); From 58f9a75f8e302b7dc4c821b87ffc914977e6d0f4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 20 Oct 2025 10:14:27 -0400 Subject: [PATCH 13/18] Moving most tests to ModelRegistryIT --- .../integration/ModelRegistryIT.java | 479 +++++++++++++++++- .../inference/registry/ModelRegistry.java | 4 +- .../registry/ModelRegistryTests.java | 427 +--------------- 3 files changed, 470 insertions(+), 440 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 5448811b1c2ae..cf1ace8558405 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -8,16 +8,23 @@ package org.elasticsearch.xpack.inference.integration; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.MinimalServiceSettings; @@ -34,20 +41,23 @@ import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.InferenceIndex; +import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ModelRegistryTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; +import org.hamcrest.Matchers; import org.junit.Before; import java.io.IOException; @@ -65,16 +75,21 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertMinimalServiceSettings; +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertStoreModel; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; @ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 public class ModelRegistryIT extends ESSingleNodeTestCase { @@ -96,7 +111,7 @@ protected Collection> getPlugins() { public void testStoreModel() throws Exception { String inferenceEntityId = "test-store-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } public void testStoreModelWithUnknownFields() throws Exception { @@ -104,7 +119,7 @@ public void testStoreModelWithUnknownFields() throws Exception { Model model = buildModelWithUnknownField(inferenceEntityId); ElasticsearchStatusException statusException = expectThrows( ElasticsearchStatusException.class, - () -> ModelRegistryTests.assertStoreModel(modelRegistry, model) + () -> assertStoreModel(modelRegistry, model) ); assertThat( statusException.getRootCause().getMessage(), @@ -116,7 +131,7 @@ public void testStoreModelWithUnknownFields() throws Exception { public void testGetModel() throws Exception { String inferenceEntityId = "test-get-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); // now get the model AtomicReference exceptionHolder = new AtomicReference<>(); @@ -148,10 +163,10 @@ public void testGetModel() throws Exception { public void testStoreModelFailsWhenModelExists() throws Exception { String inferenceEntityId = "test-put-trained-model-config-exists"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); // a model with the same id exists - var exc = expectThrows(Exception.class, () -> ModelRegistryTests.assertStoreModel(modelRegistry, model)); + var exc = expectThrows(Exception.class, () -> assertStoreModel(modelRegistry, model)); assertThat(exc.getMessage(), containsString("Inference endpoint [test-put-trained-model-config-exists] already exists")); } @@ -159,7 +174,7 @@ public void testDeleteModel() throws Exception { // put models for (var id : new String[] { "model1", "model2", "model3" }) { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference deleteResponseHolder = new AtomicReference<>(); @@ -238,7 +253,7 @@ public void testGetModelsByTaskType() throws InterruptedException { sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.TEXT_EMBEDDING, service)); for (var model : sparseAndTextEmbeddingModels) { - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference exceptionHolder = new AtomicReference<>(); @@ -277,7 +292,7 @@ public void testGetAllModels() throws InterruptedException { for (int i = 0; i < modelCount; i++) { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference> modelHolder = new AtomicReference<>(); @@ -303,7 +318,7 @@ public void testGetModelWithSecrets() throws InterruptedException { var secret = "abc"; var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - ModelRegistryTests.assertStoreModel(modelRegistry, modelWithSecrets); + assertStoreModel(modelRegistry, modelWithSecrets); AtomicReference exceptionHolder = new AtomicReference<>(); AtomicReference modelHolder = new AtomicReference<>(); @@ -352,7 +367,7 @@ public void testGetAllModels_WithDefaults() throws Exception { var id = randomAlphaOfLength(5) + i; var model = createModel(id, randomFrom(TaskType.values()), serviceName); createdModels.put(id, model); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference> modelHolder = new AtomicReference<>(); @@ -495,8 +510,8 @@ public void testGet_WithDefaults() throws InterruptedException { var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName); var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName); - ModelRegistryTests.assertStoreModel(modelRegistry, configured1); - ModelRegistryTests.assertStoreModel(modelRegistry, configured2); + assertStoreModel(modelRegistry, configured1); + assertStoreModel(modelRegistry, configured2); AtomicReference exceptionHolder = new AtomicReference<>(); AtomicReference modelHolder = new AtomicReference<>(); @@ -546,9 +561,9 @@ public void testGetByTaskType_WithDefaults() throws Exception { var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName); var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName); var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName); - ModelRegistryTests.assertStoreModel(modelRegistry, configuredSparse); - ModelRegistryTests.assertStoreModel(modelRegistry, configuredText); - ModelRegistryTests.assertStoreModel(modelRegistry, configuredRerank); + assertStoreModel(modelRegistry, configuredSparse); + assertStoreModel(modelRegistry, configuredText); + assertStoreModel(modelRegistry, configuredRerank); AtomicReference exceptionHolder = new AtomicReference<>(); AtomicReference> modelHolder = new AtomicReference<>(); @@ -577,6 +592,438 @@ public void testGetByTaskType_WithDefaults() throws Exception { assertReturnModelIsModifiable(modelHolder.get().get(0)); } + public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() { + var listener = new PlainActionFuture(); + modelRegistry.getModelWithSecrets("1", listener); + + ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), Matchers.is("Inference endpoint not found [1]")); + } + + public void testStoreModels_StoresSingleInferenceEndpoint() { + var inferenceId = "1"; + var secrets = "secret"; + + var model = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(1)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + + assertMinimalServiceSettings(modelRegistry, model); + + var listener = new PlainActionFuture(); + modelRegistry.getModelWithSecrets(inferenceId, listener); + + var returnedModel = listener.actionGet(TIMEOUT); + assertModel(returnedModel, model, secrets); + } + + public void testStoreModels_StoresMultipleInferenceEndpoints() { + var secrets = "secret"; + var inferenceId1 = "1"; + var inferenceId2 = "2"; + + var model1 = new TestModel( + inferenceId1, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId2, + TaskType.TEXT_EMBEDDING, + "foo", + new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(2)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); + assertThat(response.get(1), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); + + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model2, secrets); + } + + private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) { + assertMinimalServiceSettings(registry, model); + + var listener = new PlainActionFuture(); + registry.getModelWithSecrets(model.getInferenceEntityId(), listener); + + var storedModel = listener.actionGet(TIMEOUT); + assertModel(storedModel, model, secrets); + } + + private static void assertModel(UnparsedModel model, Model expected, String secrets) { + assertThat(model.inferenceEntityId(), Matchers.is(expected.getInferenceEntityId())); + assertThat(model.service(), Matchers.is(expected.getConfigurations().getService())); + assertThat(model.taskType(), Matchers.is(expected.getConfigurations().getTaskType())); + assertThat(model.secrets().keySet(), hasSize(1)); + assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class)); + @SuppressWarnings("unchecked") + var secretSettings = (Map) model.secrets().get("secret_settings"); + assertThat(secretSettings.get("api_key"), Matchers.is(secrets)); + } + + public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() { + var secrets = "secret"; + + var inferenceId = "1"; + + var model1 = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + // using the same inference id as model1 to cause a failure + inferenceId, + TaskType.TEXT_EMBEDDING, + "foo", + new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(2)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(1).inferenceId(), Matchers.is(model2.getInferenceEntityId())); + assertThat(response.get(1).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(1).failed()); + + var cause = response.get(1).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); + assertIndicesContainExpectedDocsCount(model1, 2); + } + + public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { + var secrets = "secret"; + + var model = new TestModel( + "1", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + storeCorruptedModel(model, false); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(1)); + assertThat(response.get(0).inferenceId(), Matchers.is(model.getInferenceEntityId())); + assertThat(response.get(0).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(0).failed()); + + var cause = response.get(0).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + // Since there was a partial write, both documents should be removed + assertIndicesContainExpectedDocsCount(model, 0); + } + + public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint() { + var secrets = "secret"; + + var inferenceId1 = "1"; + var inferenceId2 = "2"; + var inferenceId3 = "3"; + + var model1 = new TestModel( + inferenceId1, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId2, + TaskType.CHAT_COMPLETION, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model3 = new TestModel( + inferenceId3, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + storeCorruptedModel(model1, false); + storeCorruptedModel(model2, true); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model2, model3), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(3)); + assertThat(response.get(0).inferenceId(), Matchers.is(model1.getInferenceEntityId())); + assertThat(response.get(0).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(0).failed()); + + var cause = response.get(0).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + + // Since we did a partial write of model1's secrets, both documents should be removed + assertIndicesContainExpectedDocsCount(model1, 0); + + assertThat(response.get(1).inferenceId(), Matchers.is(model2.getInferenceEntityId())); + assertThat(response.get(1).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(1).failed()); + + cause = response.get(1).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_2]: version conflict, document already exists")); + + // Since we did a partial write of model2's configurations, both documents should be removed + assertIndicesContainExpectedDocsCount(model2, 0); + + // model3 should be stored successfully + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model3, secrets); + assertIndicesContainExpectedDocsCount(model3, 2); + } + + public void testGetModelNoSecrets() { + var inferenceId = "1"; + + assertStoreModel( + modelRegistry, + new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(randomAlphaOfLength(4)) + ) + ); + + var getListener = new PlainActionFuture(); + modelRegistry.getModel(inferenceId, getListener); + + var modelConfig = getListener.actionGet(TIMEOUT); + assertEquals(inferenceId, modelConfig.inferenceEntityId()); + assertEquals("foo", modelConfig.service()); + assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); + assertNotNull(modelConfig.settings().keySet()); + assertThat(modelConfig.secrets().keySet(), empty()); + } + + public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { + var model = TestModel.createRandomInstance(); + assertStoreModel(modelRegistry, model); + } + + public void testStoreModel_ThrowsException_WhenFailureIsAVersionConflict() { + var model = TestModel.createRandomInstance(); + assertStoreModel(modelRegistry, model); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> assertStoreModel(modelRegistry, model)); + assertThat(exception.status(), Matchers.is(RestStatus.BAD_REQUEST)); + assertThat( + exception.getMessage(), + Matchers.is(format("Inference endpoint [%s] already exists", model.getConfigurations().getInferenceEntityId())) + ); + } + + public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() { + var listener = new PlainActionFuture(); + modelRegistry.removeDefaultConfigs(Set.of(), listener); + assertTrue(listener.actionGet(TIMEOUT)); + } + + public void testDeleteModels_Returns_ConflictException_WhenModelIsBeingAdded() { + var model = TestModel.createRandomInstance(); + var newModel = TestModel.createRandomInstance(); + modelRegistry.updateModelTransaction(newModel, model, new PlainActionFuture<>()); + + var listener = new PlainActionFuture(); + + modelRegistry.deleteModels(Set.of(newModel.getInferenceEntityId()), listener); + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + containsString("are currently being updated, please wait until after they are finished updating to delete.") + ); + assertThat(exception.status(), Matchers.is(RestStatus.CONFLICT)); + } + + public void testContainsDefaultConfigId() { + modelRegistry.addDefaultIds( + new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) + ); + modelRegistry.addDefaultIds( + new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) + ); + assertTrue(modelRegistry.containsDefaultConfigId("foo")); + assertFalse(modelRegistry.containsDefaultConfigId("baz")); + } + + public void testDuplicateDefaultIds() { + var id = "my-inference"; + var mockServiceA = mock(InferenceService.class); + when(mockServiceA.name()).thenReturn("service-a"); + var mockServiceB = mock(InferenceService.class); + when(mockServiceB.name()).thenReturn("service-b"); + + modelRegistry.addDefaultIds(new InferenceService.DefaultConfigId(id, MinimalServiceSettingsTests.randomInstance(), mockServiceA)); + var ise = expectThrows( + IllegalStateException.class, + () -> modelRegistry.addDefaultIds( + new InferenceService.DefaultConfigId(id, MinimalServiceSettingsTests.randomInstance(), mockServiceB) + ) + ); + assertThat( + ise.getMessage(), + containsString( + "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [my-inference] declared by " + + "service [service-b]. The inference Id is already use by [service-a] service." + ) + ); + } + + public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(false); + } + + public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(true); + } + + public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { + var model = new TestModel( + "model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + PlainActionFuture firstStoreListener = new PlainActionFuture<>(); + modelRegistry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS); + firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS); + + assertIndicesContainExpectedDocsCount(model, 2); + + PlainActionFuture secondStoreListener = new PlainActionFuture<>(); + modelRegistry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat(exception.getMessage(), containsString("already exists")); + assertThat(exception.status(), Matchers.is(RestStatus.BAD_REQUEST)); + assertIndicesContainExpectedDocsCount(model, 2); + } + + private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { + var model = new TestModel( + "corrupted-model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + storeCorruptedModel(model, storeSecrets); + + assertIndicesContainExpectedDocsCount(model, 1); + + PlainActionFuture storeListener = new PlainActionFuture<>(); + modelRegistry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat(exception.getMessage(), containsString("already exists")); + assertThat(exception.status(), Matchers.is(RestStatus.BAD_REQUEST)); + + assertIndicesContainExpectedDocsCount(model, 0); + } + + private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) { + SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId())))) + .setSize(2) + .setTrackTotalHits(false) + .request(); + SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS); + try { + assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs)); + } finally { + searchResponse.decRef(); + } + } + + private void storeCorruptedModel(Model model, boolean storeSecrets) { + var listener = new PlainActionFuture(); + + client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add( + ModelRegistry.createIndexRequestBuilder( + model.getInferenceEntityId(), + storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME, + storeSecrets ? model.getSecrets() : model.getConfigurations(), + false, + client() + ) + ) + .execute(listener); + + var bulkResponse = listener.actionGet(TIMEOUT); + if (bulkResponse.hasFailures()) { + fail("Failed to store model: " + bulkResponse.buildFailureMessage()); + } + } + + + private void assertInferenceIndexExists() { var indexResponse = client().admin().indices().prepareGetIndex(TEST_REQUEST_TIMEOUT).addIndices(".inference").get(); assertNotNull(indexResponse.getSettings()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 9228ae9e5585a..ffa73ca02324a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -1041,8 +1041,8 @@ private static DeleteByQueryRequest createDeleteRequest(Set inferenceEnt return request; } - // default for testing - static IndexRequestBuilder createIndexRequestBuilder( + // public for testing + public static IndexRequestBuilder createIndexRequestBuilder( String inferenceId, String indexName, ToXContentObject body, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 8d87222b791c5..0f98c5fafc10c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -8,29 +8,19 @@ package org.elasticsearch.xpack.inference.registry; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.MinimalServiceSettingsTests; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.xpack.inference.InferenceIndex; -import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.hamcrest.Matchers; @@ -38,18 +28,12 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.core.Strings.format; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -69,314 +53,6 @@ public void createComponents() { registry = node().injector().getInstance(ModelRegistry.class); } - public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() { - var listener = new PlainActionFuture(); - registry.getModelWithSecrets("1", listener); - - ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [1]")); - } - - public void testGetModelWithSecrets() { - assertStoreModel( - registry, - new TestModel( - "1", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings("secret") - ) - ); - - var listener = new PlainActionFuture(); - registry.getModelWithSecrets("1", listener); - - var modelConfig = listener.actionGet(TIMEOUT); - assertEquals("1", modelConfig.inferenceEntityId()); - assertEquals("foo", modelConfig.service()); - assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); - assertNotNull(modelConfig.settings().keySet()); - assertThat(modelConfig.secrets().keySet(), hasSize(1)); - assertThat(modelConfig.secrets().get("secret_settings"), instanceOf(Map.class)); - @SuppressWarnings("unchecked") - var secretSettings = (Map) modelConfig.secrets().get("secret_settings"); - assertThat(secretSettings.get("api_key"), equalTo("secret")); - } - - public void testStoreModels_StoresSingleInferenceEndpoint() { - var inferenceId = "1"; - var secrets = "secret"; - - var model = new TestModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - PlainActionFuture> storeListener = new PlainActionFuture<>(); - registry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); - - var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(response.size(), is(1)); - assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); - - assertMinimalServiceSettings(registry, model); - - var listener = new PlainActionFuture(); - registry.getModelWithSecrets(inferenceId, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertModel(returnedModel, model, secrets); - } - - public void testStoreModels_StoresMultipleInferenceEndpoints() { - var secrets = "secret"; - var inferenceId1 = "1"; - var inferenceId2 = "2"; - - var model1 = new TestModel( - inferenceId1, - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - var model2 = new TestModel( - inferenceId2, - TaskType.TEXT_EMBEDDING, - "foo", - new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - PlainActionFuture> storeListener = new PlainActionFuture<>(); - registry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); - - var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(response.size(), is(2)); - assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); - assertThat(response.get(1), is(new ModelRegistry.ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); - - assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets); - assertModelAndMinimalSettingsWithSecrets(registry, model2, secrets); - } - - private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) { - assertMinimalServiceSettings(registry, model); - - var listener = new PlainActionFuture(); - registry.getModelWithSecrets(model.getInferenceEntityId(), listener); - - var storedModel = listener.actionGet(TIMEOUT); - assertModel(storedModel, model, secrets); - } - - private static void assertModel(UnparsedModel model, Model expected, String secrets) { - assertThat(model.inferenceEntityId(), is(expected.getInferenceEntityId())); - assertThat(model.service(), is(expected.getConfigurations().getService())); - assertThat(model.taskType(), is(expected.getConfigurations().getTaskType())); - assertThat(model.secrets().keySet(), hasSize(1)); - assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class)); - @SuppressWarnings("unchecked") - var secretSettings = (Map) model.secrets().get("secret_settings"); - assertThat(secretSettings.get("api_key"), is(secrets)); - } - - public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() { - var secrets = "secret"; - - var inferenceId = "1"; - - var model1 = new TestModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - var model2 = new TestModel( - // using the same inference id as model1 to cause a failure - inferenceId, - TaskType.TEXT_EMBEDDING, - "foo", - new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - PlainActionFuture> storeListener = new PlainActionFuture<>(); - registry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); - - var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(response.size(), is(2)); - assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); - assertThat(response.get(1).inferenceId(), is(model2.getInferenceEntityId())); - assertThat(response.get(1).status(), is(RestStatus.CONFLICT)); - assertTrue(response.get(1).failed()); - - var cause = response.get(1).failureCause(); - assertNotNull(cause); - assertThat(cause, instanceOf(VersionConflictEngineException.class)); - assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); - - assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets); - assertIndicesContainExpectedDocsCount(model1, 2); - } - - public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { - var secrets = "secret"; - - var model = new TestModel( - "1", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - storeCorruptedModel(model, false); - - PlainActionFuture> storeListener = new PlainActionFuture<>(); - registry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); - - var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(response.size(), is(1)); - assertThat(response.get(0).inferenceId(), is(model.getInferenceEntityId())); - assertThat(response.get(0).status(), is(RestStatus.CONFLICT)); - assertTrue(response.get(0).failed()); - - var cause = response.get(0).failureCause(); - assertNotNull(cause); - assertThat(cause, instanceOf(VersionConflictEngineException.class)); - assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); - // Since there was a partial write, both documents should be removed - assertIndicesContainExpectedDocsCount(model, 0); - } - - public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint() { - var secrets = "secret"; - - var inferenceId1 = "1"; - var inferenceId2 = "2"; - var inferenceId3 = "3"; - - var model1 = new TestModel( - inferenceId1, - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - var model2 = new TestModel( - inferenceId2, - TaskType.CHAT_COMPLETION, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - var model3 = new TestModel( - inferenceId3, - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(secrets) - ); - - storeCorruptedModel(model1, false); - storeCorruptedModel(model2, true); - - PlainActionFuture> storeListener = new PlainActionFuture<>(); - registry.storeModels(List.of(model1, model2, model3), storeListener, TimeValue.THIRTY_SECONDS); - - var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(response.size(), is(3)); - assertThat(response.get(0).inferenceId(), is(model1.getInferenceEntityId())); - assertThat(response.get(0).status(), is(RestStatus.CONFLICT)); - assertTrue(response.get(0).failed()); - - var cause = response.get(0).failureCause(); - assertNotNull(cause); - assertThat(cause, instanceOf(VersionConflictEngineException.class)); - assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); - - // Since we did a partial write of model1's secrets, both documents should be removed - assertIndicesContainExpectedDocsCount(model1, 0); - - assertThat(response.get(1).inferenceId(), is(model2.getInferenceEntityId())); - assertThat(response.get(1).status(), is(RestStatus.CONFLICT)); - assertTrue(response.get(1).failed()); - - cause = response.get(1).failureCause(); - assertNotNull(cause); - assertThat(cause, instanceOf(VersionConflictEngineException.class)); - assertThat(cause.getMessage(), containsString("[model_2]: version conflict, document already exists")); - - // Since we did a partial write of model2's configurations, both documents should be removed - assertIndicesContainExpectedDocsCount(model2, 0); - - // model3 should be stored successfully - assertModelAndMinimalSettingsWithSecrets(registry, model3, secrets); - assertIndicesContainExpectedDocsCount(model3, 2); - } - - public void testGetModelNoSecrets() { - var inferenceId = "1"; - - assertStoreModel( - registry, - new TestModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(randomAlphaOfLength(4)) - ) - ); - - var getListener = new PlainActionFuture(); - registry.getModel(inferenceId, getListener); - - var modelConfig = getListener.actionGet(TIMEOUT); - assertEquals(inferenceId, modelConfig.inferenceEntityId()); - assertEquals("foo", modelConfig.service()); - assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); - assertNotNull(modelConfig.settings().keySet()); - assertThat(modelConfig.secrets().keySet(), empty()); - } - - public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { - var model = TestModel.createRandomInstance(); - assertStoreModel(registry, model); - } - - public void testStoreModel_ThrowsException_WhenFailureIsAVersionConflict() { - var model = TestModel.createRandomInstance(); - assertStoreModel(registry, model); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> assertStoreModel(registry, model)); - assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); - assertThat( - exception.getMessage(), - is(format("Inference endpoint [%s] already exists", model.getConfigurations().getInferenceEntityId())) - ); - } public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() { var listener = new PlainActionFuture(); @@ -476,99 +152,6 @@ public void testDuplicateDefaultIds() { ); } - public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { - storeCorruptedModelThenStoreModel(false); - } - - public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() { - storeCorruptedModelThenStoreModel(true); - } - - public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { - var model = new TestModel( - "model-id", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings("secret") - ); - - PlainActionFuture firstStoreListener = new PlainActionFuture<>(); - registry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS); - firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS); - - assertIndicesContainExpectedDocsCount(model, 2); - - PlainActionFuture secondStoreListener = new PlainActionFuture<>(); - registry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); - assertThat(exception.getMessage(), containsString("already exists")); - assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); - assertIndicesContainExpectedDocsCount(model, 2); - } - - private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { - var model = new TestModel( - "corrupted-model-id", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings("secret") - ); - - storeCorruptedModel(model, storeSecrets); - - assertIndicesContainExpectedDocsCount(model, 1); - - PlainActionFuture storeListener = new PlainActionFuture<>(); - registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); - assertThat(exception.getMessage(), containsString("already exists")); - assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); - - assertIndicesContainExpectedDocsCount(model, 0); - } - - private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) { - SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) - .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId())))) - .setSize(2) - .setTrackTotalHits(false) - .request(); - SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS); - try { - assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs)); - } finally { - searchResponse.decRef(); - } - } - - private void storeCorruptedModel(Model model, boolean storeSecrets) { - var listener = new PlainActionFuture(); - - client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add( - ModelRegistry.createIndexRequestBuilder( - model.getInferenceEntityId(), - storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME, - storeSecrets ? model.getSecrets() : model.getConfigurations(), - false, - client() - ) - ) - .execute(listener); - - var bulkResponse = listener.actionGet(TIMEOUT); - if (bulkResponse.hasFailures()) { - fail("Failed to store model: " + bulkResponse.buildFailureMessage()); - } - } - public static void assertStoreModel(ModelRegistry registry, Model model) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); @@ -577,12 +160,12 @@ public static void assertStoreModel(ModelRegistry registry, Model model) { assertMinimalServiceSettings(registry, model); } - private static void assertMinimalServiceSettings(ModelRegistry registry, Model model) { + public static void assertMinimalServiceSettings(ModelRegistry registry, Model model) { var settings = registry.getMinimalServiceSettings(model.getInferenceEntityId()); assertNotNull(settings); - assertThat(settings.taskType(), equalTo(model.getTaskType())); - assertThat(settings.dimensions(), equalTo(model.getServiceSettings().dimensions())); - assertThat(settings.elementType(), equalTo(model.getServiceSettings().elementType())); - assertThat(settings.dimensions(), equalTo(model.getServiceSettings().dimensions())); + assertThat(settings.taskType(), Matchers.equalTo(model.getTaskType())); + assertThat(settings.dimensions(), Matchers.equalTo(model.getServiceSettings().dimensions())); + assertThat(settings.elementType(), Matchers.equalTo(model.getServiceSettings().elementType())); + assertThat(settings.dimensions(), Matchers.equalTo(model.getServiceSettings().dimensions())); } } From 1854e50ce0ee1c67eb92c74209de5928b02f398b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 20 Oct 2025 14:20:44 +0000 Subject: [PATCH 14/18] [CI] Auto commit changes from spotless --- .../xpack/inference/integration/ModelRegistryIT.java | 2 -- .../xpack/inference/registry/ModelRegistryTests.java | 1 - 2 files changed, 3 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index cf1ace8558405..81c52109c384c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -1022,8 +1022,6 @@ private void storeCorruptedModel(Model model, boolean storeSecrets) { } } - - private void assertInferenceIndexExists() { var indexResponse = client().admin().indices().prepareGetIndex(TEST_REQUEST_TIMEOUT).addIndices(".inference").get(); assertNotNull(indexResponse.getSettings()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 0f98c5fafc10c..a3cb0ba244d7a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -53,7 +53,6 @@ public void createComponents() { registry = node().injector().getInstance(ModelRegistry.class); } - public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() { var listener = new PlainActionFuture(); registry.removeDefaultConfigs(Set.of(), listener); From a8db6cf97e9e60bda20976012389cb33f550b333 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 20 Oct 2025 10:57:23 -0400 Subject: [PATCH 15/18] Fixing test --- .../xpack/inference/registry/ModelRegistryMetadataTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java index 8726632a77095..9af21386e93d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java @@ -301,7 +301,7 @@ public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_Wit ) ); // ensure metadata hasn't changed - assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)))); assertThat(newMetadata, not(is(metadata))); assertThat( newMetadata, From 426fcbec6e068216ce877af1b0296635786fbf42 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 20 Oct 2025 15:30:14 -0400 Subject: [PATCH 16/18] Removing duplicate tests --- .../integration/ModelRegistryIT.java | 63 +------------------ 1 file changed, 3 insertions(+), 60 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 81c52109c384c..2c65d753fb1a8 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -89,7 +89,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; @ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 public class ModelRegistryIT extends ESSingleNodeTestCase { @@ -108,13 +107,13 @@ protected Collection> getPlugins() { return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } - public void testStoreModel() throws Exception { + public void testStoreModel() { String inferenceEntityId = "test-store-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); assertStoreModel(modelRegistry, model); } - public void testStoreModelWithUnknownFields() throws Exception { + public void testStoreModelWithUnknownFields() { String inferenceEntityId = "test-store-model-unknown-field"; Model model = buildModelWithUnknownField(inferenceEntityId); ElasticsearchStatusException statusException = expectThrows( @@ -160,7 +159,7 @@ public void testGetModel() throws Exception { assertEquals(model, roundTripModel); } - public void testStoreModelFailsWhenModelExists() throws Exception { + public void testStoreModelFailsWhenModelExists() { String inferenceEntityId = "test-put-trained-model-config-exists"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); assertStoreModel(modelRegistry, model); @@ -873,62 +872,6 @@ public void testStoreModel_ThrowsException_WhenFailureIsAVersionConflict() { ); } - public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() { - var listener = new PlainActionFuture(); - modelRegistry.removeDefaultConfigs(Set.of(), listener); - assertTrue(listener.actionGet(TIMEOUT)); - } - - public void testDeleteModels_Returns_ConflictException_WhenModelIsBeingAdded() { - var model = TestModel.createRandomInstance(); - var newModel = TestModel.createRandomInstance(); - modelRegistry.updateModelTransaction(newModel, model, new PlainActionFuture<>()); - - var listener = new PlainActionFuture(); - - modelRegistry.deleteModels(Set.of(newModel.getInferenceEntityId()), listener); - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - containsString("are currently being updated, please wait until after they are finished updating to delete.") - ); - assertThat(exception.status(), Matchers.is(RestStatus.CONFLICT)); - } - - public void testContainsDefaultConfigId() { - modelRegistry.addDefaultIds( - new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) - ); - modelRegistry.addDefaultIds( - new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) - ); - assertTrue(modelRegistry.containsDefaultConfigId("foo")); - assertFalse(modelRegistry.containsDefaultConfigId("baz")); - } - - public void testDuplicateDefaultIds() { - var id = "my-inference"; - var mockServiceA = mock(InferenceService.class); - when(mockServiceA.name()).thenReturn("service-a"); - var mockServiceB = mock(InferenceService.class); - when(mockServiceB.name()).thenReturn("service-b"); - - modelRegistry.addDefaultIds(new InferenceService.DefaultConfigId(id, MinimalServiceSettingsTests.randomInstance(), mockServiceA)); - var ise = expectThrows( - IllegalStateException.class, - () -> modelRegistry.addDefaultIds( - new InferenceService.DefaultConfigId(id, MinimalServiceSettingsTests.randomInstance(), mockServiceB) - ) - ); - assertThat( - ise.getMessage(), - containsString( - "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [my-inference] declared by " - + "service [service-b]. The inference Id is already use by [service-a] service." - ) - ); - } - public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { storeCorruptedModelThenStoreModel(false); } From 65e008fb3ec8e3ee94347dc201838f7738416a53 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 21 Oct 2025 10:58:42 -0400 Subject: [PATCH 17/18] Handling empty list and duplicates --- .../integration/ModelRegistryIT.java | 42 +++++++++++++++++++ .../inference/registry/ModelRegistry.java | 11 ++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 2c65d753fb1a8..67d67183a37e3 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -599,6 +599,14 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() assertThat(exception.getMessage(), Matchers.is("Inference endpoint not found [1]")); } + public void testStoreModels_ReturnsEmptyList_WhenGivenNoModelsToStore() { + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response, is(List.of())); + } + public void testStoreModels_StoresSingleInferenceEndpoint() { var inferenceId = "1"; var secrets = "secret"; @@ -727,6 +735,40 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic assertIndicesContainExpectedDocsCount(model1, 2); } + public void testStoreModels_StoresOneModel_RemovesSecondDuplicateModelFromList_DoesNotThrowException() { + var secrets = "secret"; + var inferenceId = "1"; + var temperature = randomInt(3); + + var model1 = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(temperature), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(temperature), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(1)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); + assertIndicesContainExpectedDocsCount(model1, 2); + } + public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { var secrets = "secret"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index ffa73ca02324a..03d448580a7d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -707,9 +707,16 @@ private void storeModels( ActionListener> listener, TimeValue timeout ) { + if (models.isEmpty()) { + listener.onResponse(List.of()); + return; + } + + var modelsWithoutDuplicates = models.stream().distinct().toList(); + var bulkRequestBuilder = client.prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (var model : models) { + for (var model : modelsWithoutDuplicates) { bulkRequestBuilder.add( createIndexRequestBuilder(model.getInferenceEntityId(), InferenceIndex.INDEX_NAME, model.getConfigurations(), false, client) ); @@ -719,7 +726,7 @@ private void storeModels( ); } - bulkRequestBuilder.execute(getStoreMultipleModelsListener(models, updateClusterState, listener, timeout)); + bulkRequestBuilder.execute(getStoreMultipleModelsListener(modelsWithoutDuplicates, updateClusterState, listener, timeout)); } private ActionListener getStoreMultipleModelsListener( From 7cea6c195f289d9ae81697912a0d744d06f7c141 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 22 Oct 2025 10:57:25 -0400 Subject: [PATCH 18/18] Fixing empty delete --- .../inference/registry/ModelRegistry.java | 18 +++++++++++++++++- .../inference/registry/ModelRegistryTests.java | 9 +++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 03d448580a7d7..e556b1db9ecd8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -743,7 +743,18 @@ private ActionListener getStoreMultipleModelsListener( var storageResponses = responses.stream().map(StoreResponseWithIndexInfo::modelStoreResponse).toList(); - deleteModels(inferenceIdsToBeRemoved, ActionListener.running(() -> delegate.onResponse(storageResponses))); + ActionListener deleteListener = ActionListener.wrap(ignored -> delegate.onResponse(storageResponses), e -> { + logger.atWarn() + .withThrowable(e) + .log( + "Failed to clean up partially stored inference endpoints {}. " + + "The service may be in an inconsistent state. Please try deleting and re-adding the endpoints.", + inferenceIdsToBeRemoved + ); + delegate.onResponse(storageResponses); + }); + + deleteModels(inferenceIdsToBeRemoved, deleteListener); }); return ActionListener.wrap(bulkItemResponses -> { @@ -950,6 +961,11 @@ public void deleteModels(Set inferenceEntityIds, ActionListener } private void deleteModels(Set inferenceEntityIds, boolean updateClusterState, ActionListener listener) { + if (inferenceEntityIds.isEmpty()) { + listener.onResponse(true); + return; + } + var lockedInferenceIds = new HashSet<>(inferenceEntityIds); lockedInferenceIds.retainAll(preventDeletionLock); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index a3cb0ba244d7a..44f0dcc1d8962 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -151,6 +151,15 @@ public void testDuplicateDefaultIds() { ); } + public void testDeleteModels_Succeeds_WhenNoInferenceIdsAreProvided() { + var model = TestModel.createRandomInstance(); + assertStoreModel(registry, model); + + var listener = new PlainActionFuture(); + registry.deleteModels(Set.of(), listener); + assertTrue(listener.actionGet(TIMEOUT)); + } + public static void assertStoreModel(ModelRegistry registry, Model model) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS);