From c08371d71a7bb8705c0f18f2474a96f4bd3526da Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 24 Jan 2025 16:14:32 -0500 Subject: [PATCH 1/5] Starting new auth class implementation --- .../inference/BaseMockEISAuthServerTest.java | 67 ++++++++++++++++ .../inference/InferenceBaseRestTest.java | 21 +++-- ...etModelsWithElasticInferenceServiceIT.java | 33 ++++++++ .../inference/InferenceGetServicesIT.java | 54 +------------ .../TransportPutInferenceModelAction.java | 11 +++ .../inference/registry/ModelRegistry.java | 10 +++ .../elastic/ElasticInferenceService.java | 35 +++++++- .../ElasticInferenceServiceAuthorization.java | 80 ++++++++++++++----- ...lasticInferenceServiceCompletionModel.java | 2 +- 9 files changed, 223 insertions(+), 90 deletions(-) create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java new file mode 100644 index 0000000000000..d257ad9632553 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file has been contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; + +public class BaseMockEISAuthServerTest extends ESRestTestCase { + + // The reason we're retrying is there's a race condition between the node retrieving the + // authorization response and running the test. Retrieving the authorization should be very fast since + // we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure + // we'll automatically retry after waiting a second. + @Rule + public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1)); + + private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer + .enabledWithSparseAndChatCompletion(); + + private static final ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.security.enabled", "true") + // Adding both settings unless one feature flag is disabled in a particular environment + .setting("xpack.inference.elastic.url", mockEISServer::getUrl) + .setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl) + .plugin("inference-service-test") + .user("x_pack_rest_user", "x-pack-test-password") + .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED) + .build(); + + // The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating + // it to the cluster as a setting. + @ClassRule + public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder() + .put(ThreadContext.PREFIX + ".Authorization", token) + .put(CLIENT_SOCKET_TIMEOUT, "120s") // Long timeout for model download + .build(); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 49b2f5b041b9e..5174b5bbb8cb4 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,20 +171,20 @@ static String mockDenseServiceModelConfig() { """; } - protected void deleteModel(String modelId) throws IOException { + static void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); assertStatusOkOrCreated(response); } - protected Response deleteModel(String modelId, String queryParams) throws IOException { + static Response deleteModel(String modelId, String queryParams) throws IOException { var request = new Request("DELETE", "_inference/" + modelId + "?" + queryParams); var response = client().performRequest(request); assertStatusOkOrCreated(response); return response; } - protected void deleteModel(String modelId, TaskType taskType) throws IOException { + static void deleteModel(String modelId, TaskType taskType) throws IOException { var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId)); var response = client().performRequest(request); assertStatusOkOrCreated(response); @@ -229,12 +229,12 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin assertStatusOkOrCreated(response); } - protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { + static Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId); return putRequest(endpoint, modelConfig); } - protected Map updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException { + static Map updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID); return putRequest(endpoint, modelConfig); } @@ -265,12 +265,12 @@ protected void deletePipeline(String pipelineId) throws IOException { /** * Task type should be in modelConfig */ - protected Map putModel(String modelId, String modelConfig) throws IOException { + static Map putModel(String modelId, String modelConfig) throws IOException { String endpoint = Strings.format("_inference/%s", modelId); return putRequest(endpoint, modelConfig); } - Map putRequest(String endpoint, String body) throws IOException { + static Map putRequest(String endpoint, String body) throws IOException { var request = new Request("PUT", endpoint); request.setJsonEntity(body); var response = client().performRequest(request); @@ -318,18 +318,17 @@ protected Map getModel(String modelId) throws IOException { } @SuppressWarnings("unchecked") - protected List> getModels(String modelId, TaskType taskType) throws IOException { + static List> getModels(String modelId, TaskType taskType) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return (List>) getInternalAsMap(endpoint).get("endpoints"); } @SuppressWarnings("unchecked") - protected List> getAllModels() throws IOException { - var endpoint = Strings.format("_inference/_all"); + static List> getAllModels() throws IOException { return (List>) getInternalAsMap("_inference/_all").get("endpoints"); } - private Map getInternalAsMap(String endpoint) throws IOException { + private static Map getInternalAsMap(String endpoint) throws IOException { var request = new Request("GET", endpoint); var response = client().performRequest(request); assertStatusOkOrCreated(response); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java new file mode 100644 index 0000000000000..3c6f85ba3cd71 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file has been contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; + +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; +import static org.hamcrest.Matchers.hasSize; + +public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest { + + public void testGetDefaultEndpoints() throws IOException { + var allModels = getAllModels(); + int numModels = 4; + assertThat(allModels, hasSize(numModels)); + + var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); + assertThat(chatCompletionModels, hasSize(1)); + for (var model : chatCompletionModels) { + assertEquals("chat_completion", model.get("task_type")); + } + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index b448acd5f4a74..856fdeb6287e9 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -11,20 +11,8 @@ import org.elasticsearch.client.Request; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.cluster.ElasticsearchCluster; -import org.elasticsearch.test.cluster.FeatureFlag; -import org.elasticsearch.test.cluster.local.distribution.DistributionType; -import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; -import org.junit.ClassRule; -import org.junit.Rule; -import org.junit.rules.RuleChain; -import org.junit.rules.TestRule; import java.io.IOException; import java.util.ArrayList; @@ -35,47 +23,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; import static org.hamcrest.Matchers.equalTo; -public class InferenceGetServicesIT extends ESRestTestCase { - - // The reason we're retrying is there's a race condition between the node retrieving the - // authorization response and running the test. Retrieving the authorization should be very fast since - // we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure - // we'll automatically retry after waiting a second. - @Rule - public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1)); - - private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer - .enabledWithSparseEmbeddingsAndChatCompletion(); - - private static final ElasticsearchCluster cluster = ElasticsearchCluster.local() - .distribution(DistributionType.DEFAULT) - .setting("xpack.license.self_generated.type", "trial") - .setting("xpack.security.enabled", "true") - // Adding both settings unless one feature flag is disabled in a particular environment - .setting("xpack.inference.elastic.url", mockEISServer::getUrl) - // TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL - .setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl) - // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin - .plugin("inference-service-test") - .user("x_pack_rest_user", "x-pack-test-password") - .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED) - .build(); - - // The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating - // it to the cluster as a setting. - @ClassRule - public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster); - - @Override - protected String getTestRestCluster() { - return cluster.getHttpAddresses(); - } - - @Override - protected Settings restClientSettings() { - String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); - return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); - } +public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 6168edeca4820..4824f4e44b35a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -108,6 +108,17 @@ protected void masterOperation( return; } + if (modelRegistry.matchesDefaultConfigId(request.getInferenceEntityId())) { + listener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.", + RestStatus.BAD_REQUEST, + request.getInferenceEntityId() + ) + ); + return; + } + var requestAsMap = requestToMap(request); var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 4506a05d58054..8dc97dd8f8b2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -112,6 +112,16 @@ public ModelRegistry(Client client) { defaultConfigIds = new ArrayList<>(); } + /** + * Returns true if the provided inference entity id is the same as one of the default + * endpoints ids. + * @param inferenceEntityId the id to search for + * @return true if we find a match and false if not + */ + public boolean matchesDefaultConfigId(String inferenceEntityId) { + return idMatchedDefault(inferenceEntityId, defaultConfigIds).isPresent(); + } + /** * Set the default inference ids provided by the services * @param defaultConfigIds The defaults diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 29f1e7cf70e77..0b6a0a5a9bd29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,6 +16,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -47,6 +49,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorization; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -79,6 +82,8 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; + private static final String DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1 = ".rainbow-sprinkles-elastic"; + private static final String DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; /** * The task types that the {@link InferenceAction.Request} can accept. @@ -88,9 +93,11 @@ public class ElasticInferenceService extends SenderService { private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; private Configuration configuration; private final AtomicReference> enabledTaskTypesRef = new AtomicReference<>(EnumSet.noneOf(TaskType.class)); + private final Set enabledModels; private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; private final CountDownLatch authorizationCompletedLatch = new CountDownLatch(1); + private final Map defaultEndpoints; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -105,14 +112,31 @@ public ElasticInferenceService( this.authorizationHandler = Objects.requireNonNull(authorizationHandler); configuration = new Configuration(enabledTaskTypesRef.get()); + enabledModels = Set.of(); + this.defaultEndpoints = initDefaultEndpoints(); getAuthorization(); } + private Map initDefaultEndpoints() { + return Map.of( + DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, + new ElasticInferenceServiceCompletionModel( + DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + NAME, + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ) + ); + } + private void getAuthorization() { try { ActionListener listener = ActionListener.wrap(result -> { - setEnabledTaskTypes(result); + setEnabledTaskTypesAndModels(result); authorizationCompletedLatch.countDown(); }, e -> { // we don't need to do anything if there was a failure, everything is disabled by default @@ -126,7 +150,7 @@ private void getAuthorization() { } } - private synchronized void setEnabledTaskTypes(ElasticInferenceServiceAuthorization auth) { + private synchronized void setEnabledTaskTypesAndModels(ElasticInferenceServiceAuthorization auth) { enabledTaskTypesRef.set(filterTaskTypesByAuthorization(auth)); configuration = new Configuration(enabledTaskTypesRef.get()); @@ -135,7 +159,7 @@ private synchronized void setEnabledTaskTypes(ElasticInferenceServiceAuthorizati private static EnumSet filterTaskTypesByAuthorization(ElasticInferenceServiceAuthorization auth) { var implementedTaskTypes = EnumSet.copyOf(IMPLEMENTED_TASK_TYPES); - implementedTaskTypes.retainAll(auth.enabledTaskTypes()); + implementedTaskTypes.retainAll(auth.getEnabledTaskTypes()); return implementedTaskTypes; } @@ -168,6 +192,11 @@ public synchronized List defaultConfigIds() { return List.of(); } + @Override + public void defaultConfigs(ActionListener> defaultsListener) { + // defaultsListener.onResponse(defaultEndpoints); + } + @Override protected void doUnifiedCompletionInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java index eac64021ac85a..0e4805a2a2822 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java @@ -12,16 +12,17 @@ import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; -/** - * Provides a structure for governing which models (if any) a cluster has access to according to the upstream Elastic Inference Service. - * @param enabledModels a mapping of model ids to a set of {@link TaskType} to indicate which models are available and for which task types - */ -public record ElasticInferenceServiceAuthorization(Map> enabledModels) { +public class ElasticInferenceServiceAuthorization { + + private final Map> taskTypeToModels; + private final EnumSet enabledTaskTypes; + private final Set enabledModels; /** * Converts an authorization response from Elastic Inference Service into the {@link ElasticInferenceServiceAuthorization} format. @@ -30,45 +31,80 @@ public record ElasticInferenceServiceAuthorization(Map * @return a new {@link ElasticInferenceServiceAuthorization} */ public static ElasticInferenceServiceAuthorization of(ElasticInferenceServiceAuthorizationResponseEntity responseEntity) { - var enabledModels = new HashMap>(); + var taskTypeToModelsMap = new HashMap>(); + var enabledTaskTypesSet = EnumSet.noneOf(TaskType.class); + var enabledModelsSet = new HashSet(); for (var model : responseEntity.getAuthorizedModels()) { // if there are no task types we'll ignore the model because it's likely we didn't understand // the task type and don't support it anyway if (model.taskTypes().isEmpty() == false) { - enabledModels.put(model.modelName(), model.taskTypes()); + for (var taskType : model.taskTypes()) { + taskTypeToModelsMap.merge(taskType, Set.of(model.modelName()), (existingModelIds, newModelIds) -> { + existingModelIds.addAll(newModelIds); + return existingModelIds; + }); + enabledTaskTypesSet.add(taskType); + } + enabledModelsSet.add(model.modelName()); } } - return new ElasticInferenceServiceAuthorization(enabledModels); + return new ElasticInferenceServiceAuthorization(taskTypeToModelsMap, enabledModelsSet, enabledTaskTypesSet); } /** * Returns an object indicating that the cluster has no access to Elastic Inference Service. */ public static ElasticInferenceServiceAuthorization newDisabledService() { - return new ElasticInferenceServiceAuthorization(); + return new ElasticInferenceServiceAuthorization(Map.of(), Set.of(), EnumSet.noneOf(TaskType.class)); } - public ElasticInferenceServiceAuthorization { - Objects.requireNonNull(enabledModels); + private ElasticInferenceServiceAuthorization( + Map> taskTypeToModels, + Set enabledModels, + EnumSet enabledTaskTypes + ) { + this.taskTypeToModels = Objects.requireNonNull(taskTypeToModels); + this.enabledModels = Objects.requireNonNull(enabledModels); + this.enabledTaskTypes = Objects.requireNonNull(enabledTaskTypes); + } - for (var taskTypes : enabledModels.values()) { - if (taskTypes.isEmpty()) { - throw new IllegalArgumentException("Authorization task types must not be empty"); - } - } + public boolean isEnabled() { + return enabledModels.isEmpty() == false && taskTypeToModels.isEmpty() == false && enabledTaskTypes.isEmpty() == false; } - private ElasticInferenceServiceAuthorization() { - this(Map.of()); + public Set getEnabledModels() { + return Set.copyOf(enabledModels); } - public boolean isEnabled() { - return enabledModels.isEmpty() == false; + public EnumSet getEnabledTaskTypes() { + return EnumSet.copyOf(enabledTaskTypes); + } + + public Map> getTaskTypeToModels() { + return Map.copyOf(taskTypeToModels); } - public EnumSet enabledTaskTypes() { - return enabledModels.values().stream().flatMap(Set::stream).collect(Collectors.toCollection(() -> EnumSet.noneOf(TaskType.class))); + /** + * Returns a new {@link ElasticInferenceServiceAuthorization} object retaining only the specified task types + * and applicable models that leverage those task types. Any task types not specified in the passed in set will be + * excluded from the returned object. This is essentially an intersection. + * @param taskTypes the task types to retain in the newly created object + * @return a new object containing models and task types limited to the specified set. + */ + public ElasticInferenceServiceAuthorization newLimitedToTaskTypes(EnumSet taskTypes) { + var newTaskTypeToModels = new HashMap>(); + + for (var taskType : taskTypes) { + var models = taskTypeToModels.get(taskType); + if (models != null) { + newTaskTypeToModels.put(taskType, models); + } + } + + Set newEnabledModels = newTaskTypeToModels.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); + + return new ElasticInferenceServiceAuthorization(newTaskTypeToModels, newEnabledModels, taskTypes); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index b26f80efb1930..5125ade21339d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -74,7 +74,7 @@ public ElasticInferenceServiceCompletionModel( } - ElasticInferenceServiceCompletionModel( + public ElasticInferenceServiceCompletionModel( String inferenceEntityId, TaskType taskType, String service, From 665e700c500101f295139fff690b10a7e36e33ba Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 24 Jan 2025 17:30:11 -0500 Subject: [PATCH 2/5] Fixing some tests --- .../elastic/ElasticInferenceService.java | 100 +++++++++++++----- .../ElasticInferenceServiceAuthorization.java | 21 +++- .../elastic/ElasticInferenceServiceTests.java | 45 +++++++- ...renceServiceAuthorizationHandlerTests.java | 13 ++- ...ticInferenceServiceAuthorizationTests.java | 53 ++++++---- 5 files changed, 174 insertions(+), 58 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 0b6a0a5a9bd29..f9249c7198b5a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -53,8 +53,10 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -82,8 +84,8 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; - private static final String DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1 = ".rainbow-sprinkles-elastic"; - private static final String DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; + private static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = "rainbow-sprinkles"; + private static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1); /** * The task types that the {@link InferenceAction.Request} can accept. @@ -92,12 +94,12 @@ public class ElasticInferenceService extends SenderService { private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; private Configuration configuration; - private final AtomicReference> enabledTaskTypesRef = new AtomicReference<>(EnumSet.noneOf(TaskType.class)); - private final Set enabledModels; + private final AtomicReference authRef = new AtomicReference<>(AuthorizedContent.empty()); private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; private final CountDownLatch authorizationCompletedLatch = new CountDownLatch(1); - private final Map defaultEndpoints; + // model ids to model object + private final Map defaultModels; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -111,21 +113,20 @@ public ElasticInferenceService( this.modelRegistry = Objects.requireNonNull(modelRegistry); this.authorizationHandler = Objects.requireNonNull(authorizationHandler); - configuration = new Configuration(enabledTaskTypesRef.get()); - enabledModels = Set.of(); - this.defaultEndpoints = initDefaultEndpoints(); + configuration = new Configuration(authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes()); + defaultModels = initDefaultEndpoints(elasticInferenceServiceComponents); getAuthorization(); } - private Map initDefaultEndpoints() { + private static Map initDefaultEndpoints(ElasticInferenceServiceComponents elasticInferenceServiceComponents) { return Map.of( - DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, new ElasticInferenceServiceCompletionModel( - DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, TaskType.CHAT_COMPLETION, NAME, - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1, null), + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents @@ -133,10 +134,20 @@ private Map initDefaultEndpoints() { ); } + private record AuthorizedContent( + ElasticInferenceServiceAuthorization authorizedTaskTypesAndModels, + List enabledConfigIds, + List enabledModelObjects + ) { + static AuthorizedContent empty() { + return new AuthorizedContent(ElasticInferenceServiceAuthorization.newDisabledService(), List.of(), List.of()); + } + } + private void getAuthorization() { try { ActionListener listener = ActionListener.wrap(result -> { - setEnabledTaskTypesAndModels(result); + setAuthorizedContent(result); authorizationCompletedLatch.countDown(); }, e -> { // we don't need to do anything if there was a failure, everything is disabled by default @@ -150,17 +161,53 @@ private void getAuthorization() { } } - private synchronized void setEnabledTaskTypesAndModels(ElasticInferenceServiceAuthorization auth) { - enabledTaskTypesRef.set(filterTaskTypesByAuthorization(auth)); - configuration = new Configuration(enabledTaskTypesRef.get()); + private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorization auth) { + var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); + + // recalculate which default config ids and models are authorized now + var enabledDefaultConfigIds = getEnabledDefaultConfigIds(auth); + var enabledDefaultModelObjects = getEnabledDefaultModelsObjects(auth); + authRef.set(new AuthorizedContent(authorizedTaskTypesAndModels, enabledDefaultConfigIds, enabledDefaultModelObjects)); + + configuration = new Configuration(authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes()); defaultConfigIds().forEach(modelRegistry::addDefaultIds); } - private static EnumSet filterTaskTypesByAuthorization(ElasticInferenceServiceAuthorization auth) { - var implementedTaskTypes = EnumSet.copyOf(IMPLEMENTED_TASK_TYPES); - implementedTaskTypes.retainAll(auth.getEnabledTaskTypes()); - return implementedTaskTypes; + private List getEnabledDefaultConfigIds(ElasticInferenceServiceAuthorization auth) { + var enabledDefaultModelIds = getEnabledDefaultModelIds(auth); + + var enabledConfigIds = new ArrayList(); + for (var id : enabledDefaultModelIds) { + var model = defaultModels.get(id); + if (model != null) { + enabledConfigIds.add(new DefaultConfigId(id, model.getTaskType(), this)); + } + } + + return enabledConfigIds; + } + + private Set getEnabledDefaultModelIds(ElasticInferenceServiceAuthorization auth) { + var enabledModels = auth.getEnabledModels(); + var enabledDefaultModelIds = new HashSet<>(defaultModels.keySet()); + enabledDefaultModelIds.retainAll(enabledModels); + + return enabledDefaultModelIds; + } + + private List getEnabledDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { + var enabledDefaultModelIds = getEnabledDefaultModelIds(auth); + + var enabledModels = new ArrayList(); + for (var id : enabledDefaultModelIds) { + var model = defaultModels.get(id); + if (model != null) { + enabledModels.add(model); + } + } + + return enabledModels; } // Default for testing @@ -177,7 +224,7 @@ void waitForAuthorizationToComplete(TimeValue waitTime) { @Override public synchronized Set supportedStreamingTasks() { var enabledStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - enabledStreamingTaskTypes.retainAll(enabledTaskTypesRef.get()); + enabledStreamingTaskTypes.retainAll(authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes()); if (enabledStreamingTaskTypes.isEmpty() == false) { enabledStreamingTaskTypes.add(TaskType.ANY); @@ -188,13 +235,12 @@ public synchronized Set supportedStreamingTasks() { @Override public synchronized List defaultConfigIds() { - // TODO once we have the enabledTaskTypes figure out which default endpoints we should expose - return List.of(); + return authRef.get().enabledConfigIds; } @Override - public void defaultConfigs(ActionListener> defaultsListener) { - // defaultsListener.onResponse(defaultEndpoints); + public synchronized void defaultConfigs(ActionListener> defaultsListener) { + defaultsListener.onResponse(authRef.get().enabledModelObjects); } @Override @@ -327,12 +373,12 @@ public synchronized InferenceServiceConfiguration getConfiguration() { @Override public synchronized EnumSet supportedTaskTypes() { - return enabledTaskTypesRef.get(); + return authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes(); } @Override public synchronized boolean hideFromConfigurationApi() { - return enabledTaskTypesRef.get().isEmpty(); + return authRef.get().authorizedTaskTypesAndModels.isEnabled() == false; } private static ElasticInferenceServiceModel createModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java index 0e4805a2a2822..545ebe0cb12e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java @@ -18,6 +18,9 @@ import java.util.Set; import java.util.stream.Collectors; +/** + * This is a helper class for managing the response from {@link ElasticInferenceServiceAuthorizationHandler}. + */ public class ElasticInferenceServiceAuthorization { private final Map> taskTypeToModels; @@ -82,10 +85,6 @@ public EnumSet getEnabledTaskTypes() { return EnumSet.copyOf(enabledTaskTypes); } - public Map> getTaskTypeToModels() { - return Map.copyOf(taskTypeToModels); - } - /** * Returns a new {@link ElasticInferenceServiceAuthorization} object retaining only the specified task types * and applicable models that leverage those task types. Any task types not specified in the passed in set will be @@ -107,4 +106,18 @@ public ElasticInferenceServiceAuthorization newLimitedToTaskTypes(EnumSet new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.noneOf(TaskType.class))) - ); - } - public void testEnabledTaskTypes_MergesFromSeparateModels() { - assertThat( - new ElasticInferenceServiceAuthorization( - Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING), "model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ).enabledTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)) + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) + ) + ) ); + assertThat(auth.getEnabledTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); + assertThat(auth.getEnabledModels(), is(Set.of("model-1", "model-2"))); } public void testEnabledTaskTypes_FromSingleEntry() { - assertThat( - new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))) - .enabledTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)) + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) ); + + assertThat(auth.getEnabledTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); + assertThat(auth.getEnabledModels(), is(Set.of("model-1"))); + } + + public void testNewLimitToTaskTypes() { + fail("TODO"); } } From 0025f1ab54938c3f417ef72a6a0ecb74ce31d0d6 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 27 Jan 2025 15:18:17 -0500 Subject: [PATCH 3/5] Working tests --- .../inference/BaseMockEISAuthServerTest.java | 7 +- ...etModelsWithElasticInferenceServiceIT.java | 20 ++- ...icInferenceServiceAuthorizationServer.java | 4 +- .../elastic/ElasticInferenceService.java | 19 ++- .../ElasticInferenceServiceAuthorization.java | 10 +- .../elastic/ElasticInferenceServiceTests.java | 120 +++++++++++++ ...ticInferenceServiceAuthorizationTests.java | 160 +++++++++++++++++- 7 files changed, 319 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index d257ad9632553..586b5bb1cd5b5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -32,7 +32,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1)); private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer - .enabledWithSparseAndChatCompletion(); + .enabledRainbowSprinkles(); private static final ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) @@ -59,9 +59,6 @@ protected String getTestRestCluster() { @Override protected Settings restClientSettings() { String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); - return Settings.builder() - .put(ThreadContext.PREFIX + ".Authorization", token) - .put(CLIENT_SOCKET_TIMEOUT, "120s") // Long timeout for model download - .build(); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 3c6f85ba3cd71..76483a5f62fec 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import java.io.IOException; @@ -21,13 +22,20 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS public void testGetDefaultEndpoints() throws IOException { var allModels = getAllModels(); - int numModels = 4; - assertThat(allModels, hasSize(numModels)); - var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); - assertThat(chatCompletionModels, hasSize(1)); - for (var model : chatCompletionModels) { - assertEquals("chat_completion", model.get("task_type")); + + if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() + || ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) { + assertThat(allModels, hasSize(4)); + assertThat(chatCompletionModels, hasSize(1)); + + for (var model : chatCompletionModels) { + assertEquals("chat_completion", model.get("task_type")); + } + } else { + assertThat(allModels, hasSize(3)); + assertThat(chatCompletionModels, hasSize(0)); } + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 8960a7e1b0258..afb22cde67a6f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -23,14 +23,14 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class); private final MockWebServer webServer = new MockWebServer(); - public static MockElasticInferenceServiceAuthorizationServer enabledWithSparseEmbeddingsAndChatCompletion() { + public static MockElasticInferenceServiceAuthorizationServer enabledRainbowSprinkles() { var server = new MockElasticInferenceServiceAuthorizationServer(); String responseJson = """ { "models": [ { - "model_name": "model-a", + "model_name": "rainbow-sprinkles", "task_types": ["embed/text/sparse", "chat"] } ] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index f9249c7198b5a..c5ed4dd322229 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -82,10 +84,11 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; + private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class); private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; - private static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = "rainbow-sprinkles"; - private static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1); + static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); /** * The task types that the {@link InferenceAction.Request} can accept. @@ -181,7 +184,17 @@ private List getEnabledDefaultConfigIds(ElasticInferenceService for (var id : enabledDefaultModelIds) { var model = defaultModels.get(id); if (model != null) { - enabledConfigIds.add(new DefaultConfigId(id, model.getTaskType(), this)); + if (auth.getEnabledTaskTypes().contains(model.getTaskType()) == false) { + logger.warn( + Strings.format( + "The authorization response included the default model: %s, " + + "but did not authorize the assumed task type of the model: %s. Enabling model.", + id, + model.getTaskType() + ) + ); + } + enabledConfigIds.add(new DefaultConfigId(model.getInferenceEntityId(), model.getTaskType(), this)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java index 545ebe0cb12e1..d9f524da35226 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java @@ -44,8 +44,9 @@ public static ElasticInferenceServiceAuthorization of(ElasticInferenceServiceAut if (model.taskTypes().isEmpty() == false) { for (var taskType : model.taskTypes()) { taskTypeToModelsMap.merge(taskType, Set.of(model.modelName()), (existingModelIds, newModelIds) -> { - existingModelIds.addAll(newModelIds); - return existingModelIds; + var combinedNames = new HashSet<>(existingModelIds); + combinedNames.addAll(newModelIds); + return combinedNames; }); enabledTaskTypesSet.add(taskType); } @@ -94,17 +95,20 @@ public EnumSet getEnabledTaskTypes() { */ public ElasticInferenceServiceAuthorization newLimitedToTaskTypes(EnumSet taskTypes) { var newTaskTypeToModels = new HashMap>(); + var taskTypesThatHaveModels = EnumSet.noneOf(TaskType.class); for (var taskType : taskTypes) { var models = taskTypeToModels.get(taskType); if (models != null) { newTaskTypeToModels.put(taskType, models); + // we only want task types that correspond to actual models to ensure we're only enabling valid task types + taskTypesThatHaveModels.add(taskType); } } Set newEnabledModels = newTaskTypeToModels.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); - return new ElasticInferenceServiceAuthorization(newTaskTypeToModels, newEnabledModels, taskTypes); + return new ElasticInferenceServiceAuthorization(newTaskTypeToModels, newEnabledModels, taskTypesThatHaveModels); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6e31cc661e11b..5d777e4825598 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -795,6 +796,60 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi service.waitForAuthorizationToComplete(TIMEOUT); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertTrue(service.defaultConfigIds().isEmpty()); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertTrue(listener.actionGet(TIMEOUT).isEmpty()); + } + } + + public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimplementedTaskTypes() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "model-b", + "task_types": ["embed"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + } + } + + public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "model-b", + "task_types": ["chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); } } @@ -817,6 +872,71 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat service.waitForAuthorizationToComplete(TIMEOUT); assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertTrue(service.defaultConfigIds().isEmpty()); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertTrue(listener.actionGet(TIMEOUT).isEmpty()); + } + } + + public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIncorrect() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["embed/text/sparse"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat( + service.defaultConfigIds(), + is(List.of(new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION, service))) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + } + } + + public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat( + service.defaultConfigIds(), + is(List.of(new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION, service))) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java index dbb8cb7f7d247..6ce456ac5af91 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java @@ -70,7 +70,163 @@ public void testEnabledTaskTypes_FromSingleEntry() { assertThat(auth.getEnabledModels(), is(Set.of("model-1"))); } - public void testNewLimitToTaskTypes() { - fail("TODO"); + public void testNewLimitToTaskTypes_SingleModel() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.CHAT_COMPLETION)) + ) + ) + ); + + assertThat( + auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ) + ) + ) + ) + ) + ); + } + + public void testNewLimitToTaskTypes_MultipleModels_OnlyTextEmbedding() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.TEXT_EMBEDDING)) + ) + ) + ); + + assertThat( + auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-2", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ) + ) + ) + ) + ) + ); + } + + public void testNewLimitToTaskTypes_MultipleModels_MultipleTaskTypes() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-text-sparse", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-sparse", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-chat-completion", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ); + + var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)); + assertThat( + a, + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-text-sparse", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-chat-completion", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ) + ) + ); + } + + public void testNewLimitToTaskTypes_DuplicateModelNames() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK) + ) + ) + ) + ); + + var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)); + assertThat( + a, + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK) + ) + ) + ) + ) + ) + ); + } + + public void testNewLimitToTaskTypes_ReturnsDisabled_WhenNoOverlapForTaskTypes() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-2", + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING) + ) + ) + ) + ); + + var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.RERANK)); + assertThat(a, is(ElasticInferenceServiceAuthorization.newDisabledService())); } } From 9a637f1b252a22b719479d313b82ec8eb460fe89 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 27 Jan 2025 15:46:06 -0500 Subject: [PATCH 4/5] Refactoring --- .../inference/BaseMockEISAuthServerTest.java | 4 +- ...icInferenceServiceAuthorizationServer.java | 8 ++- .../elastic/ElasticInferenceService.java | 70 +++++++++---------- .../ElasticInferenceServiceAuthorization.java | 34 +++++---- ...renceServiceAuthorizationHandlerTests.java | 24 +++---- ...ticInferenceServiceAuthorizationTests.java | 16 ++--- 6 files changed, 83 insertions(+), 73 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 586b5bb1cd5b5..230b7ff576296 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -32,7 +32,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1)); private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer - .enabledRainbowSprinkles(); + .enabledWithRainbowSprinklesAndElser(); private static final ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) @@ -40,7 +40,9 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { .setting("xpack.security.enabled", "true") // Adding both settings unless one feature flag is disabled in a particular environment .setting("xpack.inference.elastic.url", mockEISServer::getUrl) + // TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL .setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl) + // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin .plugin("inference-service-test") .user("x_pack_rest_user", "x-pack-test-password") .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index afb22cde67a6f..b01a5dbeca18a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -23,7 +23,7 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class); private final MockWebServer webServer = new MockWebServer(); - public static MockElasticInferenceServiceAuthorizationServer enabledRainbowSprinkles() { + public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() { var server = new MockElasticInferenceServiceAuthorizationServer(); String responseJson = """ @@ -31,7 +31,11 @@ public static MockElasticInferenceServiceAuthorizationServer enabledRainbowSprin "models": [ { "model_name": "rainbow-sprinkles", - "task_types": ["embed/text/sparse", "chat"] + "task_types": ["chat"] + }, + { + "model_name": ".elser_model_2", + "task_types": ["embed/text/sparse"] } ] } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index c5ed4dd322229..498db2ae127ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -116,7 +116,7 @@ public ElasticInferenceService( this.modelRegistry = Objects.requireNonNull(modelRegistry); this.authorizationHandler = Objects.requireNonNull(authorizationHandler); - configuration = new Configuration(authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes()); + configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); defaultModels = initDefaultEndpoints(elasticInferenceServiceComponents); getAuthorization(); @@ -138,9 +138,9 @@ private static Map initDefaultEndpoints(ElasticInferenceServiceCo } private record AuthorizedContent( - ElasticInferenceServiceAuthorization authorizedTaskTypesAndModels, - List enabledConfigIds, - List enabledModelObjects + ElasticInferenceServiceAuthorization taskTypesAndModels, + List configIds, + List modelObjects ) { static AuthorizedContent empty() { return new AuthorizedContent(ElasticInferenceServiceAuthorization.newDisabledService(), List.of(), List.of()); @@ -168,23 +168,23 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); // recalculate which default config ids and models are authorized now - var enabledDefaultConfigIds = getEnabledDefaultConfigIds(auth); - var enabledDefaultModelObjects = getEnabledDefaultModelsObjects(auth); - authRef.set(new AuthorizedContent(authorizedTaskTypesAndModels, enabledDefaultConfigIds, enabledDefaultModelObjects)); + var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(auth); + var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(auth); + authRef.set(new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)); - configuration = new Configuration(authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes()); + configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); defaultConfigIds().forEach(modelRegistry::addDefaultIds); } - private List getEnabledDefaultConfigIds(ElasticInferenceServiceAuthorization auth) { - var enabledDefaultModelIds = getEnabledDefaultModelIds(auth); + private List getAuthorizedDefaultConfigIds(ElasticInferenceServiceAuthorization auth) { + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); - var enabledConfigIds = new ArrayList(); - for (var id : enabledDefaultModelIds) { + var authorizedConfigIds = new ArrayList(); + for (var id : authorizedDefaultModelIds) { var model = defaultModels.get(id); if (model != null) { - if (auth.getEnabledTaskTypes().contains(model.getTaskType()) == false) { + if (auth.getAuthorizedTaskTypes().contains(model.getTaskType()) == false) { logger.warn( Strings.format( "The authorization response included the default model: %s, " @@ -194,33 +194,33 @@ private List getEnabledDefaultConfigIds(ElasticInferenceService ) ); } - enabledConfigIds.add(new DefaultConfigId(model.getInferenceEntityId(), model.getTaskType(), this)); + authorizedConfigIds.add(new DefaultConfigId(model.getInferenceEntityId(), model.getTaskType(), this)); } } - return enabledConfigIds; + return authorizedConfigIds; } - private Set getEnabledDefaultModelIds(ElasticInferenceServiceAuthorization auth) { - var enabledModels = auth.getEnabledModels(); - var enabledDefaultModelIds = new HashSet<>(defaultModels.keySet()); - enabledDefaultModelIds.retainAll(enabledModels); + private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) { + var authorizedModels = auth.getAuthorizedModelIds(); + var authorizedDefaultModelIds = new HashSet<>(defaultModels.keySet()); + authorizedDefaultModelIds.retainAll(authorizedModels); - return enabledDefaultModelIds; + return authorizedDefaultModelIds; } - private List getEnabledDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { - var enabledDefaultModelIds = getEnabledDefaultModelIds(auth); + private List getAuthorizedDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); - var enabledModels = new ArrayList(); - for (var id : enabledDefaultModelIds) { + var authorizedModels = new ArrayList(); + for (var id : authorizedDefaultModelIds) { var model = defaultModels.get(id); if (model != null) { - enabledModels.add(model); + authorizedModels.add(model); } } - return enabledModels; + return authorizedModels; } // Default for testing @@ -236,24 +236,24 @@ void waitForAuthorizationToComplete(TimeValue waitTime) { @Override public synchronized Set supportedStreamingTasks() { - var enabledStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - enabledStreamingTaskTypes.retainAll(authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes()); + var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); + authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - if (enabledStreamingTaskTypes.isEmpty() == false) { - enabledStreamingTaskTypes.add(TaskType.ANY); + if (authorizedStreamingTaskTypes.isEmpty() == false) { + authorizedStreamingTaskTypes.add(TaskType.ANY); } - return enabledStreamingTaskTypes; + return authorizedStreamingTaskTypes; } @Override public synchronized List defaultConfigIds() { - return authRef.get().enabledConfigIds; + return authRef.get().configIds; } @Override public synchronized void defaultConfigs(ActionListener> defaultsListener) { - defaultsListener.onResponse(authRef.get().enabledModelObjects); + defaultsListener.onResponse(authRef.get().modelObjects); } @Override @@ -386,12 +386,12 @@ public synchronized InferenceServiceConfiguration getConfiguration() { @Override public synchronized EnumSet supportedTaskTypes() { - return authRef.get().authorizedTaskTypesAndModels.getEnabledTaskTypes(); + return authRef.get().taskTypesAndModels.getAuthorizedTaskTypes(); } @Override public synchronized boolean hideFromConfigurationApi() { - return authRef.get().authorizedTaskTypesAndModels.isEnabled() == false; + return authRef.get().taskTypesAndModels.isAuthorized() == false; } private static ElasticInferenceServiceModel createModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java index d9f524da35226..76721bb6dcd7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java @@ -24,8 +24,8 @@ public class ElasticInferenceServiceAuthorization { private final Map> taskTypeToModels; - private final EnumSet enabledTaskTypes; - private final Set enabledModels; + private final EnumSet authorizedTaskTypes; + private final Set authorizedModelIds; /** * Converts an authorization response from Elastic Inference Service into the {@link ElasticInferenceServiceAuthorization} format. @@ -66,24 +66,28 @@ public static ElasticInferenceServiceAuthorization newDisabledService() { private ElasticInferenceServiceAuthorization( Map> taskTypeToModels, - Set enabledModels, - EnumSet enabledTaskTypes + Set authorizedModelIds, + EnumSet authorizedTaskTypes ) { this.taskTypeToModels = Objects.requireNonNull(taskTypeToModels); - this.enabledModels = Objects.requireNonNull(enabledModels); - this.enabledTaskTypes = Objects.requireNonNull(enabledTaskTypes); + this.authorizedModelIds = Objects.requireNonNull(authorizedModelIds); + this.authorizedTaskTypes = Objects.requireNonNull(authorizedTaskTypes); } - public boolean isEnabled() { - return enabledModels.isEmpty() == false && taskTypeToModels.isEmpty() == false && enabledTaskTypes.isEmpty() == false; + /** + * Returns true if at least one task type and model is authorized. + * @return true if this cluster is authorized for at least one model and task type. + */ + public boolean isAuthorized() { + return authorizedModelIds.isEmpty() == false && taskTypeToModels.isEmpty() == false && authorizedTaskTypes.isEmpty() == false; } - public Set getEnabledModels() { - return Set.copyOf(enabledModels); + public Set getAuthorizedModelIds() { + return Set.copyOf(authorizedModelIds); } - public EnumSet getEnabledTaskTypes() { - return EnumSet.copyOf(enabledTaskTypes); + public EnumSet getAuthorizedTaskTypes() { + return EnumSet.copyOf(authorizedTaskTypes); } /** @@ -116,12 +120,12 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; ElasticInferenceServiceAuthorization that = (ElasticInferenceServiceAuthorization) o; return Objects.equals(taskTypeToModels, that.taskTypeToModels) - && Objects.equals(enabledTaskTypes, that.enabledTaskTypes) - && Objects.equals(enabledModels, that.enabledModels); + && Objects.equals(authorizedTaskTypes, that.authorizedTaskTypes) + && Objects.equals(authorizedModelIds, that.authorizedModelIds); } @Override public int hashCode() { - return Objects.hash(taskTypeToModels, enabledTaskTypes, enabledModels); + return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index 7e5f26e9e02d6..a819bf1b4a513 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -78,9 +78,9 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getEnabledTaskTypes().isEmpty()); - assertTrue(authResponse.getEnabledModels().isEmpty()); - assertFalse(authResponse.isEnabled()); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); @@ -99,9 +99,9 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getEnabledTaskTypes().isEmpty()); - assertTrue(authResponse.getEnabledModels().isEmpty()); - assertFalse(authResponse.isEnabled()); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); @@ -134,9 +134,9 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getEnabledTaskTypes().isEmpty()); - assertTrue(authResponse.getEnabledModels().isEmpty()); - assertFalse(authResponse.isEnabled()); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); @@ -185,9 +185,9 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.getEnabledTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertThat(authResponse.getEnabledModels(), is(Set.of("model-a"))); - assertTrue(authResponse.isEnabled()); + assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertTrue(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java index 6ce456ac5af91..559de47232a7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java @@ -28,8 +28,8 @@ public static ElasticInferenceServiceAuthorization createEnabledAuth() { ); } - public void testIsEnabled_ReturnsFalse_WithEmptyMap() { - assertFalse(ElasticInferenceServiceAuthorization.newDisabledService().isEnabled()); + public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { + assertFalse(ElasticInferenceServiceAuthorization.newDisabledService().isAuthorized()); } public void testExcludes_ModelsWithoutTaskTypes() { @@ -37,8 +37,8 @@ public void testExcludes_ModelsWithoutTaskTypes() { List.of(new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.noneOf(TaskType.class))) ); var auth = ElasticInferenceServiceAuthorization.of(response); - assertTrue(auth.getEnabledTaskTypes().isEmpty()); - assertFalse(auth.isEnabled()); + assertTrue(auth.getAuthorizedTaskTypes().isEmpty()); + assertFalse(auth.isAuthorized()); } public void testEnabledTaskTypes_MergesFromSeparateModels() { @@ -50,8 +50,8 @@ public void testEnabledTaskTypes_MergesFromSeparateModels() { ) ) ); - assertThat(auth.getEnabledTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); - assertThat(auth.getEnabledModels(), is(Set.of("model-1", "model-2"))); + assertThat(auth.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); + assertThat(auth.getAuthorizedModelIds(), is(Set.of("model-1", "model-2"))); } public void testEnabledTaskTypes_FromSingleEntry() { @@ -66,8 +66,8 @@ public void testEnabledTaskTypes_FromSingleEntry() { ) ); - assertThat(auth.getEnabledTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); - assertThat(auth.getEnabledModels(), is(Set.of("model-1"))); + assertThat(auth.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); + assertThat(auth.getAuthorizedModelIds(), is(Set.of("model-1"))); } public void testNewLimitToTaskTypes_SingleModel() { From 1bc2555ab043e91f012c34685cd183901d010196 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 28 Jan 2025 08:42:48 -0500 Subject: [PATCH 5/5] Addressing feedback and pull main --- .../inference/MinimalServiceSettings.java | 5 ++ ...icInferenceServiceAuthorizationServer.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- .../inference/registry/ModelRegistry.java | 4 +- .../elastic/ElasticInferenceService.java | 60 +++++++++++-------- .../registry/ModelRegistryTests.java | 14 +++++ .../elastic/ElasticInferenceServiceTests.java | 13 +++- 7 files changed, 69 insertions(+), 31 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index be380d74093af..4c81296725809 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -21,6 +21,7 @@ import java.util.Objects; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; import static org.elasticsearch.inference.TaskType.COMPLETION; import static org.elasticsearch.inference.TaskType.RERANK; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; @@ -97,6 +98,10 @@ public static MinimalServiceSettings completion() { return new MinimalServiceSettings(COMPLETION, null, null, null); } + public static MinimalServiceSettings chatCompletion() { + return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null); + } + public MinimalServiceSettings(Model model) { this( model.getTaskType(), diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index b01a5dbeca18a..3ea011c1317cc 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -34,7 +34,7 @@ public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowS "task_types": ["chat"] }, { - "model_name": ".elser_model_2", + "model_name": "elser-v2", "task_types": ["embed/text/sparse"] } ] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 4824f4e44b35a..73af12dacfadf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -108,7 +108,7 @@ protected void masterOperation( return; } - if (modelRegistry.matchesDefaultConfigId(request.getInferenceEntityId())) { + if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) { listener.onFailure( new ElasticsearchStatusException( "[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 35f0a1d32f27c..a9642a685aec9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -120,8 +120,8 @@ public ModelRegistry(Client client) { * @param inferenceEntityId the id to search for * @return true if we find a match and false if not */ - public boolean matchesDefaultConfigId(String inferenceEntityId) { - return idMatchedDefault(inferenceEntityId, defaultConfigIds).isPresent(); + public boolean containsDefaultConfigId(String inferenceEntityId) { + return defaultConfigIds.containsKey(inferenceEntityId); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 498db2ae127ed..8b6f8cd3d70c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -23,6 +23,7 @@ import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -101,8 +102,9 @@ public class ElasticInferenceService extends SenderService { private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; private final CountDownLatch authorizationCompletedLatch = new CountDownLatch(1); - // model ids to model object - private final Map defaultModels; + // model ids to model information, used for the default config methods to return the list of models and default + // configs + private final Map defaultModelsConfigs; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -117,30 +119,37 @@ public ElasticInferenceService( this.authorizationHandler = Objects.requireNonNull(authorizationHandler); configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - defaultModels = initDefaultEndpoints(elasticInferenceServiceComponents); + defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents); getAuthorization(); } - private static Map initDefaultEndpoints(ElasticInferenceServiceComponents elasticInferenceServiceComponents) { + private static Map initDefaultEndpoints( + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { return Map.of( DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, - new ElasticInferenceServiceCompletionModel( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.CHAT_COMPLETION, - NAME, - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents + new DefaultModelConfig( + new ElasticInferenceServiceCompletionModel( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + NAME, + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + MinimalServiceSettings.chatCompletion() ) ); } + private record DefaultModelConfig(Model model, MinimalServiceSettings settings) {} + private record AuthorizedContent( ElasticInferenceServiceAuthorization taskTypesAndModels, List configIds, - List modelObjects + List defaultModelConfigs ) { static AuthorizedContent empty() { return new AuthorizedContent(ElasticInferenceServiceAuthorization.newDisabledService(), List.of(), List.of()); @@ -182,19 +191,19 @@ private List getAuthorizedDefaultConfigIds(ElasticInferenceServ var authorizedConfigIds = new ArrayList(); for (var id : authorizedDefaultModelIds) { - var model = defaultModels.get(id); - if (model != null) { - if (auth.getAuthorizedTaskTypes().contains(model.getTaskType()) == false) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + if (auth.getAuthorizedTaskTypes().contains(modelConfig.model.getTaskType()) == false) { logger.warn( Strings.format( "The authorization response included the default model: %s, " + "but did not authorize the assumed task type of the model: %s. Enabling model.", id, - model.getTaskType() + modelConfig.model.getTaskType() ) ); } - authorizedConfigIds.add(new DefaultConfigId(model.getInferenceEntityId(), model.getTaskType(), this)); + authorizedConfigIds.add(new DefaultConfigId(modelConfig.model.getInferenceEntityId(), modelConfig.settings(), this)); } } @@ -203,20 +212,20 @@ private List getAuthorizedDefaultConfigIds(ElasticInferenceServ private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) { var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new HashSet<>(defaultModels.keySet()); + var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); authorizedDefaultModelIds.retainAll(authorizedModels); return authorizedDefaultModelIds; } - private List getAuthorizedDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { + private List getAuthorizedDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); - var authorizedModels = new ArrayList(); + var authorizedModels = new ArrayList(); for (var id : authorizedDefaultModelIds) { - var model = defaultModels.get(id); - if (model != null) { - authorizedModels.add(model); + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + authorizedModels.add(modelConfig); } } @@ -253,7 +262,8 @@ public synchronized List defaultConfigIds() { @Override public synchronized void defaultConfigs(ActionListener> defaultsListener) { - defaultsListener.onResponse(authRef.get().modelObjects); + var models = authRef.get().defaultModelConfigs.stream().map(config -> config.model).toList(); + defaultsListener.onResponse(models); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 1e47a9b8d5ab6..162bcc8f09713 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -310,6 +310,20 @@ public void testIdMatchedDefault() { assertFalse(matched.isPresent()); } + public void testContainsDefaultConfigId() { + var client = mockClient(); + var registry = new ModelRegistry(client); + + registry.addDefaultIds( + new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class)) + ); + registry.addDefaultIds( + new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class)) + ); + assertTrue(registry.containsDefaultConfigId("foo")); + assertFalse(registry.containsDefaultConfigId("baz")); + } + public void testTaskTypeMatchedDefaults() { var defaultConfigIds = new ArrayList(); defaultConfigIds.add( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 5d777e4825598..82fc3a8e93424 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -900,7 +901,11 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertThat( service.defaultConfigIds(), - is(List.of(new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION, service))) + is( + List.of( + new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) + ) + ) ); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); @@ -930,7 +935,11 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertThat( service.defaultConfigIds(), - is(List.of(new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION, service))) + is( + List.of( + new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) + ) + ) ); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));