diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 5448811b1c2ae..67d67183a37e3 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -8,16 +8,23 @@ package org.elasticsearch.xpack.inference.integration; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.MinimalServiceSettings; @@ -34,20 +41,23 @@ import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.InferenceIndex; +import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ModelRegistryTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; +import org.hamcrest.Matchers; import org.junit.Before; import java.io.IOException; @@ -65,11 +75,15 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertMinimalServiceSettings; +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertStoreModel; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; @@ -93,18 +107,18 @@ protected Collection> getPlugins() { return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } - public void testStoreModel() throws Exception { + public void testStoreModel() { String inferenceEntityId = "test-store-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } - public void testStoreModelWithUnknownFields() throws Exception { + public void testStoreModelWithUnknownFields() { String inferenceEntityId = "test-store-model-unknown-field"; Model model = buildModelWithUnknownField(inferenceEntityId); ElasticsearchStatusException statusException = expectThrows( ElasticsearchStatusException.class, - () -> ModelRegistryTests.assertStoreModel(modelRegistry, model) + () -> assertStoreModel(modelRegistry, model) ); assertThat( statusException.getRootCause().getMessage(), @@ -116,7 +130,7 @@ public void testStoreModelWithUnknownFields() throws Exception { public void testGetModel() throws Exception { String inferenceEntityId = "test-get-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); // now get the model AtomicReference exceptionHolder = new AtomicReference<>(); @@ -145,13 +159,13 @@ public void testGetModel() throws Exception { assertEquals(model, roundTripModel); } - public void testStoreModelFailsWhenModelExists() throws Exception { + public void testStoreModelFailsWhenModelExists() { String inferenceEntityId = "test-put-trained-model-config-exists"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); // a model with the same id exists - var exc = expectThrows(Exception.class, () -> ModelRegistryTests.assertStoreModel(modelRegistry, model)); + var exc = expectThrows(Exception.class, () -> assertStoreModel(modelRegistry, model)); assertThat(exc.getMessage(), containsString("Inference endpoint [test-put-trained-model-config-exists] already exists")); } @@ -159,7 +173,7 @@ public void testDeleteModel() throws Exception { // put models for (var id : new String[] { "model1", "model2", "model3" }) { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference deleteResponseHolder = new AtomicReference<>(); @@ -238,7 +252,7 @@ public void testGetModelsByTaskType() throws InterruptedException { sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.TEXT_EMBEDDING, service)); for (var model : sparseAndTextEmbeddingModels) { - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference exceptionHolder = new AtomicReference<>(); @@ -277,7 +291,7 @@ public void testGetAllModels() throws InterruptedException { for (int i = 0; i < modelCount; i++) { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference> modelHolder = new AtomicReference<>(); @@ -303,7 +317,7 @@ public void testGetModelWithSecrets() throws InterruptedException { var secret = "abc"; var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - ModelRegistryTests.assertStoreModel(modelRegistry, modelWithSecrets); + assertStoreModel(modelRegistry, modelWithSecrets); AtomicReference exceptionHolder = new AtomicReference<>(); AtomicReference modelHolder = new AtomicReference<>(); @@ -352,7 +366,7 @@ public void testGetAllModels_WithDefaults() throws Exception { var id = randomAlphaOfLength(5) + i; var model = createModel(id, randomFrom(TaskType.values()), serviceName); createdModels.put(id, model); - ModelRegistryTests.assertStoreModel(modelRegistry, model); + assertStoreModel(modelRegistry, model); } AtomicReference> modelHolder = new AtomicReference<>(); @@ -495,8 +509,8 @@ public void testGet_WithDefaults() throws InterruptedException { var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName); var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName); - ModelRegistryTests.assertStoreModel(modelRegistry, configured1); - ModelRegistryTests.assertStoreModel(modelRegistry, configured2); + assertStoreModel(modelRegistry, configured1); + assertStoreModel(modelRegistry, configured2); AtomicReference exceptionHolder = new AtomicReference<>(); AtomicReference modelHolder = new AtomicReference<>(); @@ -546,9 +560,9 @@ public void testGetByTaskType_WithDefaults() throws Exception { var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName); var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName); var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName); - ModelRegistryTests.assertStoreModel(modelRegistry, configuredSparse); - ModelRegistryTests.assertStoreModel(modelRegistry, configuredText); - ModelRegistryTests.assertStoreModel(modelRegistry, configuredRerank); + assertStoreModel(modelRegistry, configuredSparse); + assertStoreModel(modelRegistry, configuredText); + assertStoreModel(modelRegistry, configuredRerank); AtomicReference exceptionHolder = new AtomicReference<>(); AtomicReference> modelHolder = new AtomicReference<>(); @@ -577,6 +591,422 @@ public void testGetByTaskType_WithDefaults() throws Exception { assertReturnModelIsModifiable(modelHolder.get().get(0)); } + public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() { + var listener = new PlainActionFuture(); + modelRegistry.getModelWithSecrets("1", listener); + + ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), Matchers.is("Inference endpoint not found [1]")); + } + + public void testStoreModels_ReturnsEmptyList_WhenGivenNoModelsToStore() { + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response, is(List.of())); + } + + public void testStoreModels_StoresSingleInferenceEndpoint() { + var inferenceId = "1"; + var secrets = "secret"; + + var model = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(1)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + + assertMinimalServiceSettings(modelRegistry, model); + + var listener = new PlainActionFuture(); + modelRegistry.getModelWithSecrets(inferenceId, listener); + + var returnedModel = listener.actionGet(TIMEOUT); + assertModel(returnedModel, model, secrets); + } + + public void testStoreModels_StoresMultipleInferenceEndpoints() { + var secrets = "secret"; + var inferenceId1 = "1"; + var inferenceId2 = "2"; + + var model1 = new TestModel( + inferenceId1, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId2, + TaskType.TEXT_EMBEDDING, + "foo", + new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(2)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); + assertThat(response.get(1), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); + + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model2, secrets); + } + + private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) { + assertMinimalServiceSettings(registry, model); + + var listener = new PlainActionFuture(); + registry.getModelWithSecrets(model.getInferenceEntityId(), listener); + + var storedModel = listener.actionGet(TIMEOUT); + assertModel(storedModel, model, secrets); + } + + private static void assertModel(UnparsedModel model, Model expected, String secrets) { + assertThat(model.inferenceEntityId(), Matchers.is(expected.getInferenceEntityId())); + assertThat(model.service(), Matchers.is(expected.getConfigurations().getService())); + assertThat(model.taskType(), Matchers.is(expected.getConfigurations().getTaskType())); + assertThat(model.secrets().keySet(), hasSize(1)); + assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class)); + @SuppressWarnings("unchecked") + var secretSettings = (Map) model.secrets().get("secret_settings"); + assertThat(secretSettings.get("api_key"), Matchers.is(secrets)); + } + + public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() { + var secrets = "secret"; + + var inferenceId = "1"; + + var model1 = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + // using the same inference id as model1 to cause a failure + inferenceId, + TaskType.TEXT_EMBEDDING, + "foo", + new TestModel.TestServiceSettings("model", 123, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(2)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(1).inferenceId(), Matchers.is(model2.getInferenceEntityId())); + assertThat(response.get(1).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(1).failed()); + + var cause = response.get(1).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); + assertIndicesContainExpectedDocsCount(model1, 2); + } + + public void testStoreModels_StoresOneModel_RemovesSecondDuplicateModelFromList_DoesNotThrowException() { + var secrets = "secret"; + var inferenceId = "1"; + var temperature = randomInt(3); + + var model1 = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(temperature), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(temperature), + new TestModel.TestSecretSettings(secrets) + ); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model1, model2), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(1)); + assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); + assertIndicesContainExpectedDocsCount(model1, 2); + } + + public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyExists() { + var secrets = "secret"; + + var model = new TestModel( + "1", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + storeCorruptedModel(model, false); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(1)); + assertThat(response.get(0).inferenceId(), Matchers.is(model.getInferenceEntityId())); + assertThat(response.get(0).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(0).failed()); + + var cause = response.get(0).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + // Since there was a partial write, both documents should be removed + assertIndicesContainExpectedDocsCount(model, 0); + } + + public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint() { + var secrets = "secret"; + + var inferenceId1 = "1"; + var inferenceId2 = "2"; + var inferenceId3 = "3"; + + var model1 = new TestModel( + inferenceId1, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model2 = new TestModel( + inferenceId2, + TaskType.CHAT_COMPLETION, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + var model3 = new TestModel( + inferenceId3, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(secrets) + ); + + storeCorruptedModel(model1, false); + storeCorruptedModel(model2, true); + + PlainActionFuture> storeListener = new PlainActionFuture<>(); + modelRegistry.storeModels(List.of(model1, model2, model3), storeListener, TimeValue.THIRTY_SECONDS); + + var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(response.size(), Matchers.is(3)); + assertThat(response.get(0).inferenceId(), Matchers.is(model1.getInferenceEntityId())); + assertThat(response.get(0).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(0).failed()); + + var cause = response.get(0).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists")); + + // Since we did a partial write of model1's secrets, both documents should be removed + assertIndicesContainExpectedDocsCount(model1, 0); + + assertThat(response.get(1).inferenceId(), Matchers.is(model2.getInferenceEntityId())); + assertThat(response.get(1).status(), Matchers.is(RestStatus.CONFLICT)); + assertTrue(response.get(1).failed()); + + cause = response.get(1).failureCause(); + assertNotNull(cause); + assertThat(cause, instanceOf(VersionConflictEngineException.class)); + assertThat(cause.getMessage(), containsString("[model_2]: version conflict, document already exists")); + + // Since we did a partial write of model2's configurations, both documents should be removed + assertIndicesContainExpectedDocsCount(model2, 0); + + // model3 should be stored successfully + assertModelAndMinimalSettingsWithSecrets(modelRegistry, model3, secrets); + assertIndicesContainExpectedDocsCount(model3, 2); + } + + public void testGetModelNoSecrets() { + var inferenceId = "1"; + + assertStoreModel( + modelRegistry, + new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(randomAlphaOfLength(4)) + ) + ); + + var getListener = new PlainActionFuture(); + modelRegistry.getModel(inferenceId, getListener); + + var modelConfig = getListener.actionGet(TIMEOUT); + assertEquals(inferenceId, modelConfig.inferenceEntityId()); + assertEquals("foo", modelConfig.service()); + assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); + assertNotNull(modelConfig.settings().keySet()); + assertThat(modelConfig.secrets().keySet(), empty()); + } + + public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { + var model = TestModel.createRandomInstance(); + assertStoreModel(modelRegistry, model); + } + + public void testStoreModel_ThrowsException_WhenFailureIsAVersionConflict() { + var model = TestModel.createRandomInstance(); + assertStoreModel(modelRegistry, model); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> assertStoreModel(modelRegistry, model)); + assertThat(exception.status(), Matchers.is(RestStatus.BAD_REQUEST)); + assertThat( + exception.getMessage(), + Matchers.is(format("Inference endpoint [%s] already exists", model.getConfigurations().getInferenceEntityId())) + ); + } + + public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(false); + } + + public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(true); + } + + public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { + var model = new TestModel( + "model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + PlainActionFuture firstStoreListener = new PlainActionFuture<>(); + modelRegistry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS); + firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS); + + assertIndicesContainExpectedDocsCount(model, 2); + + PlainActionFuture secondStoreListener = new PlainActionFuture<>(); + modelRegistry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat(exception.getMessage(), containsString("already exists")); + assertThat(exception.status(), Matchers.is(RestStatus.BAD_REQUEST)); + assertIndicesContainExpectedDocsCount(model, 2); + } + + private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { + var model = new TestModel( + "corrupted-model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + storeCorruptedModel(model, storeSecrets); + + assertIndicesContainExpectedDocsCount(model, 1); + + PlainActionFuture storeListener = new PlainActionFuture<>(); + modelRegistry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat(exception.getMessage(), containsString("already exists")); + assertThat(exception.status(), Matchers.is(RestStatus.BAD_REQUEST)); + + assertIndicesContainExpectedDocsCount(model, 0); + } + + private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) { + SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId())))) + .setSize(2) + .setTrackTotalHits(false) + .request(); + SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS); + try { + assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs)); + } finally { + searchResponse.decRef(); + } + } + + private void storeCorruptedModel(Model model, boolean storeSecrets) { + var listener = new PlainActionFuture(); + + client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add( + ModelRegistry.createIndexRequestBuilder( + model.getInferenceEntityId(), + storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME, + storeSecrets ? model.getSecrets() : model.getConfigurations(), + false, + client() + ) + ) + .execute(listener); + + var bulkResponse = listener.actionGet(TIMEOUT); + if (bulkResponse.hasFailures()) { + fail("Failed to store model: " + bulkResponse.buildFailureMessage()); + } + } + private void assertInferenceIndexExists() { var indexResponse = client().admin().indices().prepareGetIndex(TEST_REQUEST_TIMEOUT).addIndices(".inference").get(); assertNotNull(indexResponse.getSettings()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index a042839b23d4e..e556b1db9ecd8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -17,10 +17,8 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -46,6 +44,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.gateway.GatewayService; @@ -662,163 +661,286 @@ public void storeModel(Model model, ActionListener listener, TimeValue } private void storeModel(Model model, boolean updateClusterState, ActionListener listener, TimeValue timeout) { - ActionListener bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout); - String inferenceEntityId = model.getConfigurations().getInferenceEntityId(); - var configRequestBuilder = createIndexRequestBuilder( - inferenceEntityId, - InferenceIndex.INDEX_NAME, - model.getConfigurations(), - false, - client - ); - var secretsRequestBuilder = createIndexRequestBuilder( - inferenceEntityId, - InferenceSecretsIndex.INDEX_NAME, - model.getSecrets(), - false, - client - ); + storeModels(List.of(model), updateClusterState, listener.delegateFailureAndWrap((delegate, responses) -> { + var firstFailureResponse = responses.stream().filter(ModelStoreResponse::failed).findFirst(); + if (firstFailureResponse.isPresent() == false) { + delegate.onResponse(Boolean.TRUE); + return; + } + + var failureItem = firstFailureResponse.get(); + if (ExceptionsHelper.unwrapCause(failureItem.failureCause()) instanceof VersionConflictEngineException) { + delegate.onFailure( + new ElasticsearchStatusException( + "Inference endpoint [{}] already exists", + RestStatus.BAD_REQUEST, + failureItem.failureCause(), + failureItem.inferenceId() + ) + ); + return; + } + + delegate.onFailure( + new ElasticsearchStatusException( + format("Failed to store inference endpoint [%s]", failureItem.inferenceId()), + RestStatus.INTERNAL_SERVER_ERROR, + failureItem.failureCause() + ) + ); + }), timeout); + } + + public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) { + public boolean failed() { + return failureCause != null; + } + } - client.prepareBulk() - .add(configRequestBuilder) - .add(secretsRequestBuilder) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .execute(bulkResponseActionListener); + public void storeModels(List models, ActionListener> listener, TimeValue timeout) { + storeModels(models, true, listener, timeout); } - private ActionListener getStoreIndexListener( - Model model, + private void storeModels( + List models, boolean updateClusterState, - ActionListener listener, + ActionListener> listener, TimeValue timeout ) { - // If there was a partial failure in writing to the indices, we need to clean up - AtomicBoolean partialFailure = new AtomicBoolean(false); - var cleanupListener = listener.delegateResponse((delegate, ex) -> { - if (partialFailure.get()) { - deleteModel(model.getInferenceEntityId(), ActionListener.running(() -> delegate.onFailure(ex))); - } else { - delegate.onFailure(ex); - } + if (models.isEmpty()) { + listener.onResponse(List.of()); + return; + } + + var modelsWithoutDuplicates = models.stream().distinct().toList(); + + var bulkRequestBuilder = client.prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (var model : modelsWithoutDuplicates) { + bulkRequestBuilder.add( + createIndexRequestBuilder(model.getInferenceEntityId(), InferenceIndex.INDEX_NAME, model.getConfigurations(), false, client) + ); + + bulkRequestBuilder.add( + createIndexRequestBuilder(model.getInferenceEntityId(), InferenceSecretsIndex.INDEX_NAME, model.getSecrets(), false, client) + ); + } + + bulkRequestBuilder.execute(getStoreMultipleModelsListener(modelsWithoutDuplicates, updateClusterState, listener, timeout)); + } + + private ActionListener getStoreMultipleModelsListener( + List models, + boolean updateClusterState, + ActionListener> listener, + TimeValue timeout + ) { + var cleanupListener = listener.>delegateFailureAndWrap((delegate, responses) -> { + var inferenceIdsToBeRemoved = responses.stream() + .filter(r -> r.modifiedIndex() && r.modelStoreResponse().failed()) + .map(r -> r.modelStoreResponse().inferenceId()) + .collect(Collectors.toSet()); + + var storageResponses = responses.stream().map(StoreResponseWithIndexInfo::modelStoreResponse).toList(); + + ActionListener deleteListener = ActionListener.wrap(ignored -> delegate.onResponse(storageResponses), e -> { + logger.atWarn() + .withThrowable(e) + .log( + "Failed to clean up partially stored inference endpoints {}. " + + "The service may be in an inconsistent state. Please try deleting and re-adding the endpoints.", + inferenceIdsToBeRemoved + ); + delegate.onResponse(storageResponses); + }); + + deleteModels(inferenceIdsToBeRemoved, deleteListener); }); + return ActionListener.wrap(bulkItemResponses -> { - var inferenceEntityId = model.getInferenceEntityId(); + var docIdToInferenceId = models.stream() + .collect(Collectors.toMap(m -> Model.documentId(m.getInferenceEntityId()), Model::getInferenceEntityId, (id1, id2) -> { + logger.warn("Encountered duplicate inference ids when storing endpoints: [{}]", id1); + return id1; + })); + var inferenceIdToModel = models.stream() + .collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity(), (id1, id2) -> id1)); if (bulkItemResponses.getItems().length == 0) { - logger.warn( - format("Storing inference endpoint [%s] failed, no items were received from the bulk response", inferenceEntityId) - ); + var inferenceEntityIds = String.join(", ", models.stream().map(Model::getInferenceEntityId).toList()); + logger.warn("Storing inference endpoints [{}] failed, no items were received from the bulk response", inferenceEntityIds); - cleanupListener.onFailure( + listener.onFailure( new ElasticsearchStatusException( - format( - "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service", - inferenceEntityId - ), - RestStatus.INTERNAL_SERVER_ERROR + "Failed to store inference endpoints [{}], empty bulk response received.", + RestStatus.INTERNAL_SERVER_ERROR, + inferenceEntityIds ) ); return; } - BulkItemResponse.Failure failure = getFirstBulkFailure(bulkItemResponses); - - if (failure == null) { - if (updateClusterState) { - var storeListener = getStoreMetadataListener(inferenceEntityId, cleanupListener); - try { - metadataTaskQueue.submitTask( - "add model [" + inferenceEntityId + "]", - new AddModelMetadataTask( - ProjectId.DEFAULT, - inferenceEntityId, - new MinimalServiceSettings(model), - storeListener - ), - timeout - ); - } catch (Exception exc) { - storeListener.onFailure(exc); - } - } else { - cleanupListener.onResponse(Boolean.TRUE); - } - return; - } - - for (BulkItemResponse aResponse : bulkItemResponses.getItems()) { - logBulkFailure(inferenceEntityId, aResponse); - partialFailure.compareAndSet(false, aResponse.isFailed() == false); - } + var responseInfo = getResponseInfo(bulkItemResponses, docIdToInferenceId, inferenceIdToModel); - if (ExceptionsHelper.unwrapCause(failure.getCause()) instanceof VersionConflictEngineException) { - cleanupListener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId)); - return; + if (updateClusterState) { + updateClusterState( + responseInfo.successfullyStoredModels(), + cleanupListener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.responses())), + timeout + ); + } else { + cleanupListener.onResponse(responseInfo.responses()); } - - cleanupListener.onFailure( - new ElasticsearchStatusException( - format("Failed to store inference endpoint [%s]", inferenceEntityId), - RestStatus.INTERNAL_SERVER_ERROR, - failure.getCause() - ) - ); }, e -> { - String errorMessage = format("Failed to store inference endpoint [%s]", model.getInferenceEntityId()); + String errorMessage = format( + "Failed to store inference endpoints [%s]", + models.stream().map(Model::getInferenceEntityId).collect(Collectors.joining(", ")) + ); logger.warn(errorMessage, e); - cleanupListener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); + listener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); }); } - private ActionListener getStoreMetadataListener(String inferenceEntityId, ActionListener listener) { - return new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse resp) { - listener.onResponse(true); - } + private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {} - @Override - public void onFailure(Exception exc) { - logger.warn( - format("Failed to add inference endpoint [%s] minimal service settings to cluster state", inferenceEntityId), - exc - ); - deleteModel(inferenceEntityId, ActionListener.running(() -> { - listener.onFailure( - new ElasticsearchStatusException( - format( - "Failed to add the inference endpoint [%s]. The service may be in an " - + "inconsistent state. Please try deleting and re-adding the endpoint.", - inferenceEntityId + private record ResponseInfo(List responses, List successfullyStoredModels) {} + + private static ResponseInfo getResponseInfo( + BulkResponse bulkResponse, + Map docIdToInferenceId, + Map inferenceIdToModel + ) { + var responses = new ArrayList(); + var successfullyStoredModels = new ArrayList(); + + var bulkItems = bulkResponse.getItems(); + for (int i = 0; i < bulkItems.length; i += 2) { + var configurationItem = bulkItems[i]; + var configStoreResponse = createModelStoreResponse(configurationItem, docIdToInferenceId); + var modelFromBulkItem = getModelFromMap(docIdToInferenceId.get(configurationItem.getId()), inferenceIdToModel); + + if (i + 1 >= bulkResponse.getItems().length) { + logger.error("Expected an even number of bulk response items, got [{}]", bulkResponse.getItems().length); + + if (configStoreResponse.failed()) { + responses.add(new StoreResponseWithIndexInfo(configStoreResponse, false)); + } else { + // if we didn't get the last item for some reason, assume it is a failure + responses.add( + new StoreResponseWithIndexInfo( + new ModelStoreResponse( + configStoreResponse.inferenceId(), + RestStatus.INTERNAL_SERVER_ERROR, + new IllegalStateException("Failed to receive part of bulk response") ), - RestStatus.INTERNAL_SERVER_ERROR, - exc + true ) ); - })); + } + return new ResponseInfo(responses, successfullyStoredModels); } - }; + + var secretsItem = bulkItems[i + 1]; + var secretsStoreResponse = createModelStoreResponse(secretsItem, docIdToInferenceId); + + assert secretsStoreResponse.inferenceId().equals(configStoreResponse.inferenceId()) + : "Mismatched inference ids in bulk response items, configuration id [" + + configStoreResponse.inferenceId() + + "] secrets id [" + + secretsStoreResponse.inferenceId() + + "]"; + + if (configStoreResponse.failed()) { + responses.add(new StoreResponseWithIndexInfo(configStoreResponse, secretsStoreResponse.failed() == false)); + } else if (secretsStoreResponse.failed()) { + // if we got here that means the configuration store bulk item succeeded, so we did modify an index + responses.add(new StoreResponseWithIndexInfo(secretsStoreResponse, true)); + } else { + responses.add(new StoreResponseWithIndexInfo(configStoreResponse, true)); + if (modelFromBulkItem != null) { + successfullyStoredModels.add(modelFromBulkItem); + } + } + } + + return new ResponseInfo(responses, successfullyStoredModels); } - private static void logBulkFailure(String inferenceEntityId, BulkItemResponse item) { - if (item.isFailed()) { + private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item, Map docIdToInferenceId) { + var failure = item.getFailure(); + + String inferenceIdOrUnknown = "unknown"; + var inferenceIdMaybeNull = docIdToInferenceId.get(item.getId()); + if (inferenceIdMaybeNull == null) { + logger.warn("Failed to find inference id for document id [{}]", item.getId()); + } else { + inferenceIdOrUnknown = inferenceIdMaybeNull; + } + + if (item.isFailed() && failure != null) { logger.warn( - format("Failed to store inference endpoint [%s] index: [%s]", inferenceEntityId, item.getIndex()), - item.getFailure().getCause() + format( + "Failed to store document id: [%s] inference id: [%s] index: [%s] bulk failure message [%s]", + item.getId(), + inferenceIdOrUnknown, + item.getIndex(), + item.getFailureMessage() + ) ); + + return new ModelStoreResponse(inferenceIdOrUnknown, item.status(), failure.getCause()); + } else { + return new ModelStoreResponse(inferenceIdOrUnknown, item.status(), null); } } - private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkResponse) { - for (BulkItemResponse item : bulkResponse.getItems()) { - if (item.isFailed()) { - return item.getFailure(); - } + private static Model getModelFromMap(@Nullable String inferenceId, Map inferenceIdToModel) { + if (inferenceId != null) { + return inferenceIdToModel.get(inferenceId); } return null; } + private void updateClusterState(List models, ActionListener listener, TimeValue timeout) { + var inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet()); + var storeListener = listener.delegateResponse((delegate, exc) -> { + logger.warn(format("Failed to add minimal service settings to cluster state for inference endpoints %s", inferenceIdsSet), exc); + deleteModels( + inferenceIdsSet, + ActionListener.running( + () -> delegate.onFailure( + new ElasticsearchStatusException( + format( + "Failed to add the inference endpoints %s. The service may be in an " + + "inconsistent state. Please try deleting and re-adding the endpoints.", + inferenceIdsSet + ), + RestStatus.INTERNAL_SERVER_ERROR, + exc + ) + ) + ) + ); + }); + + try { + metadataTaskQueue.submitTask( + format("add model metadata for %s", inferenceIdsSet), + new AddModelMetadataTask( + ProjectId.DEFAULT, + models.stream() + .map(model -> new ModelAndSettings(model.getInferenceEntityId(), new MinimalServiceSettings(model))) + .toList(), + storeListener + ), + timeout + ); + } catch (Exception exc) { + storeListener.onFailure(exc); + } + } + public synchronized void removeDefaultConfigs(Set inferenceEntityIds, ActionListener listener) { if (inferenceEntityIds.isEmpty()) { listener.onResponse(true); @@ -839,6 +961,11 @@ public void deleteModels(Set inferenceEntityIds, ActionListener } private void deleteModels(Set inferenceEntityIds, boolean updateClusterState, ActionListener listener) { + if (inferenceEntityIds.isEmpty()) { + listener.onResponse(true); + return; + } + var lockedInferenceIds = new HashSet<>(inferenceEntityIds); lockedInferenceIds.retainAll(preventDeletionLock); @@ -937,22 +1064,8 @@ private static DeleteByQueryRequest createDeleteRequest(Set inferenceEnt return request; } - private static IndexRequest createIndexRequest(String docId, String indexName, ToXContentObject body, boolean allowOverwriting) { - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - var request = new IndexRequest(indexName); - XContentBuilder source = body.toXContent( - builder, - new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())) - ); - var operation = allowOverwriting ? DocWriteRequest.OpType.INDEX : DocWriteRequest.OpType.CREATE; - - return request.opType(operation).id(docId).source(source); - } catch (IOException ex) { - throw new ElasticsearchException(format("Unexpected serialization exception for index [%s] doc [%s]", indexName, docId), ex); - } - } - - static IndexRequestBuilder createIndexRequestBuilder( + // public for testing + public static IndexRequestBuilder createIndexRequestBuilder( String inferenceId, String indexName, ToXContentObject body, @@ -1124,24 +1237,19 @@ ModelRegistryMetadata executeTask(ModelRegistryMetadata current) { } } + public record ModelAndSettings(String inferenceEntityId, MinimalServiceSettings settings) {} + private static class AddModelMetadataTask extends MetadataTask { - private final String inferenceEntityId; - private final MinimalServiceSettings settings; + private final List models = new ArrayList<>(); - AddModelMetadataTask( - ProjectId projectId, - String inferenceEntityId, - MinimalServiceSettings settings, - ActionListener listener - ) { + AddModelMetadataTask(ProjectId projectId, List models, ActionListener listener) { super(projectId, listener); - this.inferenceEntityId = inferenceEntityId; - this.settings = settings; + this.models.addAll(models); } @Override ModelRegistryMetadata executeTask(ModelRegistryMetadata current) { - return current.withAddedModel(inferenceEntityId, settings); + return current.withAddedModels(models); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 7f67b876bcad5..4bf23103af5a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -93,17 +93,31 @@ public static ModelRegistryMetadata fromState(ProjectMetadata projectMetadata) { ); public ModelRegistryMetadata withAddedModel(String inferenceEntityId, MinimalServiceSettings settings) { - final var existing = modelMap.get(inferenceEntityId); - if (existing != null && settings.equals(existing)) { + return withAddedModels(List.of(new ModelRegistry.ModelAndSettings(inferenceEntityId, settings))); + } + + public ModelRegistryMetadata withAddedModels(List models) { + var modifiedMap = false; + ImmutableOpenMap.Builder settingsBuilder = ImmutableOpenMap.builder(modelMap); + + for (var model : models) { + if (model.settings().equals(modelMap.get(model.inferenceEntityId())) == false) { + modifiedMap = true; + + settingsBuilder.fPut(model.inferenceEntityId(), model.settings()); + } + } + + if (modifiedMap == false) { return this; } - var settingsBuilder = ImmutableOpenMap.builder(modelMap); - settingsBuilder.fPut(inferenceEntityId, settings); + if (isUpgraded) { return new ModelRegistryMetadata(settingsBuilder.build()); } + var newTombstone = new HashSet<>(tombstones); - newTombstone.remove(inferenceEntityId); + models.forEach(existing -> newTombstone.remove(existing.inferenceEntityId())); return new ModelRegistryMetadata(settingsBuilder.build(), newTombstone); } @@ -260,7 +274,9 @@ public boolean equals(Object obj) { return false; } ModelRegistryMetadata other = (ModelRegistryMetadata) obj; - return Objects.equals(this.modelMap, other.modelMap) && isUpgraded == other.isUpgraded; + return Objects.equals(this.modelMap, other.modelMap) + && isUpgraded == other.isUpgraded + && Objects.equals(this.tombstones, other.tombstones); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java index 19af5ad61b988..9af21386e93d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java @@ -19,11 +19,15 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; public class ModelRegistryMetadataTests extends AbstractChunkedSerializingTestCase { public static ModelRegistryMetadata randomInstance() { @@ -60,7 +64,33 @@ protected ModelRegistryMetadata createTestInstance() { @Override protected ModelRegistryMetadata mutateInstance(ModelRegistryMetadata instance) { - return randomValueOtherThan(instance, this::createTestInstance); + int choice = randomIntBetween(0, 2); + switch (choice) { + case 0: // Mutate modelMap + var models = new HashMap<>(instance.getModelMap()); + models.put(randomAlphaOfLength(10), MinimalServiceSettingsTests.randomInstance()); + if (instance.isUpgraded()) { + return new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + } else { + return new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), new HashSet<>(instance.getTombstones())); + } + case 1: // Mutate tombstones + if (instance.getTombstones() == null) { + return new ModelRegistryMetadata(instance.getModelMap(), Set.of(randomAlphaOfLength(10))); + } else { + var tombstones = new HashSet<>(instance.getTombstones()); + tombstones.add(randomAlphaOfLength(10)); + return new ModelRegistryMetadata(instance.getModelMap(), tombstones); + } + case 2: // Mutate isUpgraded + if (instance.isUpgraded()) { + return new ModelRegistryMetadata(instance.getModelMap(), new HashSet<>()); + } else { + return new ModelRegistryMetadata(instance.getModelMap()); + } + default: + throw new IllegalStateException("Unexpected value: " + choice); + } } @Override @@ -106,4 +136,182 @@ public void testAlreadyUpgraded() { var exc = expectThrows(IllegalArgumentException.class, () -> metadata.withUpgradedModels(indexMetadata.getModelMap())); assertThat(exc.getMessage(), containsString("upgraded")); } + + public void testWithAddedModel_ReturnsSameMetadataInstance() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newMetadata = metadata.withAddedModel(inferenceId, settings); + assertThat(newMetadata, sameInstance(metadata)); + } + + public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newInferenceId = "new_id"; + var newSettings = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); + assertThat( + newMetadata, + is(new ModelRegistryMetadata(ImmutableOpenMap.builder(Map.of(inferenceId, settings, newInferenceId, newSettings)).build())) + ); + } + + public void testWithAddedModel_ReturnsNewMetadataInstance_ForNewInferenceId_WithTombstoneRemoved() { + var inferenceId = "id"; + var newInferenceId = "new_id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)); + + var newSettings = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModel(newInferenceId, newSettings); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)))); + assertThat(newMetadata, not(is(metadata))); + assertThat( + newMetadata, + is( + new ModelRegistryMetadata( + ImmutableOpenMap.builder(Map.of(inferenceId, settings, newInferenceId, newSettings)).build(), + new HashSet<>() + ) + ) + ); + } + + public void testWithAddedModels_ReturnsSameMetadataInstance() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newMetadata = metadata.withAddedModels( + List.of(new ModelRegistry.ModelAndSettings(inferenceId, settings), new ModelRegistry.ModelAndSettings(inferenceId, settings)) + ); + assertThat(newMetadata, sameInstance(metadata)); + } + + public void testWithAddedModels_ReturnsSameMetadataInstance_MultipleEntriesInMap() { + var inferenceId = "id"; + var inferenceId2 = "id2"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings, inferenceId2, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var newMetadata = metadata.withAddedModels( + List.of( + new ModelRegistry.ModelAndSettings(inferenceId, settings), + new ModelRegistry.ModelAndSettings(inferenceId, settings), + new ModelRegistry.ModelAndSettings(inferenceId2, settings) + ) + ); + assertThat(newMetadata, sameInstance(metadata)); + } + + public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var inferenceId2 = "new_id"; + var settings2 = MinimalServiceSettingsTests.randomInstance(); + var inferenceId3 = "new_id2"; + var settings3 = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModels( + List.of( + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + // This should be ignored since it's a duplicate + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + new ModelRegistry.ModelAndSettings(inferenceId3, settings3) + ) + ); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); + assertThat( + newMetadata, + is( + new ModelRegistryMetadata( + ImmutableOpenMap.builder(Map.of(inferenceId, settings, inferenceId2, settings2, inferenceId3, settings3)).build() + ) + ) + ); + } + + public void testWithAddedModels_ReturnsNewMetadataInstance_UsesOverridingSettings() { + var inferenceId = "id"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var inferenceId2 = "new_id"; + var settings2 = MinimalServiceSettingsTests.randomInstance(); + var settings3 = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModels( + List.of( + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + // This should be ignored since it's a duplicate inference id + new ModelRegistry.ModelAndSettings(inferenceId2, settings2), + // This should replace the existing settings for inferenceId2 + new ModelRegistry.ModelAndSettings(inferenceId2, settings3) + ) + ); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()))); + assertThat(newMetadata, not(is(metadata))); + assertThat( + newMetadata, + is(new ModelRegistryMetadata(ImmutableOpenMap.builder(Map.of(inferenceId, settings, inferenceId2, settings3)).build())) + ); + } + + public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_WithTombstoneRemoved() { + var inferenceId = "id"; + var newInferenceId = "new_id"; + var newInferenceId2 = "new_id2"; + var settings = MinimalServiceSettingsTests.randomInstance(); + + var models = new HashMap<>(Map.of(inferenceId, settings)); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)); + + var newSettings = MinimalServiceSettingsTests.randomInstance(); + var newMetadata = metadata.withAddedModels( + List.of( + // This will cause the new settings to be used for inferenceId + new ModelRegistry.ModelAndSettings(inferenceId, newSettings), + new ModelRegistry.ModelAndSettings(newInferenceId, newSettings), + new ModelRegistry.ModelAndSettings(newInferenceId2, newSettings) + ) + ); + // ensure metadata hasn't changed + assertThat(metadata, is(new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), Set.of(newInferenceId)))); + assertThat(newMetadata, not(is(metadata))); + assertThat( + newMetadata, + is( + new ModelRegistryMetadata( + ImmutableOpenMap.builder(Map.of(inferenceId, newSettings, newInferenceId, newSettings, newInferenceId2, newSettings)) + .build(), + new HashSet<>() + ) + ) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 2980621e52c6f..44f0dcc1d8962 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -8,29 +8,19 @@ package org.elasticsearch.xpack.inference.registry; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ResourceAlreadyExistsException; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.MinimalServiceSettingsTests; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.xpack.inference.InferenceIndex; -import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.hamcrest.Matchers; @@ -38,17 +28,12 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.core.Strings.format; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -68,85 +53,6 @@ public void createComponents() { registry = node().injector().getInstance(ModelRegistry.class); } - public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() { - var listener = new PlainActionFuture(); - registry.getModelWithSecrets("1", listener); - - ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [1]")); - } - - public void testGetModelWithSecrets() { - assertStoreModel( - registry, - new TestModel( - "1", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings("secret") - ) - ); - - var listener = new PlainActionFuture(); - registry.getModelWithSecrets("1", listener); - - var modelConfig = listener.actionGet(TIMEOUT); - assertEquals("1", modelConfig.inferenceEntityId()); - assertEquals("foo", modelConfig.service()); - assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); - assertNotNull(modelConfig.settings().keySet()); - assertThat(modelConfig.secrets().keySet(), hasSize(1)); - assertThat(modelConfig.secrets().get("secret_settings"), instanceOf(Map.class)); - @SuppressWarnings("unchecked") - var secretSettings = (Map) modelConfig.secrets().get("secret_settings"); - assertThat(secretSettings.get("api_key"), equalTo("secret")); - } - - public void testGetModelNoSecrets() { - assertStoreModel( - registry, - new TestModel( - "1", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings(randomAlphaOfLength(4)) - ) - ); - - var getListener = new PlainActionFuture(); - registry.getModel("1", getListener); - - var modelConfig = getListener.actionGet(TIMEOUT); - assertEquals("1", modelConfig.inferenceEntityId()); - assertEquals("foo", modelConfig.service()); - assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); - assertNotNull(modelConfig.settings().keySet()); - assertThat(modelConfig.secrets().keySet(), empty()); - } - - public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { - var model = TestModel.createRandomInstance(); - assertStoreModel(registry, model); - } - - public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVersionConflict() { - var model = TestModel.createRandomInstance(); - assertStoreModel(registry, model); - - ResourceAlreadyExistsException exception = expectThrows( - ResourceAlreadyExistsException.class, - () -> assertStoreModel(registry, model) - ); - assertThat( - exception.getMessage(), - is(format("Inference endpoint [%s] already exists", model.getConfigurations().getInferenceEntityId())) - ); - } - public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() { var listener = new PlainActionFuture(); registry.removeDefaultConfigs(Set.of(), listener); @@ -245,94 +151,13 @@ public void testDuplicateDefaultIds() { ); } - public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { - storeCorruptedModelThenStoreModel(false); - } - - public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() { - storeCorruptedModelThenStoreModel(true); - } - - public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { - var model = new TestModel( - "model-id", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings("secret") - ); - - PlainActionFuture firstStoreListener = new PlainActionFuture<>(); - registry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS); - firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS); - - assertIndicesContainExpectedDocsCount(model, 2); - - PlainActionFuture secondStoreListener = new PlainActionFuture<>(); - registry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); - - expectThrows(ResourceAlreadyExistsException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); - - assertIndicesContainExpectedDocsCount(model, 2); - } - - private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { - var model = new TestModel( - "corrupted-model-id", - TaskType.SPARSE_EMBEDDING, - "foo", - new TestModel.TestServiceSettings(null, null, null, null), - new TestModel.TestTaskSettings(randomInt(3)), - new TestModel.TestSecretSettings("secret") - ); - - storeCorruptedModel(model, storeSecrets); - - assertIndicesContainExpectedDocsCount(model, 1); - - PlainActionFuture storeListener = new PlainActionFuture<>(); - registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); - - expectThrows(ResourceAlreadyExistsException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); - - assertIndicesContainExpectedDocsCount(model, 0); - } - - private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) { - SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) - .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId())))) - .setSize(2) - .setTrackTotalHits(false) - .request(); - SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS); - try { - assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs)); - } finally { - searchResponse.decRef(); - } - } - - private void storeCorruptedModel(Model model, boolean storeSecrets) { - var listener = new PlainActionFuture(); - - client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add( - ModelRegistry.createIndexRequestBuilder( - model.getInferenceEntityId(), - storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME, - storeSecrets ? model.getSecrets() : model.getConfigurations(), - false, - client() - ) - ) - .execute(listener); + public void testDeleteModels_Succeeds_WhenNoInferenceIdsAreProvided() { + var model = TestModel.createRandomInstance(); + assertStoreModel(registry, model); - var bulkResponse = listener.actionGet(TIMEOUT); - if (bulkResponse.hasFailures()) { - fail("Failed to store model: " + bulkResponse.buildFailureMessage()); - } + var listener = new PlainActionFuture(); + registry.deleteModels(Set.of(), listener); + assertTrue(listener.actionGet(TIMEOUT)); } public static void assertStoreModel(ModelRegistry registry, Model model) { @@ -340,11 +165,15 @@ public static void assertStoreModel(ModelRegistry registry, Model model) { registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); assertTrue(storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + assertMinimalServiceSettings(registry, model); + } + + public static void assertMinimalServiceSettings(ModelRegistry registry, Model model) { var settings = registry.getMinimalServiceSettings(model.getInferenceEntityId()); assertNotNull(settings); - assertThat(settings.taskType(), equalTo(model.getTaskType())); - assertThat(settings.dimensions(), equalTo(model.getServiceSettings().dimensions())); - assertThat(settings.elementType(), equalTo(model.getServiceSettings().elementType())); - assertThat(settings.dimensions(), equalTo(model.getServiceSettings().dimensions())); + assertThat(settings.taskType(), Matchers.equalTo(model.getTaskType())); + assertThat(settings.dimensions(), Matchers.equalTo(model.getServiceSettings().dimensions())); + assertThat(settings.elementType(), Matchers.equalTo(model.getServiceSettings().elementType())); + assertThat(settings.dimensions(), Matchers.equalTo(model.getServiceSettings().dimensions())); } }