diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index be380d74093af..4c81296725809 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -21,6 +21,7 @@ import java.util.Objects; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; import static org.elasticsearch.inference.TaskType.COMPLETION; import static org.elasticsearch.inference.TaskType.RERANK; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; @@ -97,6 +98,10 @@ public static MinimalServiceSettings completion() { return new MinimalServiceSettings(COMPLETION, null, null, null); } + public static MinimalServiceSettings chatCompletion() { + return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null); + } + public MinimalServiceSettings(Model model) { this( model.getTaskType(), diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java new file mode 100644 index 0000000000000..230b7ff576296 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file has been contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; + +public class BaseMockEISAuthServerTest extends ESRestTestCase { + + // The reason we're retrying is there's a race condition between the node retrieving the + // authorization response and running the test. Retrieving the authorization should be very fast since + // we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure + // we'll automatically retry after waiting a second. + @Rule + public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1)); + + private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer + .enabledWithRainbowSprinklesAndElser(); + + private static final ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.security.enabled", "true") + // Adding both settings unless one feature flag is disabled in a particular environment + .setting("xpack.inference.elastic.url", mockEISServer::getUrl) + // TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL + .setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl) + // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin + .plugin("inference-service-test") + .user("x_pack_rest_user", "x-pack-test-password") + .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED) + .build(); + + // The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating + // it to the cluster as a setting. + @ClassRule + public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 49b2f5b041b9e..5174b5bbb8cb4 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,20 +171,20 @@ static String mockDenseServiceModelConfig() { """; } - protected void deleteModel(String modelId) throws IOException { + static void deleteModel(String modelId) throws IOException { var request = new Request("DELETE", "_inference/" + modelId); var response = client().performRequest(request); assertStatusOkOrCreated(response); } - protected Response deleteModel(String modelId, String queryParams) throws IOException { + static Response deleteModel(String modelId, String queryParams) throws IOException { var request = new Request("DELETE", "_inference/" + modelId + "?" + queryParams); var response = client().performRequest(request); assertStatusOkOrCreated(response); return response; } - protected void deleteModel(String modelId, TaskType taskType) throws IOException { + static void deleteModel(String modelId, TaskType taskType) throws IOException { var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId)); var response = client().performRequest(request); assertStatusOkOrCreated(response); @@ -229,12 +229,12 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin assertStatusOkOrCreated(response); } - protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { + static Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId); return putRequest(endpoint, modelConfig); } - protected Map updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException { + static Map updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID); return putRequest(endpoint, modelConfig); } @@ -265,12 +265,12 @@ protected void deletePipeline(String pipelineId) throws IOException { /** * Task type should be in modelConfig */ - protected Map putModel(String modelId, String modelConfig) throws IOException { + static Map putModel(String modelId, String modelConfig) throws IOException { String endpoint = Strings.format("_inference/%s", modelId); return putRequest(endpoint, modelConfig); } - Map putRequest(String endpoint, String body) throws IOException { + static Map putRequest(String endpoint, String body) throws IOException { var request = new Request("PUT", endpoint); request.setJsonEntity(body); var response = client().performRequest(request); @@ -318,18 +318,17 @@ protected Map getModel(String modelId) throws IOException { } @SuppressWarnings("unchecked") - protected List> getModels(String modelId, TaskType taskType) throws IOException { + static List> getModels(String modelId, TaskType taskType) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return (List>) getInternalAsMap(endpoint).get("endpoints"); } @SuppressWarnings("unchecked") - protected List> getAllModels() throws IOException { - var endpoint = Strings.format("_inference/_all"); + static List> getAllModels() throws IOException { return (List>) getInternalAsMap("_inference/_all").get("endpoints"); } - private Map getInternalAsMap(String endpoint) throws IOException { + private static Map getInternalAsMap(String endpoint) throws IOException { var request = new Request("GET", endpoint); var response = client().performRequest(request); assertStatusOkOrCreated(response); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java new file mode 100644 index 0000000000000..76483a5f62fec --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file has been contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; + +import java.io.IOException; + +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; +import static org.hamcrest.Matchers.hasSize; + +public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest { + + public void testGetDefaultEndpoints() throws IOException { + var allModels = getAllModels(); + var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); + + if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() + || ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) { + assertThat(allModels, hasSize(4)); + assertThat(chatCompletionModels, hasSize(1)); + + for (var model : chatCompletionModels) { + assertEquals("chat_completion", model.get("task_type")); + } + } else { + assertThat(allModels, hasSize(3)); + assertThat(chatCompletionModels, hasSize(0)); + } + + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index b448acd5f4a74..856fdeb6287e9 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -11,20 +11,8 @@ import org.elasticsearch.client.Request; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.cluster.ElasticsearchCluster; -import org.elasticsearch.test.cluster.FeatureFlag; -import org.elasticsearch.test.cluster.local.distribution.DistributionType; -import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; -import org.junit.ClassRule; -import org.junit.Rule; -import org.junit.rules.RuleChain; -import org.junit.rules.TestRule; import java.io.IOException; import java.util.ArrayList; @@ -35,47 +23,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; import static org.hamcrest.Matchers.equalTo; -public class InferenceGetServicesIT extends ESRestTestCase { - - // The reason we're retrying is there's a race condition between the node retrieving the - // authorization response and running the test. Retrieving the authorization should be very fast since - // we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure - // we'll automatically retry after waiting a second. - @Rule - public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1)); - - private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer - .enabledWithSparseEmbeddingsAndChatCompletion(); - - private static final ElasticsearchCluster cluster = ElasticsearchCluster.local() - .distribution(DistributionType.DEFAULT) - .setting("xpack.license.self_generated.type", "trial") - .setting("xpack.security.enabled", "true") - // Adding both settings unless one feature flag is disabled in a particular environment - .setting("xpack.inference.elastic.url", mockEISServer::getUrl) - // TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL - .setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl) - // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin - .plugin("inference-service-test") - .user("x_pack_rest_user", "x-pack-test-password") - .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED) - .build(); - - // The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating - // it to the cluster as a setting. - @ClassRule - public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster); - - @Override - protected String getTestRestCluster() { - return cluster.getHttpAddresses(); - } - - @Override - protected Settings restClientSettings() { - String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); - return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); - } +public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 8960a7e1b0258..3ea011c1317cc 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -23,15 +23,19 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class); private final MockWebServer webServer = new MockWebServer(); - public static MockElasticInferenceServiceAuthorizationServer enabledWithSparseEmbeddingsAndChatCompletion() { + public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() { var server = new MockElasticInferenceServiceAuthorizationServer(); String responseJson = """ { "models": [ { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + }, + { + "model_name": "elser-v2", + "task_types": ["embed/text/sparse"] } ] } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 6168edeca4820..73af12dacfadf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -108,6 +108,17 @@ protected void masterOperation( return; } + if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) { + listener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.", + RestStatus.BAD_REQUEST, + request.getInferenceEntityId() + ) + ); + return; + } + var requestAsMap = requestToMap(request); var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 1369ebf7dd87b..a9642a685aec9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -114,6 +114,16 @@ public ModelRegistry(Client client) { defaultConfigIds = new HashMap<>(); } + /** + * Returns true if the provided inference entity id is the same as one of the default + * endpoints ids. + * @param inferenceEntityId the id to search for + * @return true if we find a match and false if not + */ + public boolean containsDefaultConfigId(String inferenceEntityId) { + return defaultConfigIds.containsKey(inferenceEntityId); + } + /** * Set the default inference ids provided by the services * @param defaultConfigId The default diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index f96d3cb325b09..8b8723b54d683 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -16,9 +18,12 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -47,11 +52,14 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorization; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -77,8 +85,11 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; + private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class); private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; + static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); /** * The task types that the {@link InferenceAction.Request} can accept. @@ -87,10 +98,13 @@ public class ElasticInferenceService extends SenderService { private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; private Configuration configuration; - private final AtomicReference> enabledTaskTypesRef = new AtomicReference<>(EnumSet.noneOf(TaskType.class)); + private final AtomicReference authRef = new AtomicReference<>(AuthorizedContent.empty()); private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; private final CountDownLatch authorizationCompletedLatch = new CountDownLatch(1); + // model ids to model information, used for the default config methods to return the list of models and default + // configs + private final Map defaultModelsConfigs; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -104,15 +118,48 @@ public ElasticInferenceService( this.modelRegistry = Objects.requireNonNull(modelRegistry); this.authorizationHandler = Objects.requireNonNull(authorizationHandler); - configuration = new Configuration(enabledTaskTypesRef.get()); + configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); + defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents); getAuthorization(); } + private static Map initDefaultEndpoints( + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + return Map.of( + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + new DefaultModelConfig( + new ElasticInferenceServiceCompletionModel( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + NAME, + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + MinimalServiceSettings.chatCompletion() + ) + ); + } + + private record DefaultModelConfig(Model model, MinimalServiceSettings settings) {} + + private record AuthorizedContent( + ElasticInferenceServiceAuthorization taskTypesAndModels, + List configIds, + List defaultModelConfigs + ) { + static AuthorizedContent empty() { + return new AuthorizedContent(ElasticInferenceServiceAuthorization.newDisabledService(), List.of(), List.of()); + } + } + private void getAuthorization() { try { ActionListener listener = ActionListener.wrap(result -> { - setEnabledTaskTypes(result); + setAuthorizedContent(result); authorizationCompletedLatch.countDown(); }, e -> { // we don't need to do anything if there was a failure, everything is disabled by default @@ -126,17 +173,63 @@ private void getAuthorization() { } } - private synchronized void setEnabledTaskTypes(ElasticInferenceServiceAuthorization auth) { - enabledTaskTypesRef.set(filterTaskTypesByAuthorization(auth)); - configuration = new Configuration(enabledTaskTypesRef.get()); + private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorization auth) { + var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); + + // recalculate which default config ids and models are authorized now + var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(auth); + var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(auth); + authRef.set(new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)); + + configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); defaultConfigIds().forEach(modelRegistry::addDefaultIds); } - private static EnumSet filterTaskTypesByAuthorization(ElasticInferenceServiceAuthorization auth) { - var implementedTaskTypes = EnumSet.copyOf(IMPLEMENTED_TASK_TYPES); - implementedTaskTypes.retainAll(auth.enabledTaskTypes()); - return implementedTaskTypes; + private List getAuthorizedDefaultConfigIds(ElasticInferenceServiceAuthorization auth) { + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); + + var authorizedConfigIds = new ArrayList(); + for (var id : authorizedDefaultModelIds) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + if (auth.getAuthorizedTaskTypes().contains(modelConfig.model.getTaskType()) == false) { + logger.warn( + Strings.format( + "The authorization response included the default model: %s, " + + "but did not authorize the assumed task type of the model: %s. Enabling model.", + id, + modelConfig.model.getTaskType() + ) + ); + } + authorizedConfigIds.add(new DefaultConfigId(modelConfig.model.getInferenceEntityId(), modelConfig.settings(), this)); + } + } + + return authorizedConfigIds; + } + + private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) { + var authorizedModels = auth.getAuthorizedModelIds(); + var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); + authorizedDefaultModelIds.retainAll(authorizedModels); + + return authorizedDefaultModelIds; + } + + private List getAuthorizedDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); + + var authorizedModels = new ArrayList(); + for (var id : authorizedDefaultModelIds) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + authorizedModels.add(modelConfig); + } + } + + return authorizedModels; } // Default for testing @@ -152,20 +245,25 @@ void waitForAuthorizationToComplete(TimeValue waitTime) { @Override public synchronized Set supportedStreamingTasks() { - var enabledStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - enabledStreamingTaskTypes.retainAll(enabledTaskTypesRef.get()); + var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); + authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - if (enabledStreamingTaskTypes.isEmpty() == false) { - enabledStreamingTaskTypes.add(TaskType.ANY); + if (authorizedStreamingTaskTypes.isEmpty() == false) { + authorizedStreamingTaskTypes.add(TaskType.ANY); } - return enabledStreamingTaskTypes; + return authorizedStreamingTaskTypes; } @Override public synchronized List defaultConfigIds() { - // TODO once we have the enabledTaskTypes figure out which default endpoints we should expose - return List.of(); + return authRef.get().configIds; + } + + @Override + public synchronized void defaultConfigs(ActionListener> defaultsListener) { + var models = authRef.get().defaultModelConfigs.stream().map(config -> config.model).toList(); + defaultsListener.onResponse(models); } @Override @@ -298,12 +396,12 @@ public synchronized InferenceServiceConfiguration getConfiguration() { @Override public synchronized EnumSet supportedTaskTypes() { - return enabledTaskTypesRef.get(); + return authRef.get().taskTypesAndModels.getAuthorizedTaskTypes(); } @Override public synchronized boolean hideFromConfigurationApi() { - return enabledTaskTypesRef.get().isEmpty(); + return authRef.get().taskTypesAndModels.isAuthorized() == false; } private static ElasticInferenceServiceModel createModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java index eac64021ac85a..76721bb6dcd7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java @@ -12,16 +12,20 @@ import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; /** - * Provides a structure for governing which models (if any) a cluster has access to according to the upstream Elastic Inference Service. - * @param enabledModels a mapping of model ids to a set of {@link TaskType} to indicate which models are available and for which task types + * This is a helper class for managing the response from {@link ElasticInferenceServiceAuthorizationHandler}. */ -public record ElasticInferenceServiceAuthorization(Map> enabledModels) { +public class ElasticInferenceServiceAuthorization { + + private final Map> taskTypeToModels; + private final EnumSet authorizedTaskTypes; + private final Set authorizedModelIds; /** * Converts an authorization response from Elastic Inference Service into the {@link ElasticInferenceServiceAuthorization} format. @@ -30,45 +34,98 @@ public record ElasticInferenceServiceAuthorization(Map * @return a new {@link ElasticInferenceServiceAuthorization} */ public static ElasticInferenceServiceAuthorization of(ElasticInferenceServiceAuthorizationResponseEntity responseEntity) { - var enabledModels = new HashMap>(); + var taskTypeToModelsMap = new HashMap>(); + var enabledTaskTypesSet = EnumSet.noneOf(TaskType.class); + var enabledModelsSet = new HashSet(); for (var model : responseEntity.getAuthorizedModels()) { // if there are no task types we'll ignore the model because it's likely we didn't understand // the task type and don't support it anyway if (model.taskTypes().isEmpty() == false) { - enabledModels.put(model.modelName(), model.taskTypes()); + for (var taskType : model.taskTypes()) { + taskTypeToModelsMap.merge(taskType, Set.of(model.modelName()), (existingModelIds, newModelIds) -> { + var combinedNames = new HashSet<>(existingModelIds); + combinedNames.addAll(newModelIds); + return combinedNames; + }); + enabledTaskTypesSet.add(taskType); + } + enabledModelsSet.add(model.modelName()); } } - return new ElasticInferenceServiceAuthorization(enabledModels); + return new ElasticInferenceServiceAuthorization(taskTypeToModelsMap, enabledModelsSet, enabledTaskTypesSet); } /** * Returns an object indicating that the cluster has no access to Elastic Inference Service. */ public static ElasticInferenceServiceAuthorization newDisabledService() { - return new ElasticInferenceServiceAuthorization(); + return new ElasticInferenceServiceAuthorization(Map.of(), Set.of(), EnumSet.noneOf(TaskType.class)); + } + + private ElasticInferenceServiceAuthorization( + Map> taskTypeToModels, + Set authorizedModelIds, + EnumSet authorizedTaskTypes + ) { + this.taskTypeToModels = Objects.requireNonNull(taskTypeToModels); + this.authorizedModelIds = Objects.requireNonNull(authorizedModelIds); + this.authorizedTaskTypes = Objects.requireNonNull(authorizedTaskTypes); + } + + /** + * Returns true if at least one task type and model is authorized. + * @return true if this cluster is authorized for at least one model and task type. + */ + public boolean isAuthorized() { + return authorizedModelIds.isEmpty() == false && taskTypeToModels.isEmpty() == false && authorizedTaskTypes.isEmpty() == false; } - public ElasticInferenceServiceAuthorization { - Objects.requireNonNull(enabledModels); + public Set getAuthorizedModelIds() { + return Set.copyOf(authorizedModelIds); + } - for (var taskTypes : enabledModels.values()) { - if (taskTypes.isEmpty()) { - throw new IllegalArgumentException("Authorization task types must not be empty"); + public EnumSet getAuthorizedTaskTypes() { + return EnumSet.copyOf(authorizedTaskTypes); + } + + /** + * Returns a new {@link ElasticInferenceServiceAuthorization} object retaining only the specified task types + * and applicable models that leverage those task types. Any task types not specified in the passed in set will be + * excluded from the returned object. This is essentially an intersection. + * @param taskTypes the task types to retain in the newly created object + * @return a new object containing models and task types limited to the specified set. + */ + public ElasticInferenceServiceAuthorization newLimitedToTaskTypes(EnumSet taskTypes) { + var newTaskTypeToModels = new HashMap>(); + var taskTypesThatHaveModels = EnumSet.noneOf(TaskType.class); + + for (var taskType : taskTypes) { + var models = taskTypeToModels.get(taskType); + if (models != null) { + newTaskTypeToModels.put(taskType, models); + // we only want task types that correspond to actual models to ensure we're only enabling valid task types + taskTypesThatHaveModels.add(taskType); } } - } - private ElasticInferenceServiceAuthorization() { - this(Map.of()); + Set newEnabledModels = newTaskTypeToModels.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); + + return new ElasticInferenceServiceAuthorization(newTaskTypeToModels, newEnabledModels, taskTypesThatHaveModels); } - public boolean isEnabled() { - return enabledModels.isEmpty() == false; + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ElasticInferenceServiceAuthorization that = (ElasticInferenceServiceAuthorization) o; + return Objects.equals(taskTypeToModels, that.taskTypeToModels) + && Objects.equals(authorizedTaskTypes, that.authorizedTaskTypes) + && Objects.equals(authorizedModelIds, that.authorizedModelIds); } - public EnumSet enabledTaskTypes() { - return enabledModels.values().stream().flatMap(Set::stream).collect(Collectors.toCollection(() -> EnumSet.noneOf(TaskType.class))); + @Override + public int hashCode() { + return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index b26f80efb1930..5125ade21339d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -74,7 +74,7 @@ public ElasticInferenceServiceCompletionModel( } - ElasticInferenceServiceCompletionModel( + public ElasticInferenceServiceCompletionModel( String inferenceEntityId, TaskType taskType, String service, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 1e47a9b8d5ab6..162bcc8f09713 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -310,6 +310,20 @@ public void testIdMatchedDefault() { assertFalse(matched.isPresent()); } + public void testContainsDefaultConfigId() { + var client = mockClient(); + var registry = new ModelRegistry(client); + + registry.addDefaultIds( + new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class)) + ); + registry.addDefaultIds( + new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class)) + ); + assertTrue(registry.containsDefaultConfigId("foo")); + assertFalse(registry.containsDefaultConfigId("baz")); + } + public void testTaskTypeMatchedDefaults() { var defaultConfigIds = new ArrayList(); defaultConfigIds.add( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 334119f999e4c..5a3a9a29d7564 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -20,9 +20,11 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -38,6 +40,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; @@ -554,7 +557,16 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() thr public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception { try ( var service = createServiceWithMockSender( - new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING))) + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ) + ) + ) + ) ) ) { assertTrue(service.hideFromConfigurationApi()); @@ -564,7 +576,16 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { try ( var service = createServiceWithMockSender( - new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.CHAT_COMPLETION))) + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ) ) ) { assertFalse(service.hideFromConfigurationApi()); @@ -574,7 +595,16 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro public void testGetConfiguration() throws Exception { try ( var service = createServiceWithMockSender( - new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))) + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) + ) + ) + ) + ) ) ) { String content = XContentHelper.stripWhitespace(""" @@ -685,7 +715,16 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO try ( var service = createServiceWithMockSender( // this service doesn't yet support text embedding so we should still have no task types - new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING))) + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ) + ) + ) + ) ) ) { String content = XContentHelper.stripWhitespace(""" @@ -758,6 +797,60 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi service.waitForAuthorizationToComplete(TIMEOUT); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertTrue(service.defaultConfigIds().isEmpty()); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertTrue(listener.actionGet(TIMEOUT).isEmpty()); + } + } + + public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimplementedTaskTypes() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "model-b", + "task_types": ["embed"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + } + } + + public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "model-b", + "task_types": ["chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); } } @@ -780,6 +873,79 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat service.waitForAuthorizationToComplete(TIMEOUT); assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertTrue(service.defaultConfigIds().isEmpty()); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertTrue(listener.actionGet(TIMEOUT).isEmpty()); + } + } + + public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIncorrect() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["embed/text/sparse"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat( + service.defaultConfigIds(), + is( + List.of( + new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) + ) + ) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + } + } + + public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat( + service.defaultConfigIds(), + is( + List.of( + new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) + ) + ) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index 43cac4c54aa3c..a819bf1b4a513 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.EnumSet; import java.util.List; +import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -77,8 +78,9 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.enabledTaskTypes().isEmpty()); - assertFalse(authResponse.isEnabled()); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); @@ -97,8 +99,9 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.enabledTaskTypes().isEmpty()); - assertFalse(authResponse.isEnabled()); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); @@ -131,8 +134,9 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.enabledTaskTypes().isEmpty()); - assertFalse(authResponse.isEnabled()); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); @@ -181,8 +185,9 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.enabledTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertTrue(authResponse.isEnabled()); + assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertTrue(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java index 20b52cb7bb314..559de47232a7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java @@ -13,17 +13,23 @@ import java.util.EnumSet; import java.util.List; -import java.util.Map; +import java.util.Set; import static org.hamcrest.Matchers.is; public class ElasticInferenceServiceAuthorizationTests extends ESTestCase { public static ElasticInferenceServiceAuthorization createEnabledAuth() { - return new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING))); + return ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)) + ) + ) + ); } - public void testIsEnabled_ReturnsFalse_WithEmptyMap() { - assertFalse(ElasticInferenceServiceAuthorization.newDisabledService().isEnabled()); + public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { + assertFalse(ElasticInferenceServiceAuthorization.newDisabledService().isAuthorized()); } public void testExcludes_ModelsWithoutTaskTypes() { @@ -31,31 +37,196 @@ public void testExcludes_ModelsWithoutTaskTypes() { List.of(new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.noneOf(TaskType.class))) ); var auth = ElasticInferenceServiceAuthorization.of(response); - assertTrue(auth.enabledTaskTypes().isEmpty()); - assertFalse(auth.isEnabled()); + assertTrue(auth.getAuthorizedTaskTypes().isEmpty()); + assertFalse(auth.isAuthorized()); } - public void testConstructor_WithModelWithoutTaskTypes_ThrowsException() { - expectThrows( - IllegalArgumentException.class, - () -> new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.noneOf(TaskType.class))) + public void testEnabledTaskTypes_MergesFromSeparateModels() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) + ) + ) ); + assertThat(auth.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); + assertThat(auth.getAuthorizedModelIds(), is(Set.of("model-1", "model-2"))); } - public void testEnabledTaskTypes_MergesFromSeparateModels() { + public void testEnabledTaskTypes_FromSingleEntry() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ); + + assertThat(auth.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); + assertThat(auth.getAuthorizedModelIds(), is(Set.of("model-1"))); + } + + public void testNewLimitToTaskTypes_SingleModel() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.CHAT_COMPLETION)) + ) + ) + ); + assertThat( - new ElasticInferenceServiceAuthorization( - Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING), "model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ).enabledTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)) + auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ) + ) + ) + ) + ) ); } - public void testEnabledTaskTypes_FromSingleEntry() { + public void testNewLimitToTaskTypes_MultipleModels_OnlyTextEmbedding() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.TEXT_EMBEDDING)) + ) + ) + ); + assertThat( - new ElasticInferenceServiceAuthorization(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))) - .enabledTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)) + auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-2", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ) + ) + ) + ) + ) ); } + + public void testNewLimitToTaskTypes_MultipleModels_MultipleTaskTypes() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-text-sparse", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-sparse", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-chat-completion", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ); + + var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)); + assertThat( + a, + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-text-sparse", + EnumSet.of(TaskType.TEXT_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-chat-completion", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ) + ) + ); + } + + public void testNewLimitToTaskTypes_DuplicateModelNames() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK) + ) + ) + ) + ); + + var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)); + assertThat( + a, + is( + ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK) + ) + ) + ) + ) + ) + ); + } + + public void testNewLimitToTaskTypes_ReturnsDisabled_WhenNoOverlapForTaskTypes() { + var auth = ElasticInferenceServiceAuthorization.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-2", + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING) + ) + ) + ) + ); + + var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.RERANK)); + assertThat(a, is(ElasticInferenceServiceAuthorization.newDisabledService())); + } }