diff --git a/muted-tests.yml b/muted-tests.yml index 53c034dfcc3fd..04b9d6a6bd2de 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -414,15 +414,6 @@ tests: - class: org.elasticsearch.indices.mapping.UpdateMappingIntegrationIT method: testUpdateMappingConcurrently issue: https://github.com/elastic/elasticsearch/issues/137758 -- class: org.elasticsearch.xpack.inference.integration.CCMPersistentStorageServiceIT - method: testDelete_RemovesCCMConfiguration - issue: https://github.com/elastic/elasticsearch/issues/137786 -- class: org.elasticsearch.xpack.inference.integration.CCMPersistentStorageServiceIT - method: testDelete_DoesNotThrow_WhenTheConfigurationDoesNotExist - issue: https://github.com/elastic/elasticsearch/issues/137797 -- class: org.elasticsearch.xpack.inference.integration.CCMServiceIT - method: testIsEnabled_ReturnsTrue_WhenCCMConfigurationIsPresent - issue: https://github.com/elastic/elasticsearch/issues/137798 - class: org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceTests method: testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull issue: https://github.com/elastic/elasticsearch/issues/137823 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index bc403a78bc2a8..be4bacec10600 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -25,6 +25,7 @@ import org.junit.rules.TestRule; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT; public class BaseMockEISAuthServerTest extends ESRestTestCase { @@ -46,6 +47,9 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { // calls which would result in a test failure because the webserver is only expecting a single request // So to ensure we avoid that all together, this flag indicates that we'll only perform a single authorization request .setting("xpack.inference.elastic.periodic_authorization_enabled", "false") + // Setting to false so that the CCM logic will be skipped when running the tests, the authorization logic skip trying to determine + // if CCM is enabled + .setting(CCM_SUPPORTED_ENVIRONMENT.getKey(), "false") // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin .plugin("inference-service-test") .user("x_pack_rest_user", "x-pack-test-password") diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudForbiddenIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudForbiddenIT.java index 0a2f04a0aae89..fc0b3a2ddd295 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudForbiddenIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudForbiddenIT.java @@ -21,6 +21,7 @@ import org.junit.ClassRule; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature.CCM_FORBIDDEN_EXCEPTION; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; @@ -31,7 +32,7 @@ public class CCMCrudForbiddenIT extends CCMRestBaseIT { .distribution(DistributionType.DEFAULT) .setting("xpack.license.self_generated.type", "basic") .setting("xpack.security.enabled", "true") - .setting("xpack.inference.elastic.allow_configuring_ccm", "false") + .setting(CCM_SUPPORTED_ENVIRONMENT.getKey(), "false") .user("x_pack_rest_user", "x-pack-test-password") .build(); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudIT.java index 2dcfcbd93dda8..a9f73adc8f352 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudIT.java @@ -26,6 +26,7 @@ import java.io.IOException; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_CCM_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; @@ -36,7 +37,7 @@ public class CCMCrudIT extends CCMRestBaseIT { .distribution(DistributionType.DEFAULT) .setting("xpack.license.self_generated.type", "basic") .setting("xpack.security.enabled", "true") - .setting("xpack.inference.elastic.allow_configuring_ccm", "true") + .setting(CCM_SUPPORTED_ENVIRONMENT.getKey(), "true") .user("x_pack_rest_user", "x-pack-test-password") .build(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 8450ceab04848..e506f20899581 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -87,6 +88,10 @@ public void createComponents() { @After public void shutdown() { + removeEisPreconfiguredEndpoints(modelRegistry); + } + + static void removeEisPreconfiguredEndpoints(ModelRegistry modelRegistry) { // Delete all the eis preconfigured endpoints var listener = new PlainActionFuture(); modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener); @@ -101,6 +106,8 @@ public static void cleanUpClass() { @Override protected Settings nodeSettings() { return Settings.builder() + // Disable CCM to ensure that only the authorization task executor is initialized in the inference plugin when it is created + .put(CCMSettings.CCM_SUPPORTED_ENVIRONMENT.getKey(), false) .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) // Ensure that the polling logic only occurs once so we can deterministically control when an authorization response is // received @@ -123,7 +130,15 @@ public void testCreatesEisChatCompletionEndpoint() throws Exception { } private void assertNoAuthorizedEisEndpoints() throws Exception { - waitForTask(AUTH_TASK_ACTION, admin()); + assertNoAuthorizedEisEndpoints(admin(), authorizationTaskExecutor, modelRegistry); + } + + static void assertNoAuthorizedEisEndpoints( + AdminClient adminClient, + AuthorizationTaskExecutor authorizationTaskExecutor, + ModelRegistry modelRegistry + ) throws Exception { + waitForTask(AUTH_TASK_ACTION, adminClient); assertBusy(() -> { var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); @@ -131,7 +146,7 @@ private void assertNoAuthorizedEisEndpoints() throws Exception { newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); }); - var eisEndpoints = getEisEndpoints(); + var eisEndpoints = getEisEndpoints(modelRegistry); assertThat(eisEndpoints, empty()); for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) { @@ -153,7 +168,22 @@ public static TaskInfo waitForTask(String taskAction, AdminClient adminClient) t return taskRef.get(); } + static void waitForNoTask(String taskAction, AdminClient adminClient) throws Exception { + var builder = new ListTasksRequestBuilder(adminClient.cluster()); + + assertBusy(() -> { + var response = builder.get(); + var authPollerTask = response.getTasks().stream().filter(task -> task.action().equals(taskAction)).findFirst(); + assertFalse(authPollerTask.isPresent()); + }); + + } + private List getEisEndpoints() { + return getEisEndpoints(modelRegistry); + } + + static List getEisEndpoints(ModelRegistry modelRegistry) { var listener = new PlainActionFuture>(); modelRegistry.getAllModels(false, listener); @@ -162,17 +192,26 @@ private List getEisEndpoints() { } private void restartPollingTaskAndWaitForAuthResponse() throws Exception { - cancelAuthorizationTask(admin()); + restartPollingTaskAndWaitForAuthResponse(admin(), authorizationTaskExecutor); + } + + static void restartPollingTaskAndWaitForAuthResponse(AdminClient adminClient, AuthorizationTaskExecutor authTaskExecutor) + throws Exception { + cancelAuthorizationTask(adminClient); // wait for the new task to be recreated and an authorization response to be processed + waitForAuthorizationToComplete(authTaskExecutor); + } + + static void waitForAuthorizationToComplete(AuthorizationTaskExecutor authTaskExecutor) throws Exception { assertBusy(() -> { - var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); + var newPoller = authTaskExecutor.getCurrentPollerTask(); assertNotNull(newPoller); newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); }); } - public static void cancelAuthorizationTask(AdminClient adminClient) throws Exception { + static void cancelAuthorizationTask(AdminClient adminClient) throws Exception { var pollerTask = waitForTask(AUTH_TASK_ACTION, adminClient); var builder = new CancelTasksRequestBuilder(adminClient.cluster()); @@ -202,7 +241,11 @@ public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthor } private void assertChatCompletionEndpointExists() { - var eisEndpoints = getEisEndpoints(); + assertChatCompletionEndpointExists(modelRegistry); + } + + static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) { + var eisEndpoints = getEisEndpoints(modelRegistry); assertThat(eisEndpoints.size(), is(1)); var rainbowSprinklesModel = eisEndpoints.get(0); @@ -212,7 +255,7 @@ private void assertChatCompletionEndpointExists() { ); } - private void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { + static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index cb92c70d27442..2bc5e2df2853b 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -85,6 +86,8 @@ protected Collection> nodePlugins() { protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return Settings.builder() .put(super.nodeSettings(nodeOrdinal, otherSettings)) + // Disable CCM to ensure that only the authorization task executor is initialized in the inference plugin when it is created + .put(CCMSettings.CCM_SUPPORTED_ENVIRONMENT.getKey(), false) .put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial") .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java index 47446dc9517f3..b206ee02b6602 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java @@ -9,16 +9,49 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.TestPlainActionFuture; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMModel; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; +import org.junit.After; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; +import java.io.IOException; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTH_TASK_ACTION; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.assertChatCompletionEndpointExists; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.getEisEndpoints; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.removeEisPreconfiguredEndpoints; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForAuthorizationToComplete; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForNoTask; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; +import static org.elasticsearch.xpack.inference.integration.ModelRegistryIT.buildElserModelConfig; +import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertStoreModel; +import static org.hamcrest.Matchers.empty; + public class CCMServiceIT extends CCMSingleNodeIT { private static final AtomicReference ccmService = new AtomicReference<>(); + private static final MockWebServer webServer = new MockWebServer(); + private static String gatewayUrl; + + private AuthorizationTaskExecutor authorizationTaskExecutor; + private ModelRegistry modelRegistry; + public CCMServiceIT() { super(new Provider() { @Override @@ -38,9 +71,47 @@ public void delete(ActionListener listener) { }); } + @BeforeClass + public static void initClass() throws IOException { + webServer.start(); + gatewayUrl = getUrl(webServer); + } + @Before public void createComponents() { ccmService.set(node().injector().getInstance(CCMService.class)); + modelRegistry = node().injector().getInstance(ModelRegistry.class); + authorizationTaskExecutor = node().injector().getInstance(AuthorizationTaskExecutor.class); + } + + @After + public void shutdown() { + // disable CCM to clean up any stored configuration + disableCCM(); + + removeEisPreconfiguredEndpoints(modelRegistry); + } + + private void disableCCM() { + var listener = new PlainActionFuture(); + ccmService.get().disableCCM(listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + } + + @AfterClass + public static void cleanUpClass() { + webServer.close(); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder() + .put(CCMSettings.CCM_SUPPORTED_ENVIRONMENT.getKey(), true) + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) + // Ensure that the polling logic only occurs once so we can deterministically control when an authorization response is + // received + .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) + .build(); } public void testIsEnabled_ReturnsFalse_WhenNoCCMConfigurationStored() { @@ -58,4 +129,48 @@ public void testIsEnabled_ReturnsTrue_WhenCCMConfigurationIsPresent() { assertTrue(listener.actionGet(TimeValue.THIRTY_SECONDS)); } + + public void testCreatesEisChatCompletionEndpoint() throws Exception { + disableCCM(); + waitForNoTask(AUTH_TASK_ACTION, admin()); + + var eisEndpoints = getEisEndpoints(modelRegistry); + assertThat(eisEndpoints, empty()); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + var listener = new TestPlainActionFuture(); + ccmService.get().storeConfiguration(new CCMModel(new SecureString("secret".toCharArray())), listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + + // Force a cluster state update to ensure the authorization task is created + forceClusterUpdate(); + + waitForTask(AUTH_TASK_ACTION, admin()); + waitForAuthorizationToComplete(authorizationTaskExecutor); + + assertChatCompletionEndpointExists(modelRegistry); + } + + private void forceClusterUpdate() { + var model = buildElserModelConfig("test-store-model", TaskType.SPARSE_EMBEDDING); + assertStoreModel(modelRegistry, model); + } + + public void testDisableCCM_RemovesAuthorizationTask() throws Exception { + disableCCM(); + waitForNoTask(AUTH_TASK_ACTION, admin()); + + var listener = new TestPlainActionFuture(); + ccmService.get().storeConfiguration(new CCMModel(new SecureString("secret".toCharArray())), listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + + // Force a cluster state update to ensure the authorization task is created + forceClusterUpdate(); + + waitForTask(AUTH_TASK_ACTION, admin()); + waitForAuthorizationToComplete(authorizationTaskExecutor); + + disableCCM(); + waitForNoTask(AUTH_TASK_ACTION, admin()); + } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 708edc4279148..1adf2e52060a2 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -1041,7 +1041,7 @@ private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) { } } - private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { + static Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { return switch (taskType) { case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 67a0cd11c6a4c..27f8b991be1a3 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -42,6 +42,7 @@ exports org.elasticsearch.xpack.inference.rest; exports org.elasticsearch.xpack.inference.services; exports org.elasticsearch.xpack.inference.services.elastic.ccm; + exports org.elasticsearch.xpack.inference.services.elastic.authorization; exports org.elasticsearch.xpack.inference; exports org.elasticsearch.xpack.inference.action.task; exports org.elasticsearch.xpack.inference.telemetry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 810421580425c..efae908cc7cea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.support.MappedActionFilter; import org.elasticsearch.client.internal.Client; @@ -151,9 +153,11 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMCache; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMInformedSettings; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMPersistentStorageService; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; @@ -200,6 +204,8 @@ public class InferencePlugin extends Plugin ClusterPlugin, PersistentTaskPlugin { + private static final Logger logger = LogManager.getLogger(InferencePlugin.class); + /** * When this setting is true the verification check that * connects to the external service will not be made at @@ -281,6 +287,7 @@ public List getActions() { new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class), new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class), new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class), + new ActionHandler(AuthorizationTaskExecutor.Action.INSTANCE, AuthorizationTaskExecutor.Action.class), new ActionHandler(GetInferenceFieldsAction.INSTANCE, TransportGetInferenceFieldsAction.class) ); } @@ -315,6 +322,9 @@ public List getRestHandlers( @Override public Collection createComponents(PluginServices services) { var components = new ArrayList<>(); + ccmFeature.set(new CCMFeature(settings)); + components.add(ccmFeature.get()); + var throttlerManager = new ThrottlerManager(settings, services.threadPool()); throttlerManager.init(services.clusterService()); @@ -338,67 +348,54 @@ public Collection createComponents(PluginServices services) { var inferenceServices = new ArrayList<>(inferenceServiceExtensions); inferenceServices.add(this::getInferenceServiceFactories); - var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); + var inferenceServiceSettings = new CCMInformedSettings(settings, ccmFeature.get()); inferenceServiceSettings.init(services.clusterService()); - // Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`). - var elasticInferenceServiceHttpClientManager = HttpClientManager.create( - settings, - services.threadPool(), - services.clusterService(), + var eisRequestSenderFactoryComponents = createEisRequestSenderFactory( + services, throttlerManager, - getSslService(), - inferenceServiceSettings.getConnectionTtl() + inferenceServiceSettings, + ccmFeature.get() ); + var elasticInferenceServiceHttpClientManager = eisRequestSenderFactoryComponents.httpClientManager(); + elasticInferenceServiceFactory.set(eisRequestSenderFactoryComponents.factory()); - var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory( - serviceComponents.get(), - elasticInferenceServiceHttpClientManager, - services.clusterService() - ); - elasticInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); + var sageMakerSchemas = new SageMakerSchemas(); + var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); - var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( - inferenceServiceSettings.getElasticInferenceServiceUrl(), - services.threadPool() + var ccmRelatedComponents = createCCMDependentComponents( + services, + inferenceServiceSettings, + serviceComponents.get(), + elasticInferenceServiceFactory.get().createSender(), + modelRegistry.get(), + ccmFeature.get() ); + components.addAll(ccmRelatedComponents.components()); - var authTaskExecutor = AuthorizationTaskExecutor.create( - services.clusterService(), - new AuthorizationPoller.Parameters( + inferenceServices.add(() -> List.of(context -> { + var eisService = new ElasticInferenceService( + elasticInferenceServiceFactory.get(), serviceComponents.get(), - authorizationHandler, - elasticInferenceServiceFactory.get().createSender(), inferenceServiceSettings, - modelRegistry.get(), - services.client() - ) - ); - authorizationTaskExecutorRef.set(authTaskExecutor); - - var sageMakerSchemas = new SageMakerSchemas(); - var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); - inferenceServices.add( - () -> List.of( - context -> new ElasticInferenceService( - elasticInferenceServiceFactory.get(), - serviceComponents.get(), - inferenceServiceSettings, - context + context, + ccmRelatedComponents.ccmAuthApplierFactory() + ); + eisService.init(); + return eisService; + }, + context -> new SageMakerService( + new SageMakerModelBuilder(sageMakerSchemas), + new SageMakerClient( + new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())), + services.threadPool() ), - context -> new SageMakerService( - new SageMakerModelBuilder(sageMakerSchemas), - new SageMakerClient( - new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())), - services.threadPool() - ), - sageMakerSchemas, - services.threadPool(), - sageMakerConfigurations::getOrCompute, - context - ) + sageMakerSchemas, + services.threadPool(), + sageMakerConfigurations::getOrCompute, + context ) - ); + )); var meterRegistry = services.telemetryProvider().getMeterRegistry(); var inferenceStats = InferenceStats.create(meterRegistry); @@ -436,7 +433,6 @@ public Collection createComponents(PluginServices services) { new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager) ); components.add(inferenceStatsBinding); - components.add(authorizationHandler); components.add(new PluginComponentBinding<>(Sender.class, elasticInferenceServiceFactory.get().createSender())); components.add( new InferenceEndpointRegistry( @@ -449,27 +445,104 @@ public Collection createComponents(PluginServices services) { ) ); - components.add(authTaskExecutor); - components.addAll(createCCMComponents(services)); - return components; } - private Collection createCCMComponents(PluginServices services) { - ccmFeature.set(new CCMFeature(settings)); + private record CCMRelatedComponents(Collection components, CCMAuthenticationApplierFactory ccmAuthApplierFactory) {} + + private CCMRelatedComponents createCCMDependentComponents( + PluginServices services, + ElasticInferenceServiceSettings inferenceServiceSettings, + ServiceComponents serviceComponents, + Sender sender, + ModelRegistry modelRegistry, + CCMFeature ccmFeature + ) { var ccmPersistentStorageService = new CCMPersistentStorageService(services.client()); - return List.of( - new CCMService(ccmPersistentStorageService), - ccmFeature.get(), - ccmPersistentStorageService, - new CCMCache( + var ccmService = new CCMService(ccmPersistentStorageService, services.client()); + var ccmAuthApplierFactory = new CCMAuthenticationApplierFactory(ccmFeature, ccmService); + + var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + inferenceServiceSettings.getElasticInferenceServiceUrl(), + services.threadPool(), + ccmAuthApplierFactory + ); + + var authTaskExecutor = AuthorizationTaskExecutor.create( + services.clusterService(), + new AuthorizationPoller.Parameters( + serviceComponents, + authorizationHandler, + sender, + inferenceServiceSettings, + modelRegistry, + services.client(), + ccmFeature, + ccmService + ) + ); + authorizationTaskExecutorRef.set(authTaskExecutor); + + // If CCM is not allowed in this environment then we can initialize the auth poller task because + // authentication with EIS will be through certs that are already configured. If CCM configuration is allowed, + // we need to wait for the user to provide an API key before we can start polling EIS + if (ccmFeature.isCcmSupportedEnvironment() == false) { + logger.info("CCM configuration is not permitted - starting EIS authorization task executor"); + authTaskExecutor.startAndLazyCreateTask(); + } + + return new CCMRelatedComponents( + List.of( + authorizationHandler, + authTaskExecutor, + ccmService, ccmPersistentStorageService, + new CCMCache( + ccmPersistentStorageService, + services.clusterService(), + settings, + services.featureService(), + services.projectResolver(), + services.client() + ) + ), + ccmAuthApplierFactory + ); + } + + private record EisRequestSenderComponents(HttpRequestSender.Factory factory, HttpClientManager httpClientManager) {} + + private EisRequestSenderComponents createEisRequestSenderFactory( + PluginServices services, + ThrottlerManager throttlerManager, + ElasticInferenceServiceSettings inferenceServiceSettings, + CCMFeature ccmFeature + ) { + // Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`). + HttpClientManager manager; + if (ccmFeature.isCcmSupportedEnvironment()) { + // If ccm is configurable then we aren't using mTLS so ignore the ssl service + manager = HttpClientManager.create( + settings, + services.threadPool(), services.clusterService(), + throttlerManager, + inferenceServiceSettings.getConnectionTtl() + ); + } else { + manager = HttpClientManager.create( settings, - services.featureService(), - services.projectResolver(), - services.client() - ) + services.threadPool(), + services.clusterService(), + throttlerManager, + getSslService(), + inferenceServiceSettings.getConnectionTtl() + ); + } + + return new EisRequestSenderComponents( + new HttpRequestSender.Factory(serviceComponents.get(), manager, services.clusterService()), + manager ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteCCMConfigurationAction.java index c6f457d6f1e8d..3db4055a6147e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteCCMConfigurationAction.java @@ -69,7 +69,7 @@ protected void masterOperation( ClusterState state, ActionListener listener ) { - if (ccmFeature.allowConfiguringCcm() == false) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { listener.onFailure(CCM_FORBIDDEN_EXCEPTION); return; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetCCMConfigurationAction.java index 6281dda2ba4ee..122883bd68f70 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetCCMConfigurationAction.java @@ -50,7 +50,7 @@ public TransportGetCCMConfigurationAction( @Override protected void doExecute(Task task, GetCCMConfigurationAction.Request request, ActionListener listener) { - if (ccmFeature.allowConfiguringCcm() == false) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { listener.onFailure(CCM_FORBIDDEN_EXCEPTION); return; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java index 6aa2d7e7cc0b7..a074dac1b62f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java @@ -70,7 +70,7 @@ protected void masterOperation( ClusterState state, ActionListener listener ) { - if (ccmFeature.allowConfiguringCcm() == false) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { listener.onFailure(CCM_FORBIDDEN_EXCEPTION); return; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java index 08db124fdf369..43e1f6754fe3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java @@ -24,6 +24,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ssl.SSLService; @@ -104,7 +105,17 @@ public static HttpClientManager create( ClusterService clusterService, ThrottlerManager throttlerManager ) { - PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(); + return create(settings, threadPool, clusterService, throttlerManager, null); + } + + public static HttpClientManager create( + Settings settings, + ThreadPool threadPool, + ClusterService clusterService, + ThrottlerManager throttlerManager, + @Nullable TimeValue connectionTtl + ) { + PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(connectionTtl); return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager); } @@ -173,7 +184,7 @@ private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIO ); } - private static PoolingNHttpClientConnectionManager createConnectionManager() { + private static PoolingNHttpClientConnectionManager createConnectionManager(@Nullable TimeValue connectionTtl) { ConnectingIOReactor ioReactor; try { var configBuilder = IOReactorConfig.custom().setSoKeepAlive(true); @@ -184,13 +195,27 @@ private static PoolingNHttpClientConnectionManager createConnectionManager() { throw new ElasticsearchException(message, e); } + var registry = RegistryBuilder.create() + .register("http", NoopIOSessionStrategy.INSTANCE) + .register("https", SSLIOSessionStrategy.getDefaultStrategy()) + .build(); + + // -1 is used as the default within the PoolingNHttpClientConnectionManager to indicate no TTL + var connectionTtlMillis = connectionTtl == null ? -1 : connectionTtl.getMillis(); + /* - The max time to live for open connections in the pool will not be set because we don't specify a ttl in the constructor. - This meaning that there should not be a limit. - We can control the TTL dynamically using the IdleConnectionEvictor and keep-alive strategy. + If the connection TTL is not set, the TTL will be controlled using the IdleConnectionEvictor and keep-alive strategy. The max idle time cluster setting will dictate how much time an open connection can be unused for before it can be closed. */ - return new PoolingNHttpClientConnectionManager(ioReactor); + return new PoolingNHttpClientConnectionManager( + ioReactor, + null, + registry, + null, + null, + Math.toIntExact(connectionTtlMillis), + TimeUnit.MILLISECONDS + ); } private void addSettingsUpdateConsumers(ClusterService clusterService) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java index b2f073c1eb69c..2c94a40dfb62a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java @@ -22,7 +22,11 @@ public class RequestUtils { public static Header createAuthBearerHeader(SecureString apiKey) { - return new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.toString()); + return new BasicHeader(HttpHeaders.AUTHORIZATION, bearerToken(apiKey.toString())); + } + + public static String bearerToken(String apiKey) { + return "Bearer " + apiKey; } public static URI buildUri(URI accountUri, String service, CheckedSupplier uriBuilder) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteCCMConfigurationAction.java index ddd19f767db45..f635e5814e051 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteCCMConfigurationAction.java @@ -46,7 +46,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - if (ccmFeature.allowConfiguringCcm() == false) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { throw CCM_FORBIDDEN_EXCEPTION; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetCCMConfigurationAction.java index c66cb9fa2722c..7341dab634d85 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetCCMConfigurationAction.java @@ -44,7 +44,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - if (ccmFeature.allowConfiguringCcm() == false) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { throw CCM_FORBIDDEN_EXCEPTION; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutCCMConfigurationAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutCCMConfigurationAction.java index edc4707eae9ec..b57b12caff973 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutCCMConfigurationAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutCCMConfigurationAction.java @@ -46,7 +46,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - if (ccmFeature.allowConfiguringCcm() == false) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { throw CCM_FORBIDDEN_EXCEPTION; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/DefaultModelConfig.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/DefaultModelConfig.java deleted file mode 100644 index dcdf5bce1fbb4..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/DefaultModelConfig.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic; - -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; - -public record DefaultModelConfig(Model model, MinimalServiceSettings settings) { - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8a6d2626ad521..0ba5dddb5fcae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -49,6 +48,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; @@ -59,12 +59,10 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Set; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -108,26 +106,36 @@ public class ElasticInferenceService extends SenderService { ); private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private final CCMAuthenticationApplierFactory ccmAuthenticationApplierFactory; + private ElasticInferenceServiceActionCreator actionCreator; public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - InferenceServiceExtension.InferenceServiceFactoryContext context + InferenceServiceExtension.InferenceServiceFactoryContext context, + CCMAuthenticationApplierFactory ccmAuthApplierFactory ) { - this(factory, serviceComponents, elasticInferenceServiceSettings, context.clusterService()); + this(factory, serviceComponents, elasticInferenceServiceSettings, context.clusterService(), ccmAuthApplierFactory); } public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ClusterService clusterService + ClusterService clusterService, + CCMAuthenticationApplierFactory ccmAuthApplierFactory ) { super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); + this.ccmAuthenticationApplierFactory = ccmAuthApplierFactory; + } + + public void init() { + // Wait to initialize the action creator until the sender is constructed + this.actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), ccmAuthenticationApplierFactory); } @Override @@ -166,18 +174,12 @@ protected void doUnifiedCompletionInfer( var completionModel = (ElasticInferenceServiceCompletionModel) model; var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest()); - var errorMessage = constructFailedToSendRequestMessage( - String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER) - ); - var requestManager = ElasticInferenceServiceUnifiedCompletionRequestManager.of( + actionCreator.create( overriddenModel, - getServiceComponents().threadPool(), - currentTraceInfo + currentTraceInfo, + listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate)) ); - var action = new SenderExecutableAction(getSender(), requestManager, errorMessage); - - action.execute(inputs, timeout, listener); } @Override @@ -198,7 +200,7 @@ protected void doInfer( return; } - if (model instanceof ElasticInferenceServiceExecutableActionModel == false) { + if (model instanceof ElasticInferenceServiceModel == false) { listener.onFailure(createInvalidModelException(model)); return; } @@ -208,11 +210,13 @@ protected void doInfer( // generating a different "traceparent" as every task and every REST request creates a new span). var currentTraceInfo = getCurrentTraceInfo(); - ElasticInferenceServiceExecutableActionModel elasticInferenceServiceModel = (ElasticInferenceServiceExecutableActionModel) model; - var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo); + var elasticInferenceServiceModel = (ElasticInferenceServiceModel) model; - var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings); - action.execute(inputs, timeout, listener); + actionCreator.create( + elasticInferenceServiceModel, + currentTraceInfo, + listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate)) + ); } @Override @@ -228,8 +232,6 @@ protected void doChunkedInfer( ActionListener> listener ) { if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel denseTextEmbeddingsModel) { - var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); - List batchedRequests = new EmbeddingRequestChunker<>( inputs, DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, @@ -237,16 +239,24 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + actionCreator.create( + denseTextEmbeddingsModel, + getCurrentTraceInfo(), + request.listener() + .delegateFailureAndWrap( + (delegate, action) -> action.execute( + new EmbeddingsInput(request.batch().inputs(), inputType), + timeout, + delegate + ) + ) + ); } return; } if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) { - var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); - List batchedRequests = new EmbeddingRequestChunker<>( inputs, SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE, @@ -254,8 +264,18 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + actionCreator.create( + sparseTextEmbeddingsModel, + getCurrentTraceInfo(), + request.listener() + .delegateFailureAndWrap( + (delegate, action) -> action.execute( + new EmbeddingsInput(request.batch().inputs(), inputType), + timeout, + delegate + ) + ) + ); } return; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceExecutableActionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceExecutableActionModel.java deleted file mode 100644 index 974235bff0d59..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceExecutableActionModel.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic; - -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; - -import java.util.Map; - -public abstract class ElasticInferenceServiceExecutableActionModel extends ElasticInferenceServiceModel { - - public ElasticInferenceServiceExecutableActionModel( - ModelConfigurations configurations, - ModelSecrets secrets, - ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings, - ElasticInferenceServiceComponents elasticInferenceServiceComponents - ) { - super(configurations, secrets, rateLimitServiceSettings, elasticInferenceServiceComponents); - } - - public ElasticInferenceServiceExecutableActionModel( - ElasticInferenceServiceExecutableActionModel model, - ServiceSettings serviceSettings - ) { - super(model, serviceSettings); - } - - public abstract ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings); -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index a8b1cd8b11b7e..f038a81ad0ede 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceSparseEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; @@ -39,10 +40,9 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast private static final ResponseHandler HANDLER = createSparseEmbeddingsHandler(); private final ElasticInferenceServiceSparseEmbeddingsModel model; - private final Truncator truncator; - private final TraceContext traceContext; + private final CCMAuthenticationApplierFactory.AuthApplier authApplier; private static ResponseHandler createSparseEmbeddingsHandler() { return new ElasticInferenceServiceResponseHandler( @@ -54,12 +54,14 @@ private static ResponseHandler createSparseEmbeddingsHandler() { public ElasticInferenceServiceSparseEmbeddingsRequestManager( ElasticInferenceServiceSparseEmbeddingsModel model, ServiceComponents serviceComponents, - TraceContext traceContext + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { super(serviceComponents.threadPool(), model); this.model = model; this.truncator = serviceComponents.truncator(); this.traceContext = traceContext; + this.authApplier = authApplier; } @Override @@ -81,7 +83,8 @@ public void execute( model, traceContext, requestMetadata(), - inputType + inputType, + authApplier ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedCompletionRequestManager.java index f9d4c96ae32d0..c8dfc0b938232 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedCompletionRequestManager.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; @@ -34,26 +35,31 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas public static ElasticInferenceServiceUnifiedCompletionRequestManager of( ElasticInferenceServiceCompletionModel model, ThreadPool threadPool, - TraceContext traceContext + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { return new ElasticInferenceServiceUnifiedCompletionRequestManager( Objects.requireNonNull(model), Objects.requireNonNull(threadPool), - Objects.requireNonNull(traceContext) + Objects.requireNonNull(traceContext), + Objects.requireNonNull(authApplier) ); } private final ElasticInferenceServiceCompletionModel model; private final TraceContext traceContext; + private final CCMAuthenticationApplierFactory.AuthApplier authApplier; private ElasticInferenceServiceUnifiedCompletionRequestManager( ElasticInferenceServiceCompletionModel model, ThreadPool threadPool, - TraceContext traceContext + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { super(threadPool, model); this.model = model; this.traceContext = traceContext; + this.authApplier = authApplier; } @Override @@ -68,7 +74,8 @@ public void execute( inferenceInputs.castTo(UnifiedChatInput.class), model, traceContext, - requestMetadata() + requestMetadata(), + authApplier ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index e980d7f713495..6f32eae1462cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -7,105 +7,48 @@ package org.elasticsearch.xpack.inference.services.elastic.action; -import org.elasticsearch.common.Strings; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; -import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; -import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest; -import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; -import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; -public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { - - static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler( - "elastic dense text embedding", - ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse - ); - - static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler( - "elastic rerank", - (request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response) - ); +public class ElasticInferenceServiceActionCreator { private final Sender sender; - private final ServiceComponents serviceComponents; + private final CCMAuthenticationApplierFactory ccmAuthenticationApplierFactory; - private final TraceContext traceContext; - - public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) { + public ElasticInferenceServiceActionCreator( + Sender sender, + ServiceComponents serviceComponents, + CCMAuthenticationApplierFactory ccmAuthenticationApplierFactory + ) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.traceContext = traceContext; - } - - @Override - public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) { - var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext); - var errorMessage = constructFailedToSendRequestMessage( - Strings.format("%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER) - ); - return new SenderExecutableAction(sender, requestManager, errorMessage); - } - - @Override - public ExecutableAction create(ElasticInferenceServiceRerankModel model) { - var threadPool = serviceComponents.threadPool(); - var requestManager = new GenericRequestManager<>( - threadPool, - model, - RERANK_HANDLER, - (rerankInput) -> new ElasticInferenceServiceRerankRequest( - rerankInput.getQuery(), - rerankInput.getChunks(), - rerankInput.getTopN(), - model, - traceContext, - extractRequestMetadataFromThreadContext(threadPool.getThreadContext()) - ), - QueryAndDocsInputs.class - ); - var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)); - return new SenderExecutableAction(sender, requestManager, errorMessage); + this.ccmAuthenticationApplierFactory = Objects.requireNonNull(ccmAuthenticationApplierFactory); } - @Override - public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model) { - var threadPool = serviceComponents.threadPool(); - - var manager = new GenericRequestManager<>( - threadPool, - model, - DENSE_TEXT_EMBEDDINGS_HANDLER, - (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( - model, - embeddingsInput.getTextInputs(), - traceContext, - extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), - embeddingsInput.getInputType() - ), - EmbeddingsInput.class - ); - - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings"); - return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage); + public void create( + T model, + TraceContext traceContext, + ActionListener listener + ) { + var authListener = listener.delegateFailureAndWrap((delegate, applier) -> { + var strategy = ModelStrategyFactory.getStrategy(model); + var requestManager = strategy.createRequestManager(model, serviceComponents, traceContext, applier); + delegate.onResponse( + new SenderExecutableAction(sender, requestManager, constructFailedToSendRequestMessage(strategy.requestDescription())) + ); + }); + + ccmAuthenticationApplierFactory.getAuthenticationApplier(authListener); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java deleted file mode 100644 index 4f8a9c9ec20a4..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.action; - -import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; - -public interface ElasticInferenceServiceActionVisitor { - - ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); - - ExecutableAction create(ElasticInferenceServiceRerankModel model); - - ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model); -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ModelStrategyFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ModelStrategyFactory.java new file mode 100644 index 0000000000000..7e13384d9ad8b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ModelStrategyFactory.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.action; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUnifiedCompletionRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; + +record ModelStrategyFactory(ServiceComponents serviceComponents) { + + public interface Strategy { + RequestManager createRequestManager( + T model, + ServiceComponents serviceComponents, + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ); + + String requestDescription(); + } + + private static final String SPARSE_EMBEDDINGS_REQUEST_DESCRIPTION = Strings.format( + "%s sparse embeddings", + ELASTIC_INFERENCE_SERVICE_IDENTIFIER + ); + + private static final Strategy SPARSE_EMBEDDINGS_STRATEGY = new Strategy<>() { + @Override + public RequestManager createRequestManager( + ElasticInferenceServiceSparseEmbeddingsModel model, + ServiceComponents serviceComponents, + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ) { + return new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext, authApplier); + } + + @Override + public String requestDescription() { + return SPARSE_EMBEDDINGS_REQUEST_DESCRIPTION; + } + }; + + private static final String RERANK_REQUEST_DESCRIPTION = Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER); + + private static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler( + RERANK_REQUEST_DESCRIPTION, + (request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response) + ); + + private static final Strategy RERANK_STRATEGY = new Strategy<>() { + @Override + public RequestManager createRequestManager( + ElasticInferenceServiceRerankModel model, + ServiceComponents serviceComponents, + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ) { + var metadata = extractRequestMetadataFromThreadContext(serviceComponents.threadPool().getThreadContext()); + return new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + RERANK_HANDLER, + (rerankInput) -> new ElasticInferenceServiceRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getTopN(), + model, + traceContext, + metadata, + authApplier + ), + QueryAndDocsInputs.class + ); + } + + @Override + public String requestDescription() { + return RERANK_REQUEST_DESCRIPTION; + } + }; + + private static final String DENSE_TEXT_EMBEDDINGS_REQUEST_DESCRIPTION = Strings.format( + "%s dense text embeddings", + ELASTIC_INFERENCE_SERVICE_IDENTIFIER + ); + + private static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler( + DENSE_TEXT_EMBEDDINGS_REQUEST_DESCRIPTION, + ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse + ); + + private static final Strategy EMBEDDING_STRATEGY = new Strategy<>() { + @Override + public RequestManager createRequestManager( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + ServiceComponents serviceComponents, + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ) { + var metadata = extractRequestMetadataFromThreadContext(serviceComponents.threadPool().getThreadContext()); + return new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + DENSE_TEXT_EMBEDDINGS_HANDLER, + (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( + model, + embeddingsInput.getTextInputs(), + traceContext, + metadata, + embeddingsInput.getInputType(), + authApplier + ), + EmbeddingsInput.class + ); + } + + @Override + public String requestDescription() { + return DENSE_TEXT_EMBEDDINGS_REQUEST_DESCRIPTION; + } + }; + + private static final String CHAT_COMPLETIONS_REQUEST_DESCRIPTION = Strings.format( + "%s chat completions", + ELASTIC_INFERENCE_SERVICE_IDENTIFIER + ); + + private static final Strategy CHAT_COMPLETIONS_STRATEGY = new Strategy<>() { + @Override + public RequestManager createRequestManager( + ElasticInferenceServiceCompletionModel model, + ServiceComponents serviceComponents, + TraceContext traceContext, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ) { + return ElasticInferenceServiceUnifiedCompletionRequestManager.of( + model, + serviceComponents.threadPool(), + traceContext, + authApplier + ); + } + + @Override + public String requestDescription() { + return CHAT_COMPLETIONS_REQUEST_DESCRIPTION; + } + }; + + @SuppressWarnings("unchecked") + public static Strategy getStrategy(T model) { + return switch (model) { + case ElasticInferenceServiceSparseEmbeddingsModel ignored -> (Strategy) SPARSE_EMBEDDINGS_STRATEGY; + case ElasticInferenceServiceRerankModel ignored -> (Strategy) RERANK_STRATEGY; + case ElasticInferenceServiceDenseTextEmbeddingsModel ignored -> (Strategy) EMBEDDING_STRATEGY; + case ElasticInferenceServiceCompletionModel ignored -> (Strategy) CHAT_COMPLETIONS_STRATEGY; + default -> throw new IllegalArgumentException("No strategy found for model type: " + model.getClass().getSimpleName()); + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 70572e2b9c996..3fe4b306a5281 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -29,6 +29,8 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; import java.util.EnumSet; import java.util.List; @@ -39,6 +41,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; @@ -62,6 +65,8 @@ public class AuthorizationPoller extends AllocatedPersistentTask { private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; private final Client client; private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1); + private final CCMFeature ccmFeature; + private final CCMService ccmService; public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map headers) {} @@ -71,7 +76,9 @@ public record Parameters( Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - Client client + Client client, + CCMFeature ccmFeature, + CCMService ccmService ) {} public static AuthorizationPoller create(TaskFields taskFields, Parameters parameters) { @@ -87,6 +94,8 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { parameters.elasticInferenceServiceSettings, parameters.modelRegistry, parameters.client, + parameters.ccmFeature, + parameters.ccmService, null ); } @@ -100,6 +109,8 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, Client client, + CCMFeature ccmFeature, + CCMService ccmService, // this is a hack to facilitate testing Runnable callback ) { @@ -113,6 +124,8 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { ); this.modelRegistry = Objects.requireNonNull(modelRegistry); this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN); + this.ccmFeature = Objects.requireNonNull(ccmFeature); + this.ccmService = Objects.requireNonNull(ccmService); this.callback = callback; } @@ -149,18 +162,24 @@ protected void init( @Override protected void onCancelled() { - shutdown(); - markAsCompleted(); + shutdownInternal(this::markAsCompleted); } private void shutdownAndMarkTaskAsFailed(Exception e) { - shutdown(); - markAsFailed(e); + shutdownInternal(() -> markAsFailed(e)); } // default for testing void shutdown() { - shutdown.set(true); + shutdownInternal(() -> {}); + } + + private void shutdownInternal(Runnable completionRunnable) { + if (shutdown.compareAndSet(false, true)) { + // Marking a task as completed and then failed (or vice versa) results in an exception, + // so we need to ensure only one is called. + completionRunnable.run(); + } var authTask = lastAuthTask.get(); if (authTask != null) { @@ -219,10 +238,6 @@ private void scheduleAndSendAuthorizationRequest() { // default for testing void sendAuthorizationRequest() { - if (modelRegistry.isReady() == false) { - return; - } - var finalListener = ActionListener.running(() -> { if (callback != null) { callback.run(); @@ -233,18 +248,85 @@ void sendAuthorizationRequest() { delegate.onResponse(null); }); + shouldSendAuthRequest(ActionListener.wrap(action -> action.accept(finalListener), e -> { + logger.atWarn().withThrowable(e).log("Failed determining whether to send authorization request"); + finalListener.onFailure(e); + })); + } + + private class ShutdownAction implements Consumer> { + @Override + public void accept(ActionListener listener) { + logger.info("Skipping sending authorization request and completing task, because poller is shutting down"); + // We should already be shutdown, so this should just be a noop + shutdownInternal(AuthorizationPoller.this::markAsCompleted); + listener.onResponse(null); + } + } + + private record RegistryNotReadyAction() implements Consumer> { + @Override + public void accept(ActionListener listener) { + logger.info("Skipping sending authorization request, because model registry is not ready"); + listener.onResponse(null); + } + } + + private class SendAuthRequestAction implements Consumer> { + @Override + public void accept(ActionListener listener) { + sendRequest(listener); + } + } + + private class CCMDisabledAction implements Consumer> { + @Override + public void accept(ActionListener listener) { + logger.info("Skipping sending authorization request and completing task, because CCM is not enabled"); + shutdownInternal(AuthorizationPoller.this::markAsCompleted); + listener.onResponse(null); + } + } + + private void shouldSendAuthRequest(ActionListener>> listener) { + if (shutdown.get()) { + listener.onResponse(new ShutdownAction()); + return; + } + if (modelRegistry.isReady() == false) { + listener.onResponse(new RegistryNotReadyAction()); + return; + } + if (ccmFeature.isCcmSupportedEnvironment() == false) { + listener.onResponse(new SendAuthRequestAction()); + return; + } + + ccmService.isEnabled(listener.delegateFailureAndWrap((delegate, enabled) -> { + if (enabled == null || enabled == false) { + delegate.onResponse(new CCMDisabledAction()); + return; + } + delegate.onResponse(new SendAuthRequestAction()); + })); + } + + private void sendRequest(ActionListener listener) { SubscribableListener.newForked( authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) ) .andThenApply(this::getNewInferenceEndpointsToStore) .andThen((storeListener, newInferenceIds) -> storePreconfiguredModels(newInferenceIds, storeListener)) - .addListener(finalListener); + .addListener(listener); } private Set getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { + logger.debug("Received authorization response, {}", authModel); var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); + logger.debug("Authorization entity limited to service task types, {}", scopedAuthModel); var authorizedModelIds = scopedAuthModel.getAuthorizedModelIds(); + logger.debug("Authorized model IDs from EIS: {}", authorizedModelIds); var existingInferenceIds = modelRegistry.getInferenceIds(); var newInferenceIds = authorizedModelIds.stream() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index bca830eb7b948..a0442b35b290c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -10,14 +10,23 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.FixForMultiProject; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTaskParams; @@ -25,19 +34,31 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.RemoteTransportException; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.inference.common.BroadcastMessageAction; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; +/** + * Handles creating a persistent task that will periodically poll the Elastic Inference Service for which models are authorized. + * A cluster state listener is run on each node to ensure that the persistent task is created. Only one task will exist within the cluster. + * The task will only be created if CCM cannot be configured or if CCM is configurable (for an on-prem cluster) and is enabled. + * When a user enables CCM the logic will immediately try to create the persistent task + * to avoid having to wait for the next cluster update. + */ public class AuthorizationTaskExecutor extends PersistentTasksExecutor implements ClusterStateListener { private static final Logger logger = LogManager.getLogger(AuthorizationTaskExecutor.class); @@ -46,18 +67,17 @@ public class AuthorizationTaskExecutor extends PersistentTasksExecutor currentTask = new AtomicReference<>(); + private final AtomicBoolean running = new AtomicBoolean(false); public static AuthorizationTaskExecutor create(ClusterService clusterService, AuthorizationPoller.Parameters parameters) { Objects.requireNonNull(clusterService); Objects.requireNonNull(parameters); - var executor = new AuthorizationTaskExecutor( + return new AuthorizationTaskExecutor( clusterService, new PersistentTasksService(clusterService, parameters.serviceComponents().threadPool(), parameters.client()), parameters ); - executor.init(); - return executor; } // default for testing @@ -72,14 +92,99 @@ public static AuthorizationTaskExecutor create(ClusterService clusterService, Au this.pollerParameters = Objects.requireNonNull(pollerParameters); } - // default for testing - void init() { + /** + * Starts the authorization task executor without starting the persistent task. The persistent task will be created + * when the next cluster state change event occurs. This is needed because + * we can't start the persistent task until after the plugin has finished initializing. Otherwise, we'll + * get an error indicating that it isn't aware of whether the task is a cluster scoped task. + */ + public synchronized void startAndLazyCreateTask() { + startInternal(false); + } + + /** + * Starts the authorization task executor and creates the persistent task if it doesn't already exist. This should only be called from + * a context where the cluster state is already initialized. Don't call this from the plugin + * {@link org.elasticsearch.xpack.inference.InferencePlugin#createComponents(Plugin.PluginServices)}. Use + * {@link #startAndLazyCreateTask()} instead. + */ + public synchronized void startAndImmediatelyCreateTask() { + startInternal(true); + } + + private void startInternal(boolean createPersistentTask) { + var eisUrl = pollerParameters.elasticInferenceServiceSettings().getElasticInferenceServiceUrl(); + + logger.info("Authorization task executor EIS URL: [{}]", eisUrl); + // If the EIS url is not configured, then we won't be able to interact with the service, so don't start the task. - if (Strings.isNullOrEmpty(pollerParameters.elasticInferenceServiceSettings().getElasticInferenceServiceUrl()) == false) { + if (Strings.isNullOrEmpty(eisUrl) == false && running.compareAndSet(false, true)) { + logger.info("Starting authorization task executor"); + + if (createPersistentTask) { + sendStartRequest(clusterService.state()); + } + clusterService.addListener(this); } } + private void sendStartRequest(@Nullable ClusterState state) { + if (running.get() == false || authorizationTaskExists(state)) { + return; + } + + logger.info("Creating authorization poller task"); + persistentTasksService.sendClusterStartRequest( + TASK_NAME, + TASK_NAME, + new AuthorizationTaskParams(), + TimeValue.THIRTY_SECONDS, + ActionListener.wrap( + persistentTask -> logger.info("Finished creating authorization poller task, id {}", persistentTask.getId()), + exception -> { + var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception; + if (thrownException instanceof ResourceAlreadyExistsException == false) { + logger.error("Failed to create authorization poller task", exception); + } + } + ) + ); + } + + private static boolean authorizationTaskExists(@Nullable ClusterState state) { + if (state == null) { + return false; + } + + return ClusterPersistentTasksCustomMetadata.getTaskWithId(state, TASK_NAME) != null; + } + + public synchronized void stop() { + if (running.compareAndSet(true, false)) { + logger.info("Shutting down authorization task executor"); + clusterService.removeListener(this); + + sendStopRequest(); + } + } + + private void sendStopRequest() { + persistentTasksService.sendClusterRemoveRequest( + TASK_NAME, + TimeValue.THIRTY_SECONDS, + ActionListener.wrap( + persistentTask -> logger.info("Stopped authorization poller task, id {}", persistentTask.getId()), + exception -> { + var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception; + if (thrownException instanceof ResourceNotFoundException == false) { + logger.error("Failed to stop authorization poller task", exception); + } + } + ) + ); + } + /** * This method should only be used for testing purposes to get the current running task. */ @@ -92,6 +197,7 @@ protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskPara var authPoller = (AuthorizationPoller) task; currentTask.set(authPoller); authPoller.start(); + logger.info("Started authorization poller task with id {}", task.getId()); } @FixForMultiProject( @@ -120,26 +226,7 @@ protected AuthorizationPoller createTask( @Override public void clusterChanged(ClusterChangedEvent event) { - if (authorizationTaskExists(event)) { - return; - } - - persistentTasksService.sendClusterStartRequest( - TASK_NAME, - TASK_NAME, - new AuthorizationTaskParams(), - TimeValue.THIRTY_SECONDS, - ActionListener.wrap(persistentTask -> logger.debug("Created authorization poller task"), exception -> { - var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception; - if (thrownException instanceof ResourceAlreadyExistsException == false) { - logger.error("Failed to create authorization poller task", exception); - } - }) - ); - } - - private static boolean authorizationTaskExists(ClusterChangedEvent event) { - return ClusterPersistentTasksCustomMetadata.getTaskWithId(event.state(), TASK_NAME) != null; + sendStartRequest(event.state()); } public static List getNamedXContentParsers() { @@ -157,4 +244,50 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(PersistentTaskParams.class, AuthorizationPoller.TASK_NAME, AuthorizationTaskParams::new) ); } + + /** + * This action is used to broadcast to all the nodes that the authorization task executor should start or stop. + * This is specifically useful for CCM, since whether to do the polling depends on the CCM + * configuration to exist first. + */ + public static class Action extends BroadcastMessageAction { + public static final String NAME = "cluster:internal/xpack/inference/update_authorization_task"; + public static final ActionType INSTANCE = new ActionType<>(NAME); + + private final AuthorizationTaskExecutor authorizationTaskExecutor; + + @Inject + public Action( + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + AuthorizationTaskExecutor authorizationTaskExecutor + ) { + super(NAME, clusterService, transportService, actionFilters, Message::new); + this.authorizationTaskExecutor = authorizationTaskExecutor; + } + + @Override + protected void receiveMessage(Message message) { + if (message.enable()) { + authorizationTaskExecutor.startAndImmediatelyCreateTask(); + } else { + authorizationTaskExecutor.stop(); + } + } + } + + public record Message(boolean enable) implements Writeable { + public static final Message ENABLE_MESSAGE = new Message(true); + public static final Message DISABLE_MESSAGE = new Message(false); + + public Message(StreamInput in) throws IOException { + this(in.readBoolean()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(enable); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 02800105ef83d..a86866c03fb2c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -54,18 +55,32 @@ private static ResponseHandler createAuthResponseHandler() { private final ThreadPool threadPool; private final Logger logger; private final CountDownLatch requestCompleteLatch = new CountDownLatch(1); - - public ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseUrl, ThreadPool threadPool) { - this.baseUrl = baseUrl; - this.threadPool = Objects.requireNonNull(threadPool); - logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class); + private CCMAuthenticationApplierFactory authFactory; + + public ElasticInferenceServiceAuthorizationRequestHandler( + @Nullable String baseUrl, + ThreadPool threadPool, + CCMAuthenticationApplierFactory authFactory + ) { + this( + baseUrl, + Objects.requireNonNull(threadPool), + LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class), + authFactory + ); } // only use for testing - ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseUrl, ThreadPool threadPool, Logger logger) { + ElasticInferenceServiceAuthorizationRequestHandler( + @Nullable String baseUrl, + ThreadPool threadPool, + Logger logger, + CCMAuthenticationApplierFactory authFactory + ) { this.baseUrl = baseUrl; this.threadPool = Objects.requireNonNull(threadPool); this.logger = Objects.requireNonNull(logger); + this.authFactory = Objects.requireNonNull(authFactory); } /** @@ -91,25 +106,34 @@ public void getAuthorization(ActionListenerandThen((authListener) -> { - var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); - var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata); - sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); - }).andThenApply(authResult -> { - if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { - logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); - return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity); - } - - var errorMessage = Strings.format( - "%s Received an invalid response type from the Elastic Inference Service: %s", - FAILED_TO_RETRIEVE_MESSAGE, - authResult.getClass().getSimpleName() - ); - - logger.warn(errorMessage); - throw new ElasticsearchException(errorMessage); - }).addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown)); + SubscribableListener.newForked(sender::startAsynchronously) + .andThen(authFactory::getAuthenticationApplier) + .andThen((authListener, authApplier) -> { + var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); + var request = new ElasticInferenceServiceAuthorizationRequest( + baseUrl, + getCurrentTraceInfo(), + requestMetadata, + authApplier + ); + sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); + }) + .andThenApply(authResult -> { + if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { + logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); + return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity); + } + + var errorMessage = Strings.format( + "%s Received an invalid response type from the Elastic Inference Service: %s", + FAILED_TO_RETRIEVE_MESSAGE, + authResult.getClass().getSimpleName() + ); + + logger.warn(errorMessage); + throw new ElasticsearchException(errorMessage); + }) + .addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown)); } catch (Exception e) { logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); requestCompleteLatch.countDown(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactory.java new file mode 100644 index 0000000000000..607167ad8878b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactory.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; + +import java.util.Objects; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_CCM_PATH; + +/** + * Returns a class to handle modifying the HTTP requests with the appropriate CCM authentication information if CCM is configured. + */ +public class CCMAuthenticationApplierFactory { + + public static final NoopApplier NOOP_APPLIER = new NoopApplier(); + + private final CCMFeature ccmFeature; + private final CCMService ccmService; + + public CCMAuthenticationApplierFactory(CCMFeature ccmFeature, CCMService ccmService) { + this.ccmFeature = Objects.requireNonNull(ccmFeature); + this.ccmService = Objects.requireNonNull(ccmService); + } + + public interface AuthApplier extends Function {} + + public void getAuthenticationApplier(ActionListener listener) { + if (ccmFeature.isCcmSupportedEnvironment() == false) { + listener.onResponse(NOOP_APPLIER); + return; + } + + SubscribableListener.newForked(ccmService::isEnabled).andThen((ccmModelListener, enabled) -> { + if (enabled == null || enabled == false) { + listener.onFailure( + new ElasticsearchStatusException( + "Cloud connected mode is not configured, please configure it using PUT {} " + + "before accessing the Elastic Inference Service.", + RestStatus.BAD_REQUEST, + INFERENCE_CCM_PATH + ) + ); + return; + } + + ccmService.getConfiguration(ccmModelListener); + }).andThenApply(ccmModel -> new AuthenticationHeaderApplier(ccmModel.apiKey())).addListener(listener); + } + + /** + * If CCM is configured and enabled this class will apply the appropriate authentication header to the request. + */ + public record AuthenticationHeaderApplier(SecureString apiKey) implements AuthApplier { + public AuthenticationHeaderApplier(String apiKey) { + this(new SecureString(Objects.requireNonNull(apiKey).toCharArray())); + } + + @Override + public HttpRequestBase apply(HttpRequestBase request) { + request.setHeader(createAuthBearerHeader(apiKey)); + return request; + } + } + + /** + * If CCM is not configured this class will not modify the request because no authentication is necessary since mTLS certs are used. + */ + public record NoopApplier() implements AuthApplier { + @Override + public HttpRequestBase apply(HttpRequestBase request) { + return request; + } + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeature.java index cf9256da9dfc8..17f63478d7322 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeature.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeature.java @@ -17,13 +17,13 @@ public class CCMFeature { RestStatus.FORBIDDEN ); - private final boolean allowConfiguringCcm; + private final boolean isCcmSupportedEnvironment; public CCMFeature(Settings settings) { - allowConfiguringCcm = CCMSettings.ALLOW_CONFIGURING_CCM.get(settings); + isCcmSupportedEnvironment = CCMSettings.CCM_SUPPORTED_ENVIRONMENT.get(settings); } - public boolean allowConfiguringCcm() { - return allowConfiguringCcm && CCMFeatureFlag.FEATURE_FLAG.isEnabled(); + public boolean isCcmSupportedEnvironment() { + return isCcmSupportedEnvironment && CCMFeatureFlag.FEATURE_FLAG.isEnabled(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMInformedSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMInformedSettings.java new file mode 100644 index 0000000000000..9b095a64e29bd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMInformedSettings.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; + +import java.util.Objects; + +/** + * Wraps the settings to default to a static URL if the user hasn't already set one using the + * {@link ElasticInferenceServiceSettings#ELASTIC_INFERENCE_SERVICE_URL} setting. + */ +public class CCMInformedSettings extends ElasticInferenceServiceSettings { + static final String DEFAULT_CCM_URL = "https://inference.us-east-1.aws.svc.elastic.cloud"; + + private final CCMFeature ccmFeature; + + public CCMInformedSettings(Settings settings, CCMFeature ccmFeature) { + super(settings); + this.ccmFeature = Objects.requireNonNull(ccmFeature); + } + + @Override + public String getElasticInferenceServiceUrl() { + String urlFromSettings = super.getElasticInferenceServiceUrl(); + if (ccmFeature.isCcmSupportedEnvironment() == false || Strings.isNullOrEmpty(urlFromSettings) == false) { + return urlFromSettings; + } + + return DEFAULT_CCM_URL; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMService.java index 1d3f7445682b9..028f0f693146e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMService.java @@ -7,18 +7,28 @@ package org.elasticsearch.xpack.inference.services.elastic.ccm; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import java.util.Objects; public class CCMService { + private static final Logger logger = LogManager.getLogger(CCMService.class); + private final CCMPersistentStorageService ccmPersistentStorageService; + private final Client client; - public CCMService(CCMPersistentStorageService ccmPersistentStorageService) { + public CCMService(CCMPersistentStorageService ccmPersistentStorageService, Client client) { this.ccmPersistentStorageService = Objects.requireNonNull(ccmPersistentStorageService); - // TODO initialize class to handle storing whether CCM is enabled + this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN); // TODO initialize the cache for the CCM configuration } @@ -37,8 +47,23 @@ public void isEnabled(ActionListener listener) { } public void storeConfiguration(CCMModel model, ActionListener listener) { + SubscribableListener.newForked(storeListener -> ccmPersistentStorageService.store(model, storeListener)) + .andThen( + enableAuthExecutorListener -> client.execute( + AuthorizationTaskExecutor.Action.INSTANCE, + AuthorizationTaskExecutor.Action.request(AuthorizationTaskExecutor.Message.ENABLE_MESSAGE, null), + ActionListener.wrap(ack -> { + logger.debug("Successfully enabled authorization task executor"); + enableAuthExecutorListener.onResponse(null); + }, e -> { + logger.atDebug().withThrowable(e).log("Failed to enable authorization task executor"); + enableAuthExecutorListener.onFailure(e); + }) + ) + ) + .addListener(listener); + // TODO invalidate the cache - ccmPersistentStorageService.store(model, listener); } public void getConfiguration(ActionListener listener) { @@ -47,7 +72,20 @@ public void getConfiguration(ActionListener listener) { } public void disableCCM(ActionListener listener) { + SubscribableListener.newForked( + disableAuthExecutorListener -> client.execute( + AuthorizationTaskExecutor.Action.INSTANCE, + AuthorizationTaskExecutor.Action.request(AuthorizationTaskExecutor.Message.DISABLE_MESSAGE, null), + ActionListener.wrap(ack -> { + logger.debug("Successfully disabled authorization task executor"); + disableAuthExecutorListener.onResponse(null); + }, e -> { + logger.atDebug().withThrowable(e).log("Failed to disable authorization task executor"); + disableAuthExecutorListener.onFailure(e); + }) + ) + ).andThen(ccmPersistentStorageService::delete).addListener(listener); + // TODO implement invalidating the cache - ccmPersistentStorageService.delete(listener); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMSettings.java index 0daf39edaa1f5..c029b342d8e2c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMSettings.java @@ -12,13 +12,13 @@ import java.util.List; public class CCMSettings { - public static final Setting ALLOW_CONFIGURING_CCM = Setting.boolSetting( - "xpack.inference.elastic.allow_configuring_ccm", + public static final Setting CCM_SUPPORTED_ENVIRONMENT = Setting.boolSetting( + "xpack.inference.elastic.ccm_supported_environment", true, Setting.Property.NodeScope ); public static List> getSettingsDefinitions() { - return List.of(ALLOW_CONFIGURING_CCM); + return List.of(CCM_SUPPORTED_ENVIRONMENT); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java index dfbfaf47e2d2e..0b6295eda939e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java @@ -18,17 +18,15 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; -import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel { +public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceModel { private final URI uri; @@ -82,11 +80,6 @@ public ElasticInferenceServiceDenseTextEmbeddingsModel( this.uri = createUri(); } - @Override - public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { - return visitor.create(this); - } - @Override public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings getServiceSettings() { return (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) super.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java index 42176b0a67515..38b5ac438b2af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java @@ -13,6 +13,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; @@ -28,9 +29,10 @@ public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenc public ElasticInferenceServiceAuthorizationRequest( String url, TraceContext traceContext, - ElasticInferenceServiceRequestMetadata requestMetadata + ElasticInferenceServiceRequestMetadata requestMetadata, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { - super(requestMetadata); + super(requestMetadata, authApplier); this.uri = createUri(Objects.requireNonNull(url)); this.traceContextHandler = new TraceContextHandler(traceContext); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java index c5d525822fc72..9899d7a009197 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java @@ -17,6 +17,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; @@ -41,9 +42,10 @@ public ElasticInferenceServiceDenseTextEmbeddingsRequest( List inputs, TraceContext traceContext, ElasticInferenceServiceRequestMetadata metadata, - InputType inputType + InputType inputType, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { - super(metadata); + super(metadata, authApplier); this.inputs = inputs; this.model = Objects.requireNonNull(model); this.uri = model.uri(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java index cea0378aba414..9cef0d84f7da9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java @@ -14,6 +14,9 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; + +import java.util.Objects; import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_ES_VERSION; import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; @@ -21,9 +24,14 @@ public abstract class ElasticInferenceServiceRequest implements Request { private final ElasticInferenceServiceRequestMetadata metadata; + protected final CCMAuthenticationApplierFactory.AuthApplier authApplier; - public ElasticInferenceServiceRequest(ElasticInferenceServiceRequestMetadata metadata) { - this.metadata = metadata; + public ElasticInferenceServiceRequest( + ElasticInferenceServiceRequestMetadata metadata, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ) { + this.metadata = Objects.requireNonNull(metadata); + this.authApplier = Objects.requireNonNull(authApplier); } public ElasticInferenceServiceRequestMetadata getMetadata() { @@ -51,6 +59,8 @@ public final HttpRequest createHttpRequest() { request.addHeader(X_ELASTIC_ES_VERSION, esVersion); } + request = authApplier.apply(request); + return new HttpRequest(request, getInferenceEntityId()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java index 63b26f2a1223b..05ee26f568d63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; @@ -38,9 +39,10 @@ public ElasticInferenceServiceRerankRequest( Integer topN, ElasticInferenceServiceRerankModel model, TraceContext traceContext, - ElasticInferenceServiceRequestMetadata metadata + ElasticInferenceServiceRequestMetadata metadata, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { - super(metadata); + super(metadata, authApplier); this.query = query; this.documents = documents; this.topN = topN; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java index fa6e9ef5ae935..fff7b7f9dbd3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; @@ -43,9 +44,10 @@ public ElasticInferenceServiceSparseEmbeddingsRequest( ElasticInferenceServiceSparseEmbeddingsModel model, TraceContext traceContext, ElasticInferenceServiceRequestMetadata metadata, - InputType inputType + InputType inputType, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { - super(metadata); + super(metadata, authApplier); this.truncator = truncator; this.truncationResult = truncationResult; this.model = Objects.requireNonNull(model); @@ -99,7 +101,8 @@ public Request truncate() { model, traceContextHandler.traceContext(), getMetadata(), - inputType + inputType, + authApplier ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 2f41285a3345b..20388bee77957 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -16,6 +16,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; @@ -34,9 +35,10 @@ public ElasticInferenceServiceUnifiedChatCompletionRequest( UnifiedChatInput unifiedChatInput, ElasticInferenceServiceCompletionModel model, TraceContext traceContext, - ElasticInferenceServiceRequestMetadata requestMetadata + ElasticInferenceServiceRequestMetadata requestMetadata, + CCMAuthenticationApplierFactory.AuthApplier authApplier ) { - super(requestMetadata); + super(requestMetadata, authApplier); this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); this.traceContextHandler = new TraceContextHandler(traceContext); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java index 38e71d74b1716..2bb28f5a74617 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java @@ -17,17 +17,15 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; -import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -public class ElasticInferenceServiceRerankModel extends ElasticInferenceServiceExecutableActionModel { +public class ElasticInferenceServiceRerankModel extends ElasticInferenceServiceModel { private final URI uri; @@ -70,11 +68,6 @@ public ElasticInferenceServiceRerankModel( this.uri = createUri(); } - @Override - public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { - return visitor.create(this); - } - @Override public ElasticInferenceServiceRerankServiceSettings getServiceSettings() { return (ElasticInferenceServiceRerankServiceSettings) super.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java index 4dead9850a423..26143e5b1b760 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java @@ -18,17 +18,15 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; -import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel { +public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceModel { private final URI uri; @@ -82,11 +80,6 @@ public ElasticInferenceServiceSparseEmbeddingsModel( this.uri = createUri(); } - @Override - public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { - return visitor.create(this); - } - @Override public ElasticInferenceServiceSparseEmbeddingsServiceSettings getServiceSettings() { return (ElasticInferenceServiceSparseEmbeddingsServiceSettings) super.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index a5ded520b69ed..11a635cf5f41a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -317,7 +318,8 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce var request = new ElasticInferenceServiceAuthorizationRequest( getUrl(webServer), new TraceContext("", ""), - randomElasticInferenceServiceRequestMetadata() + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER ); var responseHandler = new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/RequestUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/RequestUtilsTests.java index b6690373ec097..2e72cc7a9c61e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/RequestUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/RequestUtilsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.test.ESTestCase; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.hamcrest.Matchers.is; @@ -20,4 +21,8 @@ public void testCreateAuthBearerHeader() { assertThat(header.getName(), is("Authorization")); assertThat(header.getValue(), is("Bearer abc")); } + + public void testBearerToken() { + assertThat(bearerToken("abc"), is("Bearer abc")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java index 58a357684961c..905910ffbf267 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.external.request.elastic; +import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -18,6 +20,7 @@ import java.util.List; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; @@ -67,6 +70,38 @@ public void testTruncate_DoesNotTruncate() throws IOException { assertThat(requestMap.get("top_n"), is(topN)); } + public void testDecorate_HttpRequest_WithAuthorizationHeader() { + var url = "http://eis-gateway.com"; + var query = "query"; + var documents = List.of("document 1", "document 2", "document 3"); + var modelId = "my-model-id"; + var topN = 3; + var secret = "secret"; + + var request = new ElasticInferenceServiceRerankRequest( + query, + documents, + topN, + ElasticInferenceServiceRerankModelTests.createModel(url, modelId), + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomElasticInferenceServiceRequestMetadata(), + new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(secret) + ); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var traceParent = request.getTraceContext().traceParent(); + var traceState = request.getTraceContext().traceState(); + + assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent)); + assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState)); + var headers = httpPost.getHeaders(HttpHeaders.AUTHORIZATION); + assertThat(headers.length, is(1)); + assertThat(headers[0].getValue(), is(bearerToken(secret))); + } + private ElasticInferenceServiceRerankRequest createRequest( String url, String modelId, @@ -82,7 +117,8 @@ private ElasticInferenceServiceRerankRequest createRequest( topN, rerankModel, new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomElasticInferenceServiceRequestMetadata() + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 4b17cab04471a..5e70fe01eee97 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -86,6 +86,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; @@ -1162,12 +1163,15 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - return new ElasticInferenceService( + var service = new ElasticInferenceService( factory, createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), - mockClusterServiceEmpty() + mockClusterServiceEmpty(), + createNoopApplierFactory() ); + service.init(); + return service; } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory) { @@ -1175,11 +1179,14 @@ private ElasticInferenceService createService(HttpRequestSender.Factory senderFa } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory, String elasticInferenceServiceURL) { - return new ElasticInferenceService( + var service = new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), - mockClusterServiceEmpty() + mockClusterServiceEmpty(), + createNoopApplierFactory() ); + service.init(); + return service; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index ab715b7f73e96..6a0d0d6ab8e95 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -10,11 +10,13 @@ import org.apache.http.HttpHeaders; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.TestPlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -23,12 +25,16 @@ import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -46,7 +52,10 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -97,8 +106,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( @@ -120,9 +128,74 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce ) ); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertHeadersWithoutAuth(webServer.requests()); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world")); + assertThat(requestMap.get("model"), is("my-model-id")); + } + } + + private ExecutableAction createAction(Sender sender, ElasticInferenceServiceModel model) { + return createAction(sender, model, createNoopApplierFactory()); + } + + private ExecutableAction createAction(Sender sender, ElasticInferenceServiceModel model, CCMAuthenticationApplierFactory factory) { + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), factory); + var actionCreatorListener = new TestPlainActionFuture(); + actionCreator.create(model, createTraceContext(), actionCreatorListener); + + return actionCreatorListener.actionGet(TIMEOUT); + } + + @SuppressWarnings("unchecked") + public void testExecute_SparseEmbedding_AddsAuthorizationHeader() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); + var secret = "secret-token"; + var action = createAction(sender, model, createApplierFactory(secret)); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + is( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( + List.of( + new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hello", 2.1259406f, "greet", 1.7073475f), false) + ) + ) + ) + ); + + assertHeadersWithAuth(webServer.requests(), secret); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); @@ -133,6 +206,25 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce } } + private static void assertHeadersWithAuth(List requests, String secret) { + var request = assertSingleRequestSent(requests); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo(bearerToken(secret))); + } + + private static MockRequest assertSingleRequestSent(List requests) { + assertThat(requests, hasSize(1)); + var request = requests.get(0); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + return request; + } + + private static void assertHeadersWithoutAuth(List requests) { + var request = assertSingleRequestSent(requests); + assertNull(request.getHeader(HttpHeaders.AUTHORIZATION)); + } + @SuppressWarnings("unchecked") public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOException { // timeout as zero for no retries @@ -158,8 +250,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( @@ -217,8 +308,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc var documents = List.of("document 1", "document 2", "document 3"); var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), modelId); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new QueryAndDocsInputs(query, documents, null, topN, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -237,9 +327,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc ) ); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertHeadersWithoutAuth(webServer.requests()); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); @@ -259,6 +347,74 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc } } + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForRerankAction_WithAuthHeader() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "results": [ + { + "index": 0, + "relevance_score": 0.94 + }, + { + "index": 1, + "relevance_score": 0.21 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var modelId = "my-model-id"; + var topN = 3; + var query = "query"; + var documents = List.of("document 1", "document 2", "document 3"); + + var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), modelId); + var secret = "secret-token"; + var action = createAction(sender, model, createApplierFactory(secret)); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new QueryAndDocsInputs(query, documents, null, topN, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + equalTo( + RankedDocsResultsTests.buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.94f)), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.21f)) + ) + ) + ) + ); + + assertHeadersWithAuth(webServer.requests(), secret); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + + assertThat(requestMap.size(), is(4)); + + assertThat(requestMap.get("documents"), instanceOf(List.class)); + List requestDocuments = (List) requestMap.get("documents"); + assertThat(requestDocuments.get(0), equalTo(documents.get(0))); + assertThat(requestDocuments.get(1), equalTo(documents.get(1))); + assertThat(requestDocuments.get(2), equalTo(documents.get(2))); + + assertThat(requestMap.get("top_n"), equalTo(topN)); + assertThat(requestMap.get("query"), equalTo(query)); + assertThat(requestMap.get("model"), equalTo(modelId)); + } + } + @SuppressWarnings("unchecked") public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -286,8 +442,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( @@ -308,9 +463,67 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() var secondEmbedding = textEmbeddingResults.embeddings().get(1); assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f })); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertHeadersWithoutAuth(webServer.requests()); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world", "second text")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + } + } + + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_WithAuth() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.startSynchronously(); + + String responseJson = """ + { + "data": [ + [ + 2.1259406, + 1.7073475, + 0.9020516 + ], + [ + 1.8342123, + 2.3456789, + 0.7654321 + ] + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + var secret = "secret-token"; + var action = createAction(sender, model, createApplierFactory(secret)); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("hello world", "second text"), InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(2)); + + var firstEmbedding = textEmbeddingResults.embeddings().get(0); + assertThat(firstEmbedding.values(), is(new float[] { 2.1259406f, 1.7073475f, 0.9020516f })); + + var secondEmbedding = textEmbeddingResults.embeddings().get(1); + assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f })); + + assertHeadersWithAuth(webServer.requests(), secret); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(2)); @@ -342,8 +555,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( @@ -397,8 +609,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( @@ -439,8 +650,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new EmbeddingsInput(List.of(), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -487,8 +697,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( @@ -551,8 +760,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { // truncated to 1 token = 3 characters var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id", 1); - var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); - var action = actionCreator.create(model); + var action = createAction(sender, model); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index d0d3a67b2d9d5..75ca0b00a8e12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -37,9 +37,12 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -69,12 +72,105 @@ public void testDoesNotSendAuthorizationRequest_WhenModelRegistryIsNotReady() { ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false), null ); + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); + poller.sendAuthorizationRequest(); verify(authorizationRequestHandler, never()).getAuthorization(any(), any()); + verify(mockPersistentTasksService, never()).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + isNull(), + isNull(), + any(), + any() + ); + } + + public void testDoesNotSendAuthorizationRequest_WhenCCMIsDisabled() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + + var authorizationRequestHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + authorizationRequestHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mock(Client.class), + createMockCCMFeature(true), + createMockCCMService(false), + null + ); + + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); + + poller.sendAuthorizationRequest(); + + verify(authorizationRequestHandler, never()).getAuthorization(any(), any()); + verify(mockPersistentTasksService, times(1)).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + isNull(), + isNull(), + any(), + any() + ); + } + + public void testOnlyMarksCompletedOnce() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + + var authorizationRequestHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + authorizationRequestHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mock(Client.class), + createMockCCMFeature(true), + createMockCCMService(false), + null + ); + + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); + + poller.sendAuthorizationRequest(); + poller.sendAuthorizationRequest(); + + verify(authorizationRequestHandler, never()).getAuthorization(any(), any()); + verify(mockPersistentTasksService, times(1)).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + isNull(), + isNull(), + any(), + any() + ); } public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { @@ -111,9 +207,90 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, + createMockCCMFeature(true), + createMockCCMService(true), + null + ); + + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); + + var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); + + poller.sendAuthorizationRequest(); + verify(mockClient).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + var capturedRequest = requestArgCaptor.getValue(); + assertThat( + capturedRequest.getModels(), + is( + List.of( + PreconfiguredEndpointModelAdapter.createModel( + InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2), + new ElasticInferenceServiceComponents("") + ) + ) + ) + ); + + verify(mockPersistentTasksService, never()).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + any(), + any(), + any(), + any() + ); + } + + public void testSendsAuthorizationRequest_WhenCCMIsNotConfigurable() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + // CCM is not configurable so we should send the request because it doesn't depend on an api key + createMockCCMFeature(false), + createMockCCMService(false), null ); + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); + var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); poller.sendAuthorizationRequest(); @@ -130,6 +307,15 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { ) ) ); + + verify(mockPersistentTasksService, never()).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + any(), + any(), + any(), + any() + ); } public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInferenceIdAlreadyExists() { @@ -166,6 +352,8 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, + createMockCCMFeature(true), + createMockCCMService(true), null ); @@ -208,6 +396,8 @@ public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMa ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, + createMockCCMFeature(true), + createMockCCMService(true), null ); @@ -250,6 +440,8 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, + createMockCCMFeature(true), + createMockCCMService(true), null ); @@ -309,6 +501,8 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, + createMockCCMFeature(true), + createMockCCMService(true), callback ); pollerRef.set(poller); @@ -369,6 +563,8 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw settingsMock, mockRegistry, mockClient, + createMockCCMFeature(true), + createMockCCMService(true), callback ); @@ -392,5 +588,15 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw any() ); verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + + poller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); + verify(mockPersistentTasksService, never()).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + isNull(), + isNull(), + any(), + any() + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorMessageTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorMessageTests.java new file mode 100644 index 0000000000000..f05edefaa0931 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorMessageTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class AuthorizationTaskExecutorMessageTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected AuthorizationTaskExecutor.Message mutateInstanceForVersion( + AuthorizationTaskExecutor.Message instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AuthorizationTaskExecutor.Message::new; + } + + @Override + protected AuthorizationTaskExecutor.Message createTestInstance() { + return new AuthorizationTaskExecutor.Message(randomBoolean()); + } + + @Override + protected AuthorizationTaskExecutor.Message mutateInstance(AuthorizationTaskExecutor.Message instance) throws IOException { + if (instance.enable()) { + return AuthorizationTaskExecutor.Message.DISABLE_MESSAGE; + } else { + return AuthorizationTaskExecutor.Message.ENABLE_MESSAGE; + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java index 15b586d62890d..70bb2801e64e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java @@ -33,12 +33,15 @@ import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class AuthorizationTaskExecutorTests extends ESTestCase { @@ -63,6 +66,226 @@ public void tearDown() throws Exception { terminate(threadPool); } + public void testMultipleCallsToStart_OnlyRegistersOnce() { + var eisUrl = "abc"; + var mockClusterService = createMockEmptyClusterService(); + var executor = new AuthorizationTaskExecutor( + mockClusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) + ) + ); + executor.startAndImmediatelyCreateTask(); + executor.startAndImmediatelyCreateTask(); + + verify(mockClusterService, times(1)).addListener(executor); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testStartLazy_OnlyRegistersOnce_NeverCallsPersistentTaskService() { + var eisUrl = "abc"; + var mockClusterService = createMockEmptyClusterService(); + var executor = new AuthorizationTaskExecutor( + mockClusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) + ) + ); + executor.startAndLazyCreateTask(); + executor.startAndLazyCreateTask(); + + verify(mockClusterService, times(1)).addListener(executor); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + private static ClusterService createMockEmptyClusterService() { + var mockClusterService = mock(ClusterService.class); + when(mockClusterService.state()).thenReturn(ClusterState.EMPTY_STATE); + return mockClusterService; + } + + public void testDoesNotRegisterListener_IfUrlIsEmpty() { + var eisUrl = ""; + var mockClusterService = createMockEmptyClusterService(); + var executor = new AuthorizationTaskExecutor( + mockClusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) + ) + ); + executor.startAndImmediatelyCreateTask(); + executor.startAndImmediatelyCreateTask(); + + verify(mockClusterService, never()).addListener(executor); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testMultipleCallsToStart_AndStop() { + var eisUrl = "abc"; + var mockClusterService = createMockEmptyClusterService(); + var executor = new AuthorizationTaskExecutor( + mockClusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) + ) + ); + executor.startAndImmediatelyCreateTask(); + executor.startAndImmediatelyCreateTask(); + executor.stop(); + executor.stop(); + verify(mockClusterService, times(1)).addListener(executor); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + verify(mockClusterService, times(1)).removeListener(executor); + verify(persistentTasksService, times(1)).sendClusterRemoveRequest(eq(AuthorizationPoller.TASK_NAME), any(), any()); + + executor.startAndImmediatelyCreateTask(); + executor.stop(); + verify(mockClusterService, times(2)).addListener(executor); + verify(persistentTasksService, times(2)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + verify(mockClusterService, times(2)).removeListener(executor); + verify(persistentTasksService, times(2)).sendClusterRemoveRequest(eq(AuthorizationPoller.TASK_NAME), any(), any()); + } + + public void testCallsSendClusterStartRequest_WhenStartIsCalled() { + var eisUrl = "abc"; + var mockClusterService = createMockEmptyClusterService(); + var executor = new AuthorizationTaskExecutor( + mockClusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) + ) + ); + executor.startAndImmediatelyCreateTask(); + + verify(mockClusterService, times(1)).addListener(executor); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testDoesNotCallSendClusterStartRequest_WhenStartIsCalled_WhenItIsAlreadyInClusterState() { + var initialState = initialState(); + var state = ClusterState.builder(initialState) + .metadata( + Metadata.builder(initialState.metadata()) + .putCustom( + ClusterPersistentTasksCustomMetadata.TYPE, + ClusterPersistentTasksCustomMetadata.builder() + .addTask( + AuthorizationPoller.TASK_NAME, + AuthorizationPoller.TASK_NAME, + AuthorizationTaskParams.INSTANCE, + NO_NODE_FOUND + ) + .build() + ) + ) + .build(); + + var mockClusterService = mock(ClusterService.class); + when(mockClusterService.state()).thenReturn(state); + + var eisUrl = "abc"; + var executor = new AuthorizationTaskExecutor( + mockClusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) + ) + ); + executor.startAndImmediatelyCreateTask(); + + verify(mockClusterService, times(1)).addListener(executor); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() { var eisUrl = "abc"; @@ -75,15 +298,18 @@ public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() { mock(Sender.class), ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mock(ModelRegistry.class), - mock(Client.class) + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) ) ); - executor.init(); + executor.startAndImmediatelyCreateTask(); var listener1 = new PlainActionFuture(); clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); listener1.actionGet(TimeValue.THIRTY_SECONDS); - verify(persistentTasksService, times(1)).sendClusterStartRequest( + // 2 because the first call is from the start() and the second is from the cluster state change. + verify(persistentTasksService, times(2)).sendClusterStartRequest( eq(AuthorizationPoller.TASK_NAME), eq(AuthorizationPoller.TASK_NAME), eq(AuthorizationTaskParams.INSTANCE), @@ -124,10 +350,12 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis mock(Sender.class), ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mock(ModelRegistry.class), - mock(Client.class) + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) ) ); - executor.init(); + executor.startAndImmediatelyCreateTask(); var listener = new PlainActionFuture(); clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener); @@ -151,10 +379,12 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis mock(Sender.class), ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mock(ModelRegistry.class), - mock(Client.class) + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) ) ); - executor.init(); + executor.startAndImmediatelyCreateTask(); var listener = new PlainActionFuture(); clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener); @@ -170,30 +400,30 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis public void testDoesNotCreateTask_OnClusterStateChange_WhenItAlreadyExists() { var initialState = initialState(); - var event = new ClusterChangedEvent( - "testClusterChanged", - ClusterState.builder(initialState) - .metadata( - Metadata.builder(initialState.metadata()) - .putCustom( - ClusterPersistentTasksCustomMetadata.TYPE, - ClusterPersistentTasksCustomMetadata.builder() - .addTask( - AuthorizationPoller.TASK_NAME, - AuthorizationPoller.TASK_NAME, - AuthorizationTaskParams.INSTANCE, - NO_NODE_FOUND - ) - .build() - ) - ) - .build(), - ClusterState.EMPTY_STATE - ); + var state = ClusterState.builder(initialState) + .metadata( + Metadata.builder(initialState.metadata()) + .putCustom( + ClusterPersistentTasksCustomMetadata.TYPE, + ClusterPersistentTasksCustomMetadata.builder() + .addTask( + AuthorizationPoller.TASK_NAME, + AuthorizationPoller.TASK_NAME, + AuthorizationTaskParams.INSTANCE, + NO_NODE_FOUND + ) + .build() + ) + ) + .build(); + var event = new ClusterChangedEvent("testClusterChanged", state, ClusterState.EMPTY_STATE); + + var mockClusterService = mock(ClusterService.class); + when(mockClusterService.state()).thenReturn(state); var eisUrl = "abc"; var executor = new AuthorizationTaskExecutor( - clusterService, + mockClusterService, persistentTasksService, new AuthorizationPoller.Parameters( createWithEmptySettings(threadPool), @@ -201,10 +431,12 @@ public void testDoesNotCreateTask_OnClusterStateChange_WhenItAlreadyExists() { mock(Sender.class), ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mock(ModelRegistry.class), - mock(Client.class) + mock(Client.class), + createMockCCMFeature(false), + createMockCCMService(false) ) ); - executor.init(); + executor.startAndImmediatelyCreateTask(); executor.clusterChanged(event); verify(persistentTasksService, never()).sendClusterStartRequest( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index e3d24ea2ec8f7..03a20c5bfeefc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization; +import org.apache.hc.core5.http.HttpHeaders; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; @@ -16,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -39,7 +41,10 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -74,7 +79,7 @@ public void shutdown() throws IOException { public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger, createNoopApplierFactory()); try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -96,7 +101,7 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger, createNoopApplierFactory()); try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -119,7 +124,12 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + logger, + createNoopApplierFactory() + ); try (var sender = senderFactory.createSender()) { String responseJson = """ @@ -165,7 +175,12 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + logger, + createNoopApplierFactory() + ); try (var sender = senderFactory.createSender()) { String responseJson = """ @@ -194,14 +209,76 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { var message = loggerArgsCaptor.getValue(); assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); + + assertNoAuthHeader(webServer.requests()); } } + private static void assertNoAuthHeader(List requests) { + assertThat(requests.size(), is(1)); + assertNull(requests.get(0).getHeader(HttpHeaders.AUTHORIZATION)); + } + + public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws IOException { + var secret = "secret-token"; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + var logger = mock(Logger.class); + + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + logger, + createApplierFactory(secret) + ); + + try (var sender = senderFactory.createSender()) { + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse", "chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorization(listener, sender); + + var authResponse = listener.actionGet(TIMEOUT); + assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertTrue(authResponse.isAuthorized()); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger, times(1)).debug(loggerArgsCaptor.capture()); + + var message = loggerArgsCaptor.getValue(); + assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); + + assertAuthHeader(webServer.requests(), secret); + } + } + + private static void assertAuthHeader(List requests, String secret) { + assertThat(requests.size(), is(1)); + assertThat(requests.get(0).getHeader(HttpHeaders.AUTHORIZATION), is(bearerToken(secret))); + } + public void testGetAuthorization_OnResponseCalledOnce() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + eisGatewayUrl, + threadPool, + logger, + createNoopApplierFactory() + ); PlainActionFuture listener = new PlainActionFuture<>(); ActionListener onlyOnceListener = ActionListener.assertOnce(listener); @@ -246,7 +323,7 @@ public void testGetAuthorization_InvalidResponse() throws IOException { }).when(senderMock).sendWithoutQueuing(any(), any(), any(), any(), any()); var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger, createNoopApplierFactory()); try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactoryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactoryTests.java new file mode 100644 index 0000000000000..19b2cb7d7cf61 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactoryTests.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpGet; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.TestPlainActionFuture; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CCMAuthenticationApplierFactoryTests extends ESTestCase { + + public static CCMAuthenticationApplierFactory createNoopApplierFactory() { + var mockFactory = mock(CCMAuthenticationApplierFactory.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(CCMAuthenticationApplierFactory.NOOP_APPLIER); + return Void.TYPE; + }).when(mockFactory).getAuthenticationApplier(any()); + + return mockFactory; + } + + public static CCMAuthenticationApplierFactory createApplierFactory(String secret) { + var mockFactory = mock(CCMAuthenticationApplierFactory.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(secret)); + return Void.TYPE; + }).when(mockFactory).getAuthenticationApplier(any()); + + return mockFactory; + } + + public void testNoopApplierReturnsSameRequest() { + var applier = CCMAuthenticationApplierFactory.NOOP_APPLIER; + var request = new HttpGet("http://localhost"); + var result = applier.apply(request); + assertThat(result, sameInstance(request)); + } + + public void testAuthenticationHeaderApplierSetsAuthorizationHeader() { + var secret = "my-secret"; + var applier = new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(secret); + var request = new HttpGet("http://localhost"); + applier.apply(request); + assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(bearerToken(secret))); + } + + public void testGetAuthenticationApplier_ReturnsNoopWhenConfiguringCCMIsDisabled() { + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(false); + var ccmService = mock(CCMService.class); + + var factory = new CCMAuthenticationApplierFactory(ccmFeature, ccmService); + var listener = new TestPlainActionFuture(); + factory.getAuthenticationApplier(listener); + + assertThat(listener.actionGet(TimeValue.THIRTY_SECONDS), sameInstance(CCMAuthenticationApplierFactory.NOOP_APPLIER)); + } + + public void testGetAuthenticationApplier_ReturnsFailure_WhenConfiguringCCMIsEnabled_ButHasNotBeenConfiguredYet() { + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(true); + + var ccmService = mock(CCMService.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(false); + return Void.TYPE; + }).when(ccmService).isEnabled(any()); + + var factory = new CCMAuthenticationApplierFactory(ccmFeature, ccmService); + var listener = new TestPlainActionFuture(); + factory.getAuthenticationApplier(listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.THIRTY_SECONDS)); + assertThat( + exception.getMessage(), + containsString( + "Cloud connected mode is not configured, please configure it using PUT _inference/_ccm " + + "before accessing the Elastic Inference Service." + ) + ); + } + + public void testGetAuthenticationApplier_ReturnsApiKey_WhenConfiguringCCMIsEnabled_AndSet() { + var secret = "secret"; + + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(true); + + var ccmService = mock(CCMService.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return Void.TYPE; + }).when(ccmService).isEnabled(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CCMModel(new SecureString(secret.toCharArray()))); + return Void.TYPE; + }).when(ccmService).getConfiguration(any()); + + var factory = new CCMAuthenticationApplierFactory(ccmFeature, ccmService); + var listener = new TestPlainActionFuture(); + factory.getAuthenticationApplier(listener); + + var applier = listener.actionGet(TimeValue.THIRTY_SECONDS); + assertThat(applier, is(new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(secret))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeatureTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeatureTests.java new file mode 100644 index 0000000000000..9f2d763cc6fb7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMFeatureTests.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.elasticsearch.test.ESTestCase; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CCMFeatureTests extends ESTestCase { + + public static CCMFeature createMockCCMFeature(boolean enabled) { + var mockFeature = mock(CCMFeature.class); + when(mockFeature.isCcmSupportedEnvironment()).thenReturn(enabled); + + return mockFeature; + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMInformedSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMInformedSettingsTests.java new file mode 100644 index 0000000000000..6f2a1232fa1e5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMInformedSettingsTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; +import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMInformedSettings.DEFAULT_CCM_URL; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CCMInformedSettingsTests extends ESTestCase { + public void testGetElasticInferenceServiceUrl_ReturnsSettingUrl_WhenConfiguringCCMIsPermitted_ButSettingUrlExists() { + var url = "http://custom-url.com"; + var settings = Settings.builder().put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), url).build(); + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(true); + + var informedSettings = new CCMInformedSettings(settings, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is(url)); + } + + public void testGetElasticInferenceServiceUrl_ReturnsCCMDefault_WhenConfiguringCCMIsPermitted_ButSettingUrlIsEmpty() { + var settings = Settings.builder().put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), "").build(); + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(true); + + var informedSettings = new CCMInformedSettings(settings, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is(DEFAULT_CCM_URL)); + } + + public void testGetElasticInferenceServiceUrl_ReturnsCCMDefault_WhenConfiguringCCMIsPermitted_ButSettingUrlIsNull() { + var settings = Settings.builder().put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), (String) null).build(); + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(true); + + var informedSettings = new CCMInformedSettings(settings, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is(DEFAULT_CCM_URL)); + } + + public void testGetElasticInferenceServiceUrl_ReturnsCCMDefault_WhenConfiguringCCMIsPermitted_ButSettingUrlIsAbsent() { + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(true); + + var informedSettings = new CCMInformedSettings(Settings.EMPTY, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is(DEFAULT_CCM_URL)); + } + + public void testGetElasticInferenceServiceUrl_ReturnsSettingUrl_WhenConfiguringCCMIsNotPermitted() { + var url = "http://custom-url.com"; + var settings = Settings.builder().put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), url).build(); + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(false); + + var informedSettings = new CCMInformedSettings(settings, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is(url)); + } + + public void testGetElasticInferenceServiceUrl_ReturnsEmpty_WhenConfiguringCCMIsNotPermitted_AndSettingUrlIsEmpty() { + var url = ""; + var settings = Settings.builder().put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), url).build(); + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(false); + + var informedSettings = new CCMInformedSettings(settings, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is(url)); + } + + public void testGetElasticInferenceServiceUrl_ReturnsEmpty_WhenConfiguringCCMIsNotPermitted_AndSettingUrlIsNull() { + var settings = Settings.builder().put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), (String) null).build(); + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(false); + + var informedSettings = new CCMInformedSettings(settings, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is("")); + } + + public void testGetElasticInferenceServiceUrl_ReturnsEmpty_WhenConfiguringCCMIsNotPermitted_AndSettingUrlAbsent() { + var ccmFeature = mock(CCMFeature.class); + when(ccmFeature.isCcmSupportedEnvironment()).thenReturn(false); + + var informedSettings = new CCMInformedSettings(Settings.EMPTY, ccmFeature); + + assertThat(informedSettings.getElasticInferenceServiceUrl(), is("")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMServiceTests.java new file mode 100644 index 0000000000000..ef24573ee22d2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMServiceTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +public class CCMServiceTests extends ESTestCase { + + public static CCMService createMockCCMService(boolean enabled) { + var mockService = mock(CCMService.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(enabled); + return Void.TYPE; + }).when(mockService).isEnabled(any()); + + return mockService; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java index 4a547607b9f5c..257e2fe68edf2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.Before; @@ -31,7 +32,12 @@ public void testCreateUriThrowsForInvalidBaseUrl() { ElasticsearchStatusException exception = assertThrows( ElasticsearchStatusException.class, - () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomElasticInferenceServiceRequestMetadata()) + () -> new ElasticInferenceServiceAuthorizationRequest( + invalidUrl, + traceContext, + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER + ) ); assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java index b310a208123ac..d0fbbf37a7ad3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java @@ -9,10 +9,12 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.inference.InputType; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -21,6 +23,7 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.equalTo; @@ -158,7 +161,8 @@ public void testDecorate_HttpRequest_WithProductUseCase() { List.of(input), new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata", "1.2.3"), - inputType + inputType, + CCMAuthenticationApplierFactory.NOOP_APPLIER ); var httpRequest = request.createHttpRequest(); @@ -173,6 +177,33 @@ public void testDecorate_HttpRequest_WithProductUseCase() { } } + public void testDecorate_HttpRequest_WithAuthorizationHeader() { + var input = "elastic"; + var modelId = "my-model-id"; + var url = "http://eis-gateway.com"; + var secret = "secret"; + + for (var inputType : List.of(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED)) { + var request = new ElasticInferenceServiceDenseTextEmbeddingsRequest( + ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId), + List.of(input), + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata", "1.2.3"), + inputType, + new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(new SecureString(secret.toCharArray())) + ); + + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var headers = httpPost.getHeaders(HttpHeaders.AUTHORIZATION); + assertThat(headers.length, is(1)); + assertThat(headers[0].getValue(), is(bearerToken(secret))); + } + } + private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest( String url, String modelId, @@ -186,7 +217,8 @@ private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest( inputs, new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), randomElasticInferenceServiceRequestMetadata(), - inputType + inputType, + CCMAuthenticationApplierFactory.NOOP_APPLIER ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequestTests.java index da8247d15a837..32714dd2c973a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequestTests.java @@ -7,20 +7,38 @@ package org.elasticsearch.xpack.inference.services.elastic.request; +import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import java.net.URI; import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_ES_VERSION; import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; public class ElasticInferenceServiceRequestTests extends ESTestCase { + public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithAuthorizationHeader() { + var secret = "secret"; + var productOrigin = "elastic"; + var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest( + new ElasticInferenceServiceRequestMetadata(productOrigin, null, null), + new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(new SecureString(secret.toCharArray())) + ); + var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase().getHeaders(HttpHeaders.AUTHORIZATION).length, equalTo(1)); + assertThat(httpRequest.httpRequestBase().getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is(bearerToken(secret))); + } + public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() { var productOrigin = "elastic"; var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest( @@ -63,7 +81,14 @@ public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_Wi private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest( ElasticInferenceServiceRequestMetadata requestMetadata ) { - return new ElasticInferenceServiceRequest(requestMetadata) { + return getDummyElasticInferenceServiceRequest(requestMetadata, CCMAuthenticationApplierFactory.NOOP_APPLIER); + } + + private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest( + ElasticInferenceServiceRequestMetadata requestMetadata, + CCMAuthenticationApplierFactory.AuthApplier authApplier + ) { + return new ElasticInferenceServiceRequest(requestMetadata, authApplier) { @Override protected HttpRequestBase createHttpRequestBase() { return new HttpGet("http://localhost:8080"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java index 9fae589e2428e..1e19e66383feb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.io.IOException; @@ -23,6 +24,7 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.equalTo; @@ -110,7 +112,8 @@ public void testDecorate_HttpRequest_WithProductUseCase() { ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId), new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata", "1.2.3"), - inputType + inputType, + CCMAuthenticationApplierFactory.NOOP_APPLIER ); var httpRequest = request.createHttpRequest(); @@ -125,6 +128,34 @@ public void testDecorate_HttpRequest_WithProductUseCase() { } } + public void testDecorate_HttpRequest_WithAuthorizationHeader() { + var input = "elastic"; + var modelId = "my-model-id"; + var url = "http://eis-gateway.com"; + var secret = "secret"; + + for (var inputType : List.of(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED)) { + var request = new ElasticInferenceServiceSparseEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId), + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata", "1.2.3"), + inputType, + new CCMAuthenticationApplierFactory.AuthenticationHeaderApplier(secret) + ); + + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var headers = httpPost.getHeaders(HttpHeaders.AUTHORIZATION); + assertThat(headers.length, is(1)); + assertThat(headers[0].getValue(), is(bearerToken(secret))); + } + } + public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String modelId, String input, InputType inputType) { var embeddingsModel = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId); @@ -134,7 +165,8 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, embeddingsModel, new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), randomElasticInferenceServiceRequestMetadata(), - inputType + inputType, + CCMAuthenticationApplierFactory.NOOP_APPLIER ); } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index dbde7cc6bbb90..706e7d903ebcc 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -334,6 +334,7 @@ public class Constants { "cluster:internal/xpack/inference/fields/get", "cluster:internal/xpack/inference/rerankwindowsize/get", "cluster:internal/xpack/inference/unified", + "cluster:internal/xpack/inference/update_authorization_task", "cluster:internal/xpack/ml/auditor/reset", "cluster:internal/xpack/ml/coordinatedinference", "cluster:internal/xpack/ml/datafeed/isolate",