From db5a7b393cf26b61b500393837b94ee6f3182cab Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 14 Nov 2025 16:51:59 -0500 Subject: [PATCH 01/24] Starting new response class --- ...cInferenceServiceAuthorizationRequest.java | 7 +- ...eServiceAuthorizationResponseEntityV2.java | 359 ++++++++++++++++++ ...renceServiceAuthorizationRequestTests.java | 11 + 3 files changed, 374 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java index 42176b0a67515..4440eb4af2365 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequest.java @@ -9,6 +9,7 @@ import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.request.Request; @@ -24,6 +25,7 @@ public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenc private final URI uri; private final TraceContextHandler traceContextHandler; + static final String AUTHORIZATION_PATH = "/api/v2/authorizations"; public ElasticInferenceServiceAuthorizationRequest( String url, @@ -35,10 +37,9 @@ public ElasticInferenceServiceAuthorizationRequest( this.traceContextHandler = new TraceContextHandler(traceContext); } - private URI createUri(String url) throws ElasticsearchStatusException { + private static URI createUri(String url) throws ElasticsearchStatusException { try { - // TODO, consider transforming the base URL into a URI for better error handling. - return new URI(url + "/api/v1/authorizations"); + return new URIBuilder(url).setPath(AUTHORIZATION_PATH).build(); } catch (URISyntaxException e) { throw new ElasticsearchStatusException( "Failed to create URI for service [" + ElasticInferenceService.NAME + "]: " + e.getMessage(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java new file mode 100644 index 0000000000000..85c9b1700401b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java @@ -0,0 +1,359 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.response; + +import org.apache.http.auth.AUTH; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/* + +{ + "inference_endpoints": [ + { + "id": ".rainbow-sprinkles-elastic", + "model_name": "rainbow-sprinkles", + "task_type": "chat", + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01", + "end_of_life_date": "2025-12-31" + }, + { + "id": ".elastic-elser-v2", + "model_name": "elser_model_2", + "task_type": "embed/text/sparse", + "status": "preview", + "properties": [ + "english" + ], + "release_date": "2024-05-01", + "configuration": { + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 + } + } + }, + { + "id": ".jina-embeddings-v3", + "model_name": "jina-embeddings-v3", + "task_type": "embed/text/dense", + "status": "beta", + "properties": [ + "multilingual", + "open-weights" + ], + "release_date": "2024-05-01", + "configuration": { + "similarity": "cosine", + "dimension": 1024, + "element_type": "float", + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 + } + } + } + ] +} + */ +public class ElasticInferenceServiceAuthorizationResponseEntityV2 implements InferenceServiceResults { + + public static final String NAME = "elastic_inference_service_auth_results_v2"; + + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntityV2.class); + private static final String AUTH_FIELD_NAME = "authorized_models"; + private static final Map ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of( + "embed/text/sparse", + TaskType.SPARSE_EMBEDDING, + "chat", + TaskType.CHAT_COMPLETION, + "embed/text/dense", + TaskType.TEXT_EMBEDDING, + "rerank/text/text-similarity", + TaskType.RERANK + ); + + private static final String INFERENCE_ENDPOINTS = "inference_endpoints"; + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + ElasticInferenceServiceAuthorizationResponseEntityV2.class.getSimpleName(), + true, + args -> new ElasticInferenceServiceAuthorizationResponseEntityV2((List) args[0]) + ); + + static { + PARSER.declareObjectArray( + constructorArg(), + AuthorizedEndpoint.AUTHORIZED_ENDPOINT_PARSER::apply, + new ParseField(INFERENCE_ENDPOINTS) + ); + } + + public record AuthorizedEndpoint( + String id, + String modelName, + String taskType, + String status, + @Nullable List properties, + String releaseDate, + @Nullable Configuration configuration + ) implements Writeable, ToXContentObject { + + private static final String ID = "id"; + private static final String MODEL_NAME = "model_name"; + private static final String TASK_TYPE = "task_type"; + private static final String STATUS = "status"; + private static final String PROPERTIES = "properties"; + private static final String RELEASE_DATE = "release_date"; + private static final String CONFIGURATION = "configuration"; + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser AUTHORIZED_ENDPOINT_PARSER = new ConstructingObjectParser<>( + AuthorizedEndpoint.class.getSimpleName(), + true, + args -> new AuthorizedEndpoint( + (String) args[0], + (String) args[1], + (String) args[2], + (String) args[3], + (List) args[4], + (String) args[5], + (Configuration) args[6] + ) + ); + + static { + AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(ID)); + AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(MODEL_NAME)); + AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(TASK_TYPE)); + AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(STATUS)); + AUTHORIZED_ENDPOINT_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(PROPERTIES)); + AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(RELEASE_DATE)); + AUTHORIZED_ENDPOINT_PARSER.declareObject( + optionalConstructorArg(), + Configuration.CONFIGURATION_PARSER::apply, + new ParseField(CONFIGURATION) + ); + } + + private static EnumSet toTaskTypes(List stringTaskTypes) { + var taskTypes = EnumSet.noneOf(TaskType.class); + for (String taskType : stringTaskTypes) { + var mappedTaskType = ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING.get(taskType); + if (mappedTaskType != null) { + taskTypes.add(mappedTaskType); + } + } + + return taskTypes; + } + + public AuthorizedEndpoint(StreamInput in) throws IOException { + this(in.readString(), in.readEnumSet(TaskType.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelName); + out.writeEnumSet(taskTypes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(ID, ) + + builder.field("model_name", modelName); + builder.field("task_types", taskTypes.stream().map(TaskType::toString).collect(Collectors.toList())); + + builder.endObject(); + + return builder; + } + + @Override + public String toString() { + return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes); + } + } + + public record Configuration( + @Nullable String similarity, + @Nullable Integer dimensions, + @Nullable String elementType, + @Nullable Map chunkingSettings + ) implements Writeable, ToXContentObject { + + private static final String SIMILARITY = "similarity"; + private static final String DIMENSIONS = "dimensions"; + private static final String ELEMENT_TYPE = "element_type"; + private static final String CHUNKING_SETTINGS = "chunking_settings"; + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser CONFIGURATION_PARSER = new ConstructingObjectParser<>( + Configuration.class.getSimpleName(), + true, + args -> new Configuration((String) args[0], (Integer) args[1], (String) args[2], (Map) args[3]) + ); + + static { + CONFIGURATION_PARSER.declareString(optionalConstructorArg(), new ParseField(SIMILARITY)); + CONFIGURATION_PARSER.declareInt(optionalConstructorArg(), new ParseField(DIMENSIONS)); + CONFIGURATION_PARSER.declareString(optionalConstructorArg(), new ParseField(ELEMENT_TYPE)); + CONFIGURATION_PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField(CHUNKING_SETTINGS)); + } + + public Configuration(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readOptionalVInt(), in.readOptionalString(), in.readGenericMap()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(similarity); + out.writeOptionalVInt(dimensions); + out.writeOptionalString(elementType); + out.writeGenericMap(chunkingSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + + if (elementType != null) { + builder.field(ELEMENT_TYPE, elementType); + } + + if (chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS, chunkingSettings); + } + + builder.endObject(); + return builder; + } + } + + private final List authorizedModels; + + public ElasticInferenceServiceAuthorizationResponseEntityV2(List authorizedModels) { + this.authorizedModels = Objects.requireNonNull(authorizedModels); + } + + /** + * Create an empty response + */ + public ElasticInferenceServiceAuthorizationResponseEntityV2() { + this(List.of()); + } + + public ElasticInferenceServiceAuthorizationResponseEntityV2(StreamInput in) throws IOException { + this(in.readCollectionAsList(AuthorizedEndpoint::new)); + } + + public static ElasticInferenceServiceAuthorizationResponseEntityV2 fromResponse(Request request, HttpResult response) + throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + return PARSER.apply(jsonParser, null); + } + } + + public List getAuthorizedModels() { + return authorizedModels; + } + + @Override + public String toString() { + return authorizedModels.stream().map(AuthorizedEndpoint::toString).collect(Collectors.joining(", ")); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(authorizedModels); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ElasticInferenceServiceAuthorizationResponseEntityV2 that = (ElasticInferenceServiceAuthorizationResponseEntityV2) o; + return Objects.equals(authorizedModels, that.authorizedModels); + } + + @Override + public int hashCode() { + return Objects.hash(authorizedModels); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java index 4a547607b9f5c..9e01c2ff4a5a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java @@ -13,6 +13,10 @@ import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.Before; +import java.net.URI; +import java.net.URISyntaxException; + +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest.AUTHORIZATION_PATH; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; @@ -37,4 +41,11 @@ public void testCreateUriThrowsForInvalidBaseUrl() { assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); assertThat(exception.getMessage(), containsString("Failed to create URI for service")); } + + public void testCreateUri_CreatesUri() throws URISyntaxException { + String url = "https://inference.us-east-1.aws.svc.elastic.cloud"; + + var request = new ElasticInferenceServiceAuthorizationRequest(url, traceContext, randomElasticInferenceServiceRequestMetadata()); + assertThat(request.getURI(), is(new URI(url + AUTHORIZATION_PATH))); + } } From 5060aab50594d1dc6d9550f38bf828ba51820f5b Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 17 Nov 2025 16:40:40 -0500 Subject: [PATCH 02/24] Writing tests --- .../action/StoreInferenceEndpointsAction.java | 6 +- ...eInferenceEndpointsActionRequestTests.java | 3 +- .../TransportGetInferenceServicesAction.java | 12 +- .../inference/registry/ModelRegistry.java | 14 +- .../authorization/AuthorizationModel.java | 305 ++++++++++++ .../authorization/AuthorizationPoller.java | 34 +- ...nceServiceAuthorizationRequestHandler.java | 12 +- ...lasticInferenceServiceCompletionModel.java | 4 +- ...eServiceAuthorizationResponseEntityV2.java | 170 ++++--- .../elastic/ElasticInferenceServiceTests.java | 44 +- .../AuthorizationModelTests.java | 451 ++++++++++++++++++ .../AuthorizationPollerTests.java | 110 ++--- ...ferenceServiceAuthorizationModelTests.java | 385 --------------- ...rviceAuthorizationRequestHandlerTests.java | 30 +- ...enceServiceAuthorizationEntityV2Tests.java | 82 ++++ 15 files changed, 1039 insertions(+), 623 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java index aa613cda60399..01f2ac02dc284 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java @@ -35,9 +35,9 @@ public StoreInferenceEndpointsAction() { } public static class Request extends AcknowledgedRequest { - private final List models; + private final List models; - public Request(List models, TimeValue timeout) { + public Request(List models, TimeValue timeout) { super(timeout, DEFAULT_ACK_TIMEOUT); this.models = Objects.requireNonNull(models); } @@ -53,7 +53,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(models); } - public List getModels() { + public List getModels() { return models; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java index 3673296c29ce7..1bc9655ff74b9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.core.XPackClientPlugin; @@ -43,7 +44,7 @@ protected StoreInferenceEndpointsAction.Request createTestInstance() { @Override protected StoreInferenceEndpointsAction.Request mutateInstance(StoreInferenceEndpointsAction.Request instance) throws IOException { - var newModels = new ArrayList<>(instance.getModels()); + var newModels = new ArrayList(instance.getModels()); newModels.add(ModelTests.randomModel()); return new StoreInferenceEndpointsAction.Request(newModels, instance.masterNodeTimeout()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index 18c83df4067ed..605b410075d05 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import java.util.ArrayList; @@ -122,7 +122,7 @@ private void getServiceConfigurationsForServicesAndEis( ArrayList> availableServices, @Nullable TaskType requestedTaskType ) { - SubscribableListener.newForked(authModelListener -> { + SubscribableListener.newForked(authModelListener -> { // Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); }).>andThen((configurationListener, authorizationModel) -> { @@ -133,8 +133,8 @@ private void getServiceConfigurationsForServicesAndEis( return; } - var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes()); - if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) { + var config = ElasticInferenceService.createConfiguration(authorizationModel.getTaskTypes()); + if (requestedTaskType != null && authorizationModel.getTaskTypes().contains(requestedTaskType) == false) { configurationListener.onResponse(serviceConfigs); return; } @@ -150,14 +150,14 @@ private void getServiceConfigurationsForServicesAndEis( ); } - private void getEisAuthorization(ActionListener listener, Sender sender) { + private void getEisAuthorization(ActionListener listener, Sender sender) { var disabledServiceListener = listener.delegateResponse((delegate, e) -> { logger.warn( "Failed to retrieve authorization information from the " + "Elastic Inference Service while determining service configurations. Marking service as disabled.", e ); - delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + delegate.onResponse(AuthorizationModel.empty()); }); eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index cf731b22807e2..62832b3538e7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -712,12 +712,12 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< }), timeout); } - public void storeModels(List models, ActionListener> listener, TimeValue timeout) { + public void storeModels(List models, ActionListener> listener, TimeValue timeout) { storeModels(models, true, listener, timeout); } private void storeModels( - List models, + List models, boolean updateClusterState, ActionListener> listener, TimeValue timeout @@ -745,7 +745,7 @@ private void storeModels( } private ActionListener getStoreMultipleModelsListener( - List models, + List models, boolean updateClusterState, ActionListener> listener, TimeValue timeout @@ -818,12 +818,12 @@ private ActionListener getStoreMultipleModelsListener( private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {} - private record ResponseInfo(List responses, List successfullyStoredModels) {} + private record ResponseInfo(List responses, List successfullyStoredModels) {} private static ResponseInfo getResponseInfo( BulkResponse bulkResponse, Map docIdToInferenceId, - Map inferenceIdToModel + Map inferenceIdToModel ) { var responses = new ArrayList(); var successfullyStoredModels = new ArrayList(); @@ -909,7 +909,7 @@ private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item } } - private static Model getModelFromMap(@Nullable String inferenceId, Map inferenceIdToModel) { + private static Model getModelFromMap(@Nullable String inferenceId, Map inferenceIdToModel) { if (inferenceId != null) { return inferenceIdToModel.get(inferenceId); } @@ -917,7 +917,7 @@ private static Model getModelFromMap(@Nullable String inferenceId, Map models, ActionListener listener, TimeValue timeout) { + private void updateClusterState(List models, ActionListener listener, TimeValue timeout) { var inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet()); var storeListener = listener.delegateResponse((delegate, exc) -> { logger.warn(format("Failed to add minimal service settings to cluster state for inference endpoints %s", inferenceIdsSet), exc); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java new file mode 100644 index 0000000000000..9608c556e4d0e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -0,0 +1,305 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; + +/** + * Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format for consumption by the service. + */ +public class AuthorizationModel { + + private static final Logger logger = LogManager.getLogger(AuthorizationModel.class); + private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping"; + + public static AuthorizationModel of(ElasticInferenceServiceAuthorizationResponseEntityV2 responseEntity, String baseEisUrl) { + var components = new ElasticInferenceServiceComponents(baseEisUrl); + return createInternal(responseEntity.getAuthorizedEndpoints(), components); + } + + private static AuthorizationModel createInternal( + List responseEndpoints, + ElasticInferenceServiceComponents components + ) { + var validEndpoints = new ArrayList(); + for (var authorizedEndpoint : responseEndpoints) { + var model = createModel(authorizedEndpoint, components); + if (model != null) { + validEndpoints.add(model); + } + } + + return new AuthorizationModel(validEndpoints); + } + + private static ElasticInferenceServiceModel createModel( + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceComponents components + ) { + try { + var taskType = getTaskType(authorizedEndpoint.taskType()); + if (taskType == null) { + logger.warn(UNKNOWN_TASK_TYPE_LOG_MESSAGE, authorizedEndpoint.id(), authorizedEndpoint.taskType()); + return null; + } + + return switch (taskType) { + case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, components); + case SPARSE_EMBEDDING -> createSparseEmbeddingsModel(authorizedEndpoint, components); + case TEXT_EMBEDDING -> createDenseTextEmbeddingsModel(authorizedEndpoint, components); + case RERANK -> createRerankModel(authorizedEndpoint, components); + default -> { + logger.warn(UNKNOWN_TASK_TYPE_LOG_MESSAGE, authorizedEndpoint.id(), taskType); + yield null; + } + }; + } catch (Exception e) { + logger.atWarn() + .withThrowable(e) + .log( + "Failed to create model for authorized endpoint id [{}] with task type [{}], skipping", + authorizedEndpoint.id(), + authorizedEndpoint.taskType() + ); + return null; + } + } + + private static TaskType getTaskType(String taskType) { + try { + return TaskType.fromString(taskType); + } catch (IllegalArgumentException e) { + return null; + } + } + + private static ElasticInferenceServiceCompletionModel createCompletionModel( + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceComponents components + ) { + return new ElasticInferenceServiceCompletionModel( + authorizedEndpoint.id(), + CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(authorizedEndpoint.modelName()), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + components + ); + } + + private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddingsModel( + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceComponents components + ) { + return new ElasticInferenceServiceSparseEmbeddingsModel( + authorizedEndpoint.id(), + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(authorizedEndpoint.modelName(), null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + components, + ChunkingSettingsBuilder.fromMap(getChunkingSettingsMap(getConfigurationOrEmpty(authorizedEndpoint))) + ); + } + + private static ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration getConfigurationOrEmpty( + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint + ) { + if (authorizedEndpoint.configuration() != null) { + return authorizedEndpoint.configuration(); + } + + return ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.EMPTY; + } + + private static Map getChunkingSettingsMap( + ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration configuration + ) { + return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>()); + } + + private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEmbeddingsModel( + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceComponents components + ) { + var config = getConfigurationOrEmpty(authorizedEndpoint); + validateConfigurationForTextEmbedding(config); + + return new ElasticInferenceServiceDenseTextEmbeddingsModel( + authorizedEndpoint.id(), + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + authorizedEndpoint.modelName(), + getSimilarityMeasure(config), + config.dimensions(), + null + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + components, + ChunkingSettingsBuilder.fromMap(getChunkingSettingsMap(config)) + ); + } + + private static void validateConfigurationForTextEmbedding(ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration config) { + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.ELEMENT_TYPE, + config.elementType(), + TaskType.TEXT_EMBEDDING + ); + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.DIMENSIONS, + config.dimensions(), + TaskType.TEXT_EMBEDDING + ); + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.SIMILARITY, + config.similarity(), + TaskType.TEXT_EMBEDDING + ); + } + + private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) { + if (fieldValue == null) { + throw new IllegalArgumentException( + Strings.format("Required field [%s] is missing for task_type [%s]", field, taskType.toString()) + ); + } + } + + private static SimilarityMeasure getSimilarityMeasure( + ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration configuration + ) { + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.SIMILARITY, + configuration.similarity(), + TaskType.TEXT_EMBEDDING + ); + + return SimilarityMeasure.fromString(configuration.similarity()); + } + + private static ElasticInferenceServiceRerankModel createRerankModel( + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceComponents components + ) { + return new ElasticInferenceServiceRerankModel( + authorizedEndpoint.id(), + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings(authorizedEndpoint.modelName()), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + components + ); + } + + /** + * Returns an object indicating that the cluster is not authorized for any endpoints from EIS. + */ + public static AuthorizationModel empty() { + return new AuthorizationModel(List.of()); + } + + private final Map authorizedEndpoints; + private final EnumSet taskTypes; + + // Default for testing + AuthorizationModel(List authorizedEndpoints) { + Objects.requireNonNull(authorizedEndpoints); + this.authorizedEndpoints = authorizedEndpoints.stream() + .collect(Collectors.toMap(ElasticInferenceServiceModel::getInferenceEntityId, Function.identity(), (a, b) -> a, HashMap::new)); + + var taskTypesSet = EnumSet.noneOf(TaskType.class); + taskTypesSet.addAll(this.authorizedEndpoints.values().stream().map(ElasticInferenceServiceModel::getTaskType).toList()); + this.taskTypes = taskTypesSet; + } + + /** + * Returns true if at least one endpoint is authorized. + * @return true if this cluster is authorized for at least one endpoint. + */ + public boolean isAuthorized() { + return authorizedEndpoints.isEmpty() == false; + } + + /** + * Returns a new {@link AuthorizationModel} object retaining only the specified task types + * and applicable models that leverage those task types. Any task types not specified in the provided parameter will be + * excluded from the returned object. This is essentially an intersection. + * @param taskTypes the task types to retain in the newly created object + * @return a new object containing endpoints limited to the specified task types + */ + public AuthorizationModel newLimitedToTaskTypes(EnumSet taskTypes) { + var endpoints = this.authorizedEndpoints.values().stream().filter(endpoint -> taskTypes.contains(endpoint.getTaskType())).toList(); + return new AuthorizationModel(endpoints); + } + + public EnumSet getTaskTypes() { + return EnumSet.copyOf(taskTypes); + } + + public Set getEndpointIds() { + return Set.copyOf(authorizedEndpoints.keySet()); + } + + public List getEndpoints(Set endpointIds) { + return endpointIds.stream().map(authorizedEndpoints::get).filter(Objects::nonNull).toList(); + } + + @Override + public String toString() { + return String.format("AuthorizationModel{authorizedEndpoints=%s, taskTypes=%s}", authorizedEndpoints, taskTypes); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + AuthorizationModel that = (AuthorizationModel) o; + return Objects.equals(authorizedEndpoints, that.authorizedEndpoints) && Objects.equals(taskTypes, that.taskTypes); + } + + @Override + public int hashCode() { + return Objects.hash(authorizedEndpoints, taskTypes); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 70572e2b9c996..4dacafd7950df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.tasks.TaskId; @@ -28,18 +29,15 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.IMPLEMENTED_TASK_TYPES; @@ -233,7 +231,7 @@ void sendAuthorizationRequest() { delegate.onResponse(null); }); - SubscribableListener.newForked( + SubscribableListener.newForked( authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) ) .andThenApply(this::getNewInferenceEndpointsToStore) @@ -241,31 +239,27 @@ void sendAuthorizationRequest() { .addListener(finalListener); } - private Set getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { + private List getNewInferenceEndpointsToStore(AuthorizationModel authModel) { var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); - var authorizedModelIds = scopedAuthModel.getAuthorizedModelIds(); + var newEndpointIds = scopedAuthModel.getEndpointIds(); var existingInferenceIds = modelRegistry.getInferenceIds(); - var newInferenceIds = authorizedModelIds.stream() - .map(InternalPreconfiguredEndpoints::getWithModelName) - .flatMap(List::stream) - .map(model -> model.configurations().getInferenceEntityId()) - .collect(Collectors.toSet()); - - newInferenceIds.removeAll(existingInferenceIds); - return newInferenceIds; + newEndpointIds.removeAll(existingInferenceIds); + return scopedAuthModel.getEndpoints(newEndpointIds); } - private void storePreconfiguredModels(Set newInferenceIds, ActionListener listener) { - if (newInferenceIds.isEmpty()) { + private void storePreconfiguredModels(List newEndpoints, ActionListener listener) { + if (newEndpoints.isEmpty()) { listener.onResponse(null); return; } - logger.info("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds); - var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents); - var storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS); + logger.info( + "Storing new EIS preconfigured inference endpoints with inference IDs {}", + newEndpoints.stream().map(Model::getInferenceEntityId).toList() + ); + var storeRequest = new StoreInferenceEndpointsAction.Request(newEndpoints, TimeValue.THIRTY_SECONDS); ActionListener logResultsListener = ActionListener.wrap(responses -> { for (var response : responses.getResults()) { @@ -278,7 +272,7 @@ private void storePreconfiguredModels(Set newInferenceIds, ActionListene .log("Successfully stored EIS preconfigured inference endpoint with inference ID [{}]", response.inferenceId()); } } - }, e -> logger.atWarn().withThrowable(e).log("Failed to store new EIS preconfigured inference endpoints [{}]", newInferenceIds)); + }, e -> logger.atWarn().withThrowable(e).log("Failed to store new EIS preconfigured inference endpoints [{}]", newEndpoints)); client.execute( StoreInferenceEndpointsAction.INSTANCE, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 02800105ef83d..51a94914390eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -23,7 +23,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; @@ -46,7 +46,7 @@ public class ElasticInferenceServiceAuthorizationRequestHandler { private static ResponseHandler createAuthResponseHandler() { return new ElasticInferenceServiceResponseHandler( Strings.format("%s authorization", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - ElasticInferenceServiceAuthorizationResponseEntity::fromResponse + ElasticInferenceServiceAuthorizationResponseEntityV2::fromResponse ); } @@ -73,13 +73,13 @@ public ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseU * @param listener a listener to receive the response * @param sender a {@link Sender} for making the request to the Elastic Inference Service */ - public void getAuthorization(ActionListener listener, Sender sender) { + public void getAuthorization(ActionListener listener, Sender sender) { try { logger.debug("Retrieving authorization information from the Elastic Inference Service."); if (Strings.isNullOrEmpty(baseUrl)) { logger.debug("The base URL for the authorization service is not valid, rejecting authorization."); - listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + listener.onResponse(AuthorizationModel.empty()); return; } @@ -96,9 +96,9 @@ public void getAuthorization(ActionListener { - if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { + if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntityV2 authResponseEntity) { logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); - return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity); + return AuthorizationModel.of(authResponseEntity, baseUrl); } var errorMessage = Strings.format( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 969bf06d47fe0..31d0e171ac673 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -78,8 +78,8 @@ public ElasticInferenceServiceCompletionModel( TaskType taskType, String service, ElasticInferenceServiceCompletionServiceSettings serviceSettings, - @Nullable TaskSettings taskSettings, - @Nullable SecretSettings secretSettings, + TaskSettings taskSettings, + SecretSettings secretSettings, ElasticInferenceServiceComponents elasticInferenceServiceComponents ) { super( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java index 85c9b1700401b..d4528c8401a0c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.elastic.response; -import org.apache.http.auth.AUTH; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -102,19 +101,6 @@ public class ElasticInferenceServiceAuthorizationResponseEntityV2 implements Inf public static final String NAME = "elastic_inference_service_auth_results_v2"; - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntityV2.class); - private static final String AUTH_FIELD_NAME = "authorized_models"; - private static final Map ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of( - "embed/text/sparse", - TaskType.SPARSE_EMBEDDING, - "chat", - TaskType.CHAT_COMPLETION, - "embed/text/dense", - TaskType.TEXT_EMBEDDING, - "rerank/text/text-similarity", - TaskType.RERANK - ); - private static final String INFERENCE_ENDPOINTS = "inference_endpoints"; @SuppressWarnings("unchecked") @@ -173,52 +159,109 @@ public record AuthorizedEndpoint( AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(STATUS)); AUTHORIZED_ENDPOINT_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(PROPERTIES)); AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(RELEASE_DATE)); - AUTHORIZED_ENDPOINT_PARSER.declareObject( - optionalConstructorArg(), - Configuration.CONFIGURATION_PARSER::apply, - new ParseField(CONFIGURATION) - ); - } - - private static EnumSet toTaskTypes(List stringTaskTypes) { - var taskTypes = EnumSet.noneOf(TaskType.class); - for (String taskType : stringTaskTypes) { - var mappedTaskType = ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING.get(taskType); - if (mappedTaskType != null) { - taskTypes.add(mappedTaskType); - } - } - - return taskTypes; + AUTHORIZED_ENDPOINT_PARSER.declareObject(optionalConstructorArg(), Configuration.PARSER::apply, new ParseField(CONFIGURATION)); } public AuthorizedEndpoint(StreamInput in) throws IOException { - this(in.readString(), in.readEnumSet(TaskType.class)); + this( + in.readString(), + in.readString(), + in.readString(), + in.readString(), + in.readOptionalCollectionAsList(StreamInput::readString), + in.readString(), + in.readOptionalWriteable(Configuration::new) + ); } @Override public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); out.writeString(modelName); - out.writeEnumSet(taskTypes); + out.writeString(taskType); + out.writeString(status); + out.writeOptionalCollection(properties, StreamOutput::writeString); + out.writeString(releaseDate); + out.writeOptionalWriteable(configuration); } - @Override + @Override + public String toString() { + return Strings.format( + "AuthorizedEndpoint{id='%s', modelName='%s', taskType='%s', status='%s', " + + "properties=%s, releaseDate='%s', configuration=%s}", + id, + modelName, + taskType, + status, + properties, + releaseDate, + configuration + ); + } + + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ID, ) - - builder.field("model_name", modelName); - builder.field("task_types", taskTypes.stream().map(TaskType::toString).collect(Collectors.toList())); + builder.field(ID, id); + builder.field(MODEL_NAME, modelName); + builder.field(TASK_TYPE, taskType); + builder.field(STATUS, status); + if (properties != null) { + builder.field(PROPERTIES, properties); + } + builder.field(RELEASE_DATE, releaseDate); + if (configuration != null) { + builder.field(CONFIGURATION, configuration); + } builder.endObject(); return builder; } + } + + public record TaskTypeObject(String eisTaskType, String elasticsearchTaskType) implements Writeable, ToXContentObject { + + private static final String EIS_TASK_TYPE_FIELD = "eis"; + private static final String ELASTICSEARCH_TASK_TYPE_FIELD = "elasticsearch"; + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + TaskTypeObject.class.getSimpleName(), + true, + args -> new TaskTypeObject((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(optionalConstructorArg(), new ParseField(EIS_TASK_TYPE_FIELD)); + PARSER.declareString(constructorArg(), new ParseField(ELASTICSEARCH_TASK_TYPE_FIELD)); + } + + public TaskTypeObject(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readString()); + } @Override public String toString() { - return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes); + return Strings.format("TaskTypeObject{eisTaskType='%s', elasticsearchTaskType='%s'}", eisTaskType, elasticsearchTaskType); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(eisTaskType); + out.writeString(elasticsearchTaskType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (eisTaskType != null) { + builder.field(EIS_TASK_TYPE_FIELD, eisTaskType); + } + builder.field(ELASTICSEARCH_TASK_TYPE_FIELD, elasticsearchTaskType); + builder.endObject(); + return builder; } } @@ -229,23 +272,25 @@ public record Configuration( @Nullable Map chunkingSettings ) implements Writeable, ToXContentObject { - private static final String SIMILARITY = "similarity"; - private static final String DIMENSIONS = "dimensions"; - private static final String ELEMENT_TYPE = "element_type"; - private static final String CHUNKING_SETTINGS = "chunking_settings"; + public static final Configuration EMPTY = new Configuration(null, null, null, null); + + public static final String SIMILARITY = "similarity"; + public static final String DIMENSIONS = "dimensions"; + public static final String ELEMENT_TYPE = "element_type"; + public static final String CHUNKING_SETTINGS = "chunking_settings"; @SuppressWarnings("unchecked") - public static final ConstructingObjectParser CONFIGURATION_PARSER = new ConstructingObjectParser<>( + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( Configuration.class.getSimpleName(), true, args -> new Configuration((String) args[0], (Integer) args[1], (String) args[2], (Map) args[3]) ); static { - CONFIGURATION_PARSER.declareString(optionalConstructorArg(), new ParseField(SIMILARITY)); - CONFIGURATION_PARSER.declareInt(optionalConstructorArg(), new ParseField(DIMENSIONS)); - CONFIGURATION_PARSER.declareString(optionalConstructorArg(), new ParseField(ELEMENT_TYPE)); - CONFIGURATION_PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField(CHUNKING_SETTINGS)); + PARSER.declareString(optionalConstructorArg(), new ParseField(SIMILARITY)); + PARSER.declareInt(optionalConstructorArg(), new ParseField(DIMENSIONS)); + PARSER.declareString(optionalConstructorArg(), new ParseField(ELEMENT_TYPE)); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField(CHUNKING_SETTINGS)); } public Configuration(StreamInput in) throws IOException { @@ -282,12 +327,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + @Override + public String toString() { + return Strings.format( + "Configuration{similarity='%s', dimensions=%s, elementType='%s', chunkingSettings=%s}", + similarity, + dimensions, + elementType, + chunkingSettings + ); + } } - private final List authorizedModels; + private final List authorizedEndpoints; public ElasticInferenceServiceAuthorizationResponseEntityV2(List authorizedModels) { - this.authorizedModels = Objects.requireNonNull(authorizedModels); + this.authorizedEndpoints = Objects.requireNonNull(authorizedModels); } /** @@ -310,23 +366,23 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2 fromResponse( } } - public List getAuthorizedModels() { - return authorizedModels; + public List getAuthorizedEndpoints() { + return authorizedEndpoints; } @Override public String toString() { - return authorizedModels.stream().map(AuthorizedEndpoint::toString).collect(Collectors.joining(", ")); + return authorizedEndpoints.stream().map(AuthorizedEndpoint::toString).collect(Collectors.joining(", ")); } @Override public Iterator toXContentChunked(ToXContent.Params params) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("Not implemented"); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(authorizedModels); + out.writeCollection(authorizedEndpoints); } @Override @@ -349,11 +405,11 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ElasticInferenceServiceAuthorizationResponseEntityV2 that = (ElasticInferenceServiceAuthorizationResponseEntityV2) o; - return Objects.equals(authorizedModels, that.authorizedModels); + return Objects.equals(authorizedEndpoints, that.authorizedEndpoints); } @Override public int hashCode() { - return Objects.hash(authorizedModels); + return Objects.hash(authorizedEndpoints); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 4b17cab04471a..abec53f9111b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -48,14 +48,11 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModelTests; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -911,27 +908,8 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio } } - public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); - } - } - public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ) - ) { + try (var service = createServiceWithMockSender()) { expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } @@ -1015,21 +993,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { } public void testGetConfiguration_ThrowsUnsupported() throws Exception { - try ( - var service = createServiceWithMockSender( - // this service doesn't yet support text embedding so we should still have no task types - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { + try (var service = createServiceWithMockSender()) { expectThrows(UnsupportedOperationException.class, service::getConfiguration); } } @@ -1154,10 +1118,6 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp } private ElasticInferenceService createServiceWithMockSender() { - return createServiceWithMockSender(ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth()); - } - - private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel auth) { var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java new file mode 100644 index 0000000000000..2c363b04ea277 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java @@ -0,0 +1,451 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsInAnyOrder; + +public class AuthorizationModelTests extends ESTestCase { + public static AuthorizationModel createAuthorizationModel(TaskType taskType) { + String id = randomAlphaOfLength(10); + String name = randomAlphaOfLength(10); + String url = "url"; + ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint endpoint; + var status = randomFrom("ga", "beta", "preview"); + + + switch (taskType) { + case CHAT_COMPLETION -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.CHAT_COMPLETION.toString(), + status, + null, + "", + null + ); + case SPARSE_EMBEDDING -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.SPARSE_EMBEDDING.toString(), + status, + null, + "", + null + ); + case TEXT_EMBEDDING -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.TEXT_EMBEDDING.toString(), + status, + null, + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + randomFrom(SimilarityMeasure.values()).toString(), + randomInt(), + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ); + case RERANK -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.RERANK.toString(), + status, + null, + "", + null + ); + default -> throw new IllegalArgumentException("Unsupported task type: " + taskType); + } + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(endpoint)); + return AuthorizationModel.of(response, url); + } + + public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { + assertFalse(new AuthorizationModel(List.of()).isAuthorized()); + assertFalse(AuthorizationModel.empty().isAuthorized()); + } + + public void testExcludes_EndpointsWithoutValidTaskTypes() { + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + "id", + "name", + "invalid_task_type", + "ga", + null, + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + "id2", + "name", + TaskType.ANY.toString(), + "ga", + null, + "", + null + ) + ) + ); + var auth = AuthorizationModel.of(response, "url"); + assertTrue(auth.getTaskTypes().isEmpty()); + assertFalse(auth.isAuthorized()); + } + + public void testReturnsAuthorizedTaskTypes() { + var id1 = "id1"; + var id2 = "id2"; + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id1, + "name1", + TaskType.CHAT_COMPLETION.toString(), + "ga", + null, + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id2, + "name2", + TaskType.SPARSE_EMBEDDING.toString(), + "ga", + null, + "", + null + ) + ) + ); + + var auth = AuthorizationModel.of(response, "url"); + assertThat(auth.getTaskTypes(), is(Set.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(auth.getEndpointIds(), is(Set.of(id1, id2))); + assertTrue(auth.isAuthorized()); + } + + public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { + var id = "id1"; + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + "name1", + TaskType.CHAT_COMPLETION.toString(), + "ga", + null, + "", + null + ), + // This should be ignored because the id is a duplicate + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + "name2", + TaskType.SPARSE_EMBEDDING.toString(), + "ga", + null, + "", + null + ) + ) + ); + + var auth = AuthorizationModel.of(response, "url"); + assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); + assertThat(auth.getEndpointIds(), is(Set.of(id))); + assertTrue(auth.isAuthorized()); + } + + public void testReturnsAuthorizedEndpoints() { + var id1 = "id1"; + var id2 = "id2"; + + var name1 = "name1"; + var name2 = "name2"; + + var similarity = SimilarityMeasure.COSINE; + var dimensions = 123; + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id1, + name1, + TaskType.CHAT_COMPLETION.toString(), + "ga", + null, + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id2, + name2, + TaskType.TEXT_EMBEDDING.toString(), + "ga", + null, + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + similarity.toString(), + dimensions, + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ) + ) + ); + + var url = "url"; + + var auth = AuthorizationModel.of(response, url); + assertThat(auth.getEndpointIds(), is(Set.of(id1, id2))); + assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING))); + assertTrue(auth.isAuthorized()); + + var chatCompletionEndpoint = new ElasticInferenceServiceCompletionModel( + id1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(name1), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + var textEmbeddingEndpoint = new ElasticInferenceServiceDenseTextEmbeddingsModel( + id2, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(name2, similarity, dimensions, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ); + + assertThat(auth.getEndpoints(Set.of(id1, id2)), containsInAnyOrder(chatCompletionEndpoint, textEmbeddingEndpoint)); + + assertThat(auth.getEndpoints(Set.of(id2)), is(List.of(textEmbeddingEndpoint))); + + assertThat(auth.getEndpoints(Set.of()), is(List.of())); + } + + public void testReturnsAuthorizedEndpoints_FiltersInvalid() { + var id1 = "id1"; + var id2 = "invalid_text_embedding"; + + var name1 = "name1"; + var name2 = "name2"; + + var dimensions = 123; + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id1, + name1, + TaskType.CHAT_COMPLETION.toString(), + "ga", + null, + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id2, + name2, + TaskType.TEXT_EMBEDDING.toString(), + "ga", + null, + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + null, + dimensions, + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id2, + name2, + TaskType.TEXT_EMBEDDING.toString(), + "ga", + null, + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + SimilarityMeasure.DOT_PRODUCT.toString(), + dimensions, + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + // Invalid chunking settings + Map.of("unexpected_field", "unexpected_value") + ) + ) + ) + ); + + var url = "url"; + + var auth = AuthorizationModel.of(response, url); + assertThat(auth.getEndpointIds(), is(Set.of(id1))); + assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); + assertTrue(auth.isAuthorized()); + + var chatCompletionEndpoint = new ElasticInferenceServiceCompletionModel( + id1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(name1), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + + assertThat(auth.getEndpoints(Set.of(id1, id2)), is(List.of(chatCompletionEndpoint))); + + assertThat(auth.getEndpoints(Set.of(id2)), is(List.of())); + assertThat(auth.getEndpoints(Set.of()), is(List.of())); + } + + public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { + var idChat = "id_chat"; + var idSparse = "id_sparse"; + var idDense = "id_dense"; + var idRerank = "id_rerank"; + + var nameChat = "chat_model"; + var nameSparse = "sparse_model"; + var nameDense = "dense_model"; + var nameRerank = "rerank_model"; + + var similarity = SimilarityMeasure.COSINE; + var dimensions = 256; + var elementType = DenseVectorFieldMapper.ElementType.FLOAT.toString(); + + var url = "base_url"; + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + idChat, + nameChat, + TaskType.CHAT_COMPLETION.toString(), + "ga", + null, + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + idSparse, + nameSparse, + TaskType.SPARSE_EMBEDDING.toString(), + "ga", + null, + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + idDense, + nameDense, + TaskType.TEXT_EMBEDDING.toString(), + "ga", + null, + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + similarity.toString(), + dimensions, + elementType, + null + ) + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + idRerank, + nameRerank, + TaskType.RERANK.toString(), + "ga", + null, + "", + null + ) + ) + ); + + var auth = AuthorizationModel.of(response, url); + + var endpoints = auth.getEndpoints(Set.of(idChat, idSparse, idDense, idRerank)); + assertThat(endpoints.size(), is(4)); + assertThat( + endpoints, + containsInAnyOrder( + new ElasticInferenceServiceCompletionModel( + idChat, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(nameChat), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ), + new ElasticInferenceServiceSparseEmbeddingsModel( + idSparse, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(nameSparse, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ElasticInferenceServiceDenseTextEmbeddingsModel( + idDense, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(nameDense, similarity, dimensions, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ElasticInferenceServiceRerankModel( + idRerank, + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings(nameRerank), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index d0d3a67b2d9d5..dc8eb780c801e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -11,19 +11,26 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -37,6 +44,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationEntityV2Tests.createAuthorizedEndpoint; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -82,20 +90,13 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { when(mockRegistry.isReady()).thenReturn(true); when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + var url = "eis-url"; + var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(0); listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) + AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url) ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -123,9 +124,15 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { capturedRequest.getModels(), is( List.of( - PreconfiguredEndpointModelAdapter.createModel( - InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2), - new ElasticInferenceServiceComponents("") + new ElasticInferenceServiceSparseEmbeddingsModel( + sparseModel.id(), + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(sparseModel.modelName(), null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS ) ) ) @@ -133,66 +140,18 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { } public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInferenceIdAlreadyExists() { - var mockRegistry = mock(ModelRegistry.class); - when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2, "id2")); + var url = "eis-url"; + var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - var mockClient = mock(Client.class); - when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); - - var poller = new AuthorizationPoller( - new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), - createWithEmptySettings(taskQueue.getThreadPool()), - mockAuthHandler, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - mockRegistry, - mockClient, - null - ); - - poller.sendAuthorizationRequest(); - verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); - } - - public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMapping() { var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of(sparseModel.id(), "id2")); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(0); listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - // This is a model id that does not exist in the preconfigured endpoints map so it will not be stored - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) + AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url) ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -216,25 +175,18 @@ public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMa } public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() { + var url = "eis-url"; + var completionModel = createAuthorizedEndpoint(TaskType.COMPLETION); + var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(0); listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - // EIS does not yet support completions so this model will be ignored - EnumSet.of(TaskType.COMPLETION) - ) - ) - ) - ) + AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(completionModel)), url) ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java deleted file mode 100644 index 18e937a290c2b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java +++ /dev/null @@ -1,385 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; - -import java.util.EnumSet; -import java.util.List; -import java.util.Set; - -import static org.hamcrest.Matchers.is; - -public class ElasticInferenceServiceAuthorizationModelTests extends ESTestCase { - public static ElasticInferenceServiceAuthorizationModel createEnabledAuth() { - return ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)) - ) - ) - ); - } - - public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { - assertFalse(ElasticInferenceServiceAuthorizationModel.newDisabledService().isAuthorized()); - } - - public void testExcludes_ModelsWithoutTaskTypes() { - var response = new ElasticInferenceServiceAuthorizationResponseEntity( - List.of(new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.noneOf(TaskType.class))) - ); - var auth = ElasticInferenceServiceAuthorizationModel.of(response); - assertTrue(auth.getAuthorizedTaskTypes().isEmpty()); - assertFalse(auth.isAuthorized()); - } - - public void testEnabledTaskTypes_MergesFromSeparateModels() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ) - ) - ); - assertThat(auth.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); - assertThat(auth.getAuthorizedModelIds(), is(Set.of("model-1", "model-2"))); - } - - public void testEnabledTaskTypes_FromSingleEntry() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ); - - assertThat(auth.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))); - assertThat(auth.getAuthorizedModelIds(), is(Set.of("model-1"))); - } - - public void testNewLimitToTaskTypes_SingleModel() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.CHAT_COMPLETION)) - ) - ) - ); - - assertThat( - auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ); - } - - public void testNewLimitToTaskTypes_MultipleModels_OnlyTextEmbedding() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.TEXT_EMBEDDING)) - ) - ) - ); - - assertThat( - auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-2", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ); - } - - public void testNewLimitToTaskTypes_MultipleModels_MultipleTaskTypes() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-text-sparse", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-sparse", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-chat-completion", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ); - - var limitedAuth = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)); - assertThat( - limitedAuth, - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-text-sparse", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-chat-completion", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ) - ); - } - - public void testNewLimitToTaskTypes_DuplicateModelNames() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK) - ) - ) - ) - ); - - var limitedAuth = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)); - assertThat( - limitedAuth, - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK) - ) - ) - ) - ) - ) - ); - } - - public void testNewLimitToTaskTypes_ReturnsDisabled_WhenNoOverlapForTaskTypes() { - var auth = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-2", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING) - ) - ) - ) - ); - - var limitedAuth = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.RERANK)); - assertThat(limitedAuth, is(ElasticInferenceServiceAuthorizationModel.newDisabledService())); - } - - public void testMerge_CombinesCorrectly() { - var auth1 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ); - - var auth2 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ) - ) - ); - - assertThat( - auth1.merge(auth2), - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-2", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ) - ); - } - - public void testMerge_AddsNewTaskType() { - var auth1 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ); - - var auth2 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.CHAT_COMPLETION)) - ) - ) - ); - - assertThat( - auth1.merge(auth2), - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-2", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ) - ); - } - - public void testMerge_IgnoresDuplicates() { - var auth1 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ); - - var auth2 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ); - - assertThat( - auth1.merge(auth2), - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ) - ); - } - - public void testMerge_CombinesCorrectlyWithEmptyModel() { - var auth1 = ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ); - - var auth2 = ElasticInferenceServiceAuthorizationModel.newDisabledService(); - - assertThat( - auth1.merge(auth2), - is( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ) - ); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index e3d24ea2ec8f7..850cf4aa1c716 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -77,12 +77,12 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger); try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertTrue(authResponse.getTaskTypes().isEmpty()); + assertTrue(authResponse.getEndpointIds().isEmpty()); assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); @@ -99,12 +99,12 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger); try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertTrue(authResponse.getTaskTypes().isEmpty()); + assertTrue(authResponse.getEndpointIds().isEmpty()); assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); @@ -135,7 +135,7 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep queueWebServerResponsesForRetries(responseJson); - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT)); @@ -181,12 +181,12 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getEndpointIds(), is(Set.of("model-a"))); assertTrue(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); @@ -203,8 +203,8 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { var logger = mock(Logger.class); var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); - PlainActionFuture listener = new PlainActionFuture<>(); - ActionListener onlyOnceListener = ActionListener.assertOnce(listener); + PlainActionFuture listener = new PlainActionFuture<>(); + ActionListener onlyOnceListener = ActionListener.assertOnce(listener); String responseJson = """ { "models": [ @@ -222,8 +222,8 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { authHandler.waitForAuthRequestCompletion(TIMEOUT); var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getEndpointIds(), is(Set.of("model-a"))); assertTrue(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); @@ -249,7 +249,7 @@ public void testGetAuthorization_InvalidResponse() throws IOException { var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger); try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java new file mode 100644 index 0000000000000..2ed0b38eb303b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.response; + +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; + +public class ElasticInferenceServiceAuthorizationEntityV2Tests extends ESTestCase { + public static ElasticInferenceServiceAuthorizationResponseEntityV2 createResponse() { + return new ElasticInferenceServiceAuthorizationResponseEntityV2( + randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) + ); + } + + public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { + var id = randomAlphaOfLength(10); + var name = randomAlphaOfLength(10); + var status = randomFrom("ga", "beta", "preview"); + + return switch (taskType) { + case CHAT_COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.CHAT_COMPLETION.toString(), + status, + null, + "", + null + ); + case SPARSE_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.SPARSE_EMBEDDING.toString(), + status, + null, + "", + null + ); + case TEXT_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.TEXT_EMBEDDING.toString(), + status, + null, + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + randomFrom(SimilarityMeasure.values()).toString(), + randomInt(), + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ); + case RERANK -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.RERANK.toString(), + status, + null, + "", + null + ); + case COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.COMPLETION.toString(), + status, + null, + "", + null + ); + default -> throw new IllegalArgumentException("Unsupported task type: " + taskType); + }; + } +} From 6a31be2d0e872eb34909e31a1493bb4833a19011 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 18 Nov 2025 11:21:52 -0500 Subject: [PATCH 03/24] Fixing tests --- .../authorization/AuthorizationPoller.java | 3 +- .../PreconfiguredEndpointModelAdapter.java | 44 --- ...eServiceAuthorizationResponseEntityV2.java | 74 +--- .../AuthorizationModelTests.java | 144 +++++--- .../AuthorizationPollerTests.java | 61 +-- ...rviceAuthorizationRequestHandlerTests.java | 74 +++- ...reconfiguredEndpointModelAdapterTests.java | 183 --------- ...enceServiceAuthorizationEntityV2Tests.java | 82 ----- ...iceAuthorizationResponseEntityV2Tests.java | 346 ++++++++++++++++++ 9 files changed, 523 insertions(+), 488 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 4dacafd7950df..9ab58d554ff82 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import java.util.EnumSet; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -242,7 +243,7 @@ void sendAuthorizationRequest() { private List getNewInferenceEndpointsToStore(AuthorizationModel authModel) { var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); - var newEndpointIds = scopedAuthModel.getEndpointIds(); + var newEndpointIds = new HashSet<>(scopedAuthModel.getEndpointIds()); var existingInferenceIds = modelRegistry.getInferenceIds(); newEndpointIds.removeAll(existingInferenceIds); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java deleted file mode 100644 index ab23da7cab5b2..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; - -import java.util.List; -import java.util.Set; - -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS; - -public class PreconfiguredEndpointModelAdapter { - public static List getModels(Set inferenceIds, ElasticInferenceServiceComponents elasticInferenceServiceComponents) { - return inferenceIds.stream() - .sorted() - .filter(EIS_PRECONFIGURED_ENDPOINT_IDS::contains) - .map(id -> createModel(InternalPreconfiguredEndpoints.getWithInferenceId(id), elasticInferenceServiceComponents)) - .toList(); - } - - public static Model createModel( - InternalPreconfiguredEndpoints.MinimalModel minimalModel, - ElasticInferenceServiceComponents elasticInferenceServiceComponents - ) { - return new ElasticInferenceServiceModel( - minimalModel.configurations(), - new ModelSecrets(EmptySecretSettings.INSTANCE), - minimalModel.rateLimitServiceSettings(), - elasticInferenceServiceComponents - ); - } - - private PreconfiguredEndpointModelAdapter() {} -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java index d4528c8401a0c..7d13eeeb89a6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java @@ -15,9 +15,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContent; @@ -31,7 +28,6 @@ import org.elasticsearch.xpack.inference.external.request.Request; import java.io.IOException; -import java.util.EnumSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -41,62 +37,6 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -/* - -{ - "inference_endpoints": [ - { - "id": ".rainbow-sprinkles-elastic", - "model_name": "rainbow-sprinkles", - "task_type": "chat", - "status": "ga", - "properties": [ - "multilingual" - ], - "release_date": "2024-05-01", - "end_of_life_date": "2025-12-31" - }, - { - "id": ".elastic-elser-v2", - "model_name": "elser_model_2", - "task_type": "embed/text/sparse", - "status": "preview", - "properties": [ - "english" - ], - "release_date": "2024-05-01", - "configuration": { - "chunking_settings": { - "strategy": "sentence", - "max_chunk_size": 250, - "sentence_overlap": 1 - } - } - }, - { - "id": ".jina-embeddings-v3", - "model_name": "jina-embeddings-v3", - "task_type": "embed/text/dense", - "status": "beta", - "properties": [ - "multilingual", - "open-weights" - ], - "release_date": "2024-05-01", - "configuration": { - "similarity": "cosine", - "dimension": 1024, - "element_type": "float", - "chunking_settings": { - "strategy": "sentence", - "max_chunk_size": 250, - "sentence_overlap": 1 - } - } - } - ] -} - */ public class ElasticInferenceServiceAuthorizationResponseEntityV2 implements InferenceServiceResults { public static final String NAME = "elastic_inference_service_auth_results_v2"; @@ -126,6 +66,7 @@ public record AuthorizedEndpoint( String status, @Nullable List properties, String releaseDate, + @Nullable String endOfLifeDate, @Nullable Configuration configuration ) implements Writeable, ToXContentObject { @@ -135,6 +76,7 @@ public record AuthorizedEndpoint( private static final String STATUS = "status"; private static final String PROPERTIES = "properties"; private static final String RELEASE_DATE = "release_date"; + private static final String END_OF_LIFE_DATE = "end_of_life_date"; private static final String CONFIGURATION = "configuration"; @SuppressWarnings("unchecked") @@ -148,7 +90,8 @@ public record AuthorizedEndpoint( (String) args[3], (List) args[4], (String) args[5], - (Configuration) args[6] + (String) args[6], + (Configuration) args[7] ) ); @@ -159,6 +102,7 @@ public record AuthorizedEndpoint( AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(STATUS)); AUTHORIZED_ENDPOINT_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(PROPERTIES)); AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(RELEASE_DATE)); + AUTHORIZED_ENDPOINT_PARSER.declareString(optionalConstructorArg(), new ParseField(END_OF_LIFE_DATE)); AUTHORIZED_ENDPOINT_PARSER.declareObject(optionalConstructorArg(), Configuration.PARSER::apply, new ParseField(CONFIGURATION)); } @@ -170,6 +114,7 @@ public AuthorizedEndpoint(StreamInput in) throws IOException { in.readString(), in.readOptionalCollectionAsList(StreamInput::readString), in.readString(), + in.readOptionalString(), in.readOptionalWriteable(Configuration::new) ); } @@ -182,6 +127,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(status); out.writeOptionalCollection(properties, StreamOutput::writeString); out.writeString(releaseDate); + out.writeOptionalString(endOfLifeDate); out.writeOptionalWriteable(configuration); } @@ -189,13 +135,14 @@ public void writeTo(StreamOutput out) throws IOException { public String toString() { return Strings.format( "AuthorizedEndpoint{id='%s', modelName='%s', taskType='%s', status='%s', " - + "properties=%s, releaseDate='%s', configuration=%s}", + + "properties=%s, releaseDate='%s', endOfLifeDate='%s', configuration=%s}", id, modelName, taskType, status, properties, releaseDate, + endOfLifeDate, configuration ); } @@ -212,6 +159,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PROPERTIES, properties); } builder.field(RELEASE_DATE, releaseDate); + if (endOfLifeDate != null) { + builder.field(END_OF_LIFE_DATE, endOfLifeDate); + } if (configuration != null) { builder.field(CONFIGURATION, configuration); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java index 2c363b04ea277..dce473017eb3a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; +import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -34,62 +35,6 @@ import static org.hamcrest.Matchers.containsInAnyOrder; public class AuthorizationModelTests extends ESTestCase { - public static AuthorizationModel createAuthorizationModel(TaskType taskType) { - String id = randomAlphaOfLength(10); - String name = randomAlphaOfLength(10); - String url = "url"; - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint endpoint; - var status = randomFrom("ga", "beta", "preview"); - - - switch (taskType) { - case CHAT_COMPLETION -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.CHAT_COMPLETION.toString(), - status, - null, - "", - null - ); - case SPARSE_EMBEDDING -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.SPARSE_EMBEDDING.toString(), - status, - null, - "", - null - ); - case TEXT_EMBEDDING -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.TEXT_EMBEDDING.toString(), - status, - null, - "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - randomFrom(SimilarityMeasure.values()).toString(), - randomInt(), - DenseVectorFieldMapper.ElementType.FLOAT.toString(), - null - ) - ); - case RERANK -> endpoint = new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.RERANK.toString(), - status, - null, - "", - null - ); - default -> throw new IllegalArgumentException("Unsupported task type: " + taskType); - } - - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(endpoint)); - return AuthorizationModel.of(response, url); - } public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { assertFalse(new AuthorizationModel(List.of()).isAuthorized()); @@ -106,6 +51,7 @@ public void testExcludes_EndpointsWithoutValidTaskTypes() { "ga", null, "", + "", null ), new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( @@ -115,6 +61,7 @@ public void testExcludes_EndpointsWithoutValidTaskTypes() { "ga", null, "", + "", null ) ) @@ -137,6 +84,7 @@ public void testReturnsAuthorizedTaskTypes() { "ga", null, "", + "", null ), new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( @@ -146,6 +94,7 @@ public void testReturnsAuthorizedTaskTypes() { "ga", null, "", + "", null ) ) @@ -169,6 +118,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { "ga", null, "", + "", null ), // This should be ignored because the id is a duplicate @@ -179,6 +129,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { "ga", null, "", + "", null ) ) @@ -209,6 +160,7 @@ public void testReturnsAuthorizedEndpoints() { "ga", null, "", + "", null ), new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( @@ -218,6 +170,7 @@ public void testReturnsAuthorizedEndpoints() { "ga", null, "", + "", new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( similarity.toString(), dimensions, @@ -256,12 +209,80 @@ public void testReturnsAuthorizedEndpoints() { ); assertThat(auth.getEndpoints(Set.of(id1, id2)), containsInAnyOrder(chatCompletionEndpoint, textEmbeddingEndpoint)); - assertThat(auth.getEndpoints(Set.of(id2)), is(List.of(textEmbeddingEndpoint))); - assertThat(auth.getEndpoints(Set.of()), is(List.of())); } + public void testScopesToTaskType() { + var id1 = "id1"; + var id2 = "id2"; + + var name1 = "name1"; + var name2 = "name2"; + + var similarity = SimilarityMeasure.COSINE; + var dimensions = 123; + + var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id1, + name1, + TaskType.CHAT_COMPLETION.toString(), + "ga", + null, + "", + "", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id2, + name2, + TaskType.TEXT_EMBEDDING.toString(), + "ga", + null, + "", + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + similarity.toString(), + dimensions, + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ) + ) + ); + + var url = "url"; + + var auth = AuthorizationModel.of(response, url); + assertThat(auth.getEndpointIds(), is(Set.of(id1, id2))); + assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING))); + assertTrue(auth.isAuthorized()); + + var scopedToChatCompletion = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.CHAT_COMPLETION)); + assertThat(scopedToChatCompletion.getEndpointIds(), is(Set.of(id1))); + assertThat(scopedToChatCompletion.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); + assertTrue(scopedToChatCompletion.isAuthorized()); + + var chatCompletionEndpoint = new ElasticInferenceServiceCompletionModel( + id1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(name1), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + + assertThat(auth.getEndpoints(Set.of(id1)), is(List.of(chatCompletionEndpoint))); + + var scopedToNone = auth.newLimitedToTaskTypes(EnumSet.noneOf(TaskType.class)); + assertThat(scopedToNone.getEndpointIds(), is(Set.of())); + assertThat(scopedToNone.getTaskTypes(), is(Set.of())); + assertFalse(scopedToNone.isAuthorized()); + } + public void testReturnsAuthorizedEndpoints_FiltersInvalid() { var id1 = "id1"; var id2 = "invalid_text_embedding"; @@ -280,6 +301,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { "ga", null, "", + "", null ), new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( @@ -289,6 +311,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { "ga", null, "", + "", new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( null, dimensions, @@ -303,6 +326,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { "ga", null, "", + "", new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( SimilarityMeasure.DOT_PRODUCT.toString(), dimensions, @@ -363,6 +387,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "ga", null, "", + "", null ), new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( @@ -372,6 +397,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "ga", null, "", + "", null ), new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( @@ -381,6 +407,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "ga", null, "", + "", new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( similarity.toString(), dimensions, @@ -395,6 +422,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "ga", null, "", + "", null ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index dc8eb780c801e..91aa697d1d9c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -26,15 +26,12 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.Before; import org.mockito.ArgumentCaptor; -import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -44,7 +41,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationEntityV2Tests.createAuthorizedEndpoint; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2Tests.createAuthorizedEndpoint; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -95,9 +92,7 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse( - AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url) - ); + listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -150,9 +145,7 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse( - AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url) - ); + listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -210,27 +203,19 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra } public void testSendsTwoAuthorizationRequests() throws InterruptedException { + var url = "eis-url"; + var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); + var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + // Since the registry is already aware of the sparse endpoint, the authorization poller will not consider it a new + // one and not attempt to store it. + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", sparseModel.id())); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - // this is an unknown model id so it won't trigger storing an inference endpoint because - // it doesn't map to a known one - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ); + ActionListener listener = invocation.getArgument(0); + listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -273,27 +258,19 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { } public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throws InterruptedException { + var url = "eis-url"; + var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); + var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + // Since the registry is already aware of the sparse endpoint, the authorization poller will not consider it a new + // one and not attempt to store it. + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", sparseModel.id())); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - // this is an unknown model id so it won't trigger storing an inference endpoint because - // it doesn't map to a known one - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ); + ActionListener listener = invocation.getArgument(0); + listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 850cf4aa1c716..f7f1e5add05a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -13,6 +13,8 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -20,11 +22,16 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.core.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -124,12 +131,16 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep try (var sender = senderFactory.createSender()) { String responseJson = """ { - "models": [ - { - "invalid-field": "model-a", - "task-types": ["embed/text/sparse", "chat"] - } - ] + "inference_endpoints": [ + { + "id": 123, + "model_name": "elastic-rerank-v1", + "task_type": "rerank", + "status": "preview", + "properties": [], + "release_date": "2024-05-01" + } + ] } """; @@ -139,13 +150,13 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep authHandler.getAuthorization(listener, sender); var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), containsString("failed to parse field [models]")); + assertThat(exception.getMessage(), containsString("failed to parse field [inference_endpoints]")); var stringCaptor = ArgumentCaptor.forClass(String.class); var exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture()); var message = stringCaptor.getValue(); - assertThat(message, containsString("failed to parse field [models]")); + assertThat(message, containsString("failed to parse field [inference_endpoints]")); var capturedException = exceptionCaptor.getValue(); assertThat(capturedException, instanceOf(XContentParseException.class)); @@ -207,12 +218,21 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { ActionListener onlyOnceListener = ActionListener.assertOnce(listener); String responseJson = """ { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] + "id": ".elastic-elser-v2", + "model_name": "elser_model_2", + "task_type": "sparse_embedding", + "status": "preview", + "properties": [ + "english" + ], + "release_date": "2024-05-01", + "configuration": { + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 + } + } } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); @@ -221,10 +241,32 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { authHandler.getAuthorization(onlyOnceListener, sender); authHandler.waitForAuthRequestCompletion(TIMEOUT); + var endpointId = ".elastic-elser-v2"; + var endpointName = "elser_model_2"; + var maxChunkSize = 250; + var sentenceOverlap = 1; + var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertThat(authResponse.getEndpointIds(), is(Set.of("model-a"))); + assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + assertThat(authResponse.getEndpointIds(), is(Set.of(endpointId))); assertTrue(authResponse.isAuthorized()); + assertThat( + authResponse.getEndpoints(Set.of(endpointId)), + is( + List.of( + new ElasticInferenceServiceSparseEmbeddingsModel( + endpointId, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(endpointName, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(eisGatewayUrl), + new SentenceBoundaryChunkingSettings(maxChunkSize, sentenceOverlap) + ) + ) + ) + ); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java deleted file mode 100644 index a3fc723309a9f..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; - -import java.util.Collections; -import java.util.List; -import java.util.Set; - -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_RERANK_ENDPOINT_ID_V1; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_RERANK_MODEL_ID_V1; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DENSE_TEXT_EMBEDDINGS_DIMENSIONS; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.GP_LLM_V2_MODEL_ID; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.defaultDenseTextEmbeddingsSimilarity; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; - -public class PreconfiguredEndpointModelAdapterTests extends ESTestCase { - - private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_SETTINGS = - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null); - private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SETTINGS = - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - private static final ElasticInferenceServiceCompletionServiceSettings GP_LLM_V2_COMPLETION_SETTINGS = - new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V2_MODEL_ID); - private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_SETTINGS = - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - defaultDenseTextEmbeddingsSimilarity(), - DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - null - ); - private static final ElasticInferenceServiceRerankServiceSettings RERANK_SETTINGS = new ElasticInferenceServiceRerankServiceSettings( - DEFAULT_RERANK_MODEL_ID_V1 - ); - private static final ElasticInferenceServiceComponents EIS_COMPONENTS = new ElasticInferenceServiceComponents(""); - - public void testGetModelsWithValidId() { - var endpointIds = Set.of( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, - DEFAULT_ELSER_ENDPOINT_ID_V2, - DEFAULT_RERANK_ENDPOINT_ID_V1, - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID - ); - var models = PreconfiguredEndpointModelAdapter.getModels(endpointIds, EIS_COMPONENTS); - - assertThat(models, hasSize(endpointIds.size())); - assertThat( - models, - containsInAnyOrder( - new ElasticInferenceServiceModel( - new ModelConfigurations( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - SPARSE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - SPARSE_SETTINGS, - EIS_COMPONENTS - ), - new ElasticInferenceServiceModel( - new ModelConfigurations( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - COMPLETION_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - COMPLETION_SETTINGS, - EIS_COMPONENTS - ), - new ElasticInferenceServiceModel( - new ModelConfigurations( - GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - GP_LLM_V2_COMPLETION_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - GP_LLM_V2_COMPLETION_SETTINGS, - EIS_COMPONENTS - ), - new ElasticInferenceServiceModel( - new ModelConfigurations( - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, - TaskType.TEXT_EMBEDDING, - ElasticInferenceService.NAME, - DENSE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - DENSE_SETTINGS, - EIS_COMPONENTS - ), - new ElasticInferenceServiceModel( - new ModelConfigurations( - DEFAULT_RERANK_ENDPOINT_ID_V1, - TaskType.RERANK, - ElasticInferenceService.NAME, - RERANK_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - RERANK_SETTINGS, - EIS_COMPONENTS - ) - ) - ); - } - - public void testGetModelsWithValidAndInvalidIds() { - var models = PreconfiguredEndpointModelAdapter.getModels( - Set.of(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, "some-invalid-id", DEFAULT_ELSER_ENDPOINT_ID_V2), - EIS_COMPONENTS - ); - - assertThat(models, hasSize(2)); - assertThat( - models, - containsInAnyOrder( - new ElasticInferenceServiceModel( - new ModelConfigurations( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - SPARSE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - SPARSE_SETTINGS, - EIS_COMPONENTS - ), - new ElasticInferenceServiceModel( - new ModelConfigurations( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - COMPLETION_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - new ModelSecrets(EmptySecretSettings.INSTANCE), - COMPLETION_SETTINGS, - EIS_COMPONENTS - ) - ) - ); - } - - public void testGetModelsWithOnlyInvalidId() { - assertThat(PreconfiguredEndpointModelAdapter.getModels(Collections.singleton("nonexistent-id"), EIS_COMPONENTS), is(List.of())); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java deleted file mode 100644 index 2ed0b38eb303b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationEntityV2Tests.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.response; - -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; - -public class ElasticInferenceServiceAuthorizationEntityV2Tests extends ESTestCase { - public static ElasticInferenceServiceAuthorizationResponseEntityV2 createResponse() { - return new ElasticInferenceServiceAuthorizationResponseEntityV2( - randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) - ); - } - - public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { - var id = randomAlphaOfLength(10); - var name = randomAlphaOfLength(10); - var status = randomFrom("ga", "beta", "preview"); - - return switch (taskType) { - case CHAT_COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.CHAT_COMPLETION.toString(), - status, - null, - "", - null - ); - case SPARSE_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.SPARSE_EMBEDDING.toString(), - status, - null, - "", - null - ); - case TEXT_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.TEXT_EMBEDDING.toString(), - status, - null, - "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - randomFrom(SimilarityMeasure.values()).toString(), - randomInt(), - DenseVectorFieldMapper.ElementType.FLOAT.toString(), - null - ) - ); - case RERANK -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.RERANK.toString(), - status, - null, - "", - null - ); - case COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - id, - name, - TaskType.COMPLETION.toString(), - status, - null, - "", - null - ); - default -> throw new IllegalArgumentException("Unsupported task type: " + taskType); - }; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java new file mode 100644 index 0000000000000..7ba368a9cf832 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java @@ -0,0 +1,346 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.response; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.chunking.SentenceBoundaryChunkingSettings; +import org.elasticsearch.xpack.core.inference.chunking.WordBoundaryChunkingSettings; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceAuthorizationResponseEntityV2Tests extends AbstractBWCWireSerializationTestCase< + ElasticInferenceServiceAuthorizationResponseEntityV2> { + public static ElasticInferenceServiceAuthorizationResponseEntityV2 createResponse() { + return new ElasticInferenceServiceAuthorizationResponseEntityV2( + randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) + ); + } + + public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { + var id = randomAlphaOfLength(10); + var name = randomAlphaOfLength(10); + var status = randomFrom("ga", "beta", "preview"); + + return switch (taskType) { + case CHAT_COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.CHAT_COMPLETION.toString(), + status, + null, + "", + "", + null + ); + case SPARSE_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.SPARSE_EMBEDDING.toString(), + status, + null, + "", + "", + null + ); + case TEXT_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.TEXT_EMBEDDING.toString(), + status, + null, + "", + "", + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + randomFrom(SimilarityMeasure.values()).toString(), + randomInt(), + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ); + case RERANK -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.RERANK.toString(), + status, + null, + "", + "", + null + ); + case COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + id, + name, + TaskType.COMPLETION.toString(), + status, + null, + "", + "", + null + ); + default -> throw new IllegalArgumentException("Unsupported task type: " + taskType); + }; + } + + public void testParseAllFields() throws IOException { + var json = """ + { + "inference_endpoints": [ + { + "id": ".rainbow-sprinkles-elastic", + "model_name": "rainbow-sprinkles", + "task_type": "chat_completion", + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01", + "end_of_life_date": "2025-12-31" + }, + { + "id": ".elastic-elser-v2", + "model_name": "elser_model_2", + "task_type": "sparse_embedding", + "status": "preview", + "properties": [ + "english" + ], + "release_date": "2024-05-01", + "configuration": { + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 + } + } + }, + { + "id": ".jina-embeddings-v3", + "model_name": "jina-embeddings-v3", + "task_type": "text_embedding", + "status": "beta", + "properties": [ + "multilingual", + "open-weights" + ], + "release_date": "2024-05-01", + "configuration": { + "similarity": "cosine", + "dimensions": 1024, + "element_type": "float", + "chunking_settings": { + "strategy": "word", + "max_chunk_size": 500, + "overlap": 2 + } + } + }, + { + "id": ".elastic-rerank-v1", + "model_name": "elastic-rerank-v1", + "task_type": "rerank", + "status": "preview", + "properties": [], + "release_date": "2024-05-01" + } + ] + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, json)) { + var entity = ElasticInferenceServiceAuthorizationResponseEntityV2.PARSER.apply(parser, null); + + var rainbowSprinklesId = ".rainbow-sprinkles-elastic"; + var elserModelId = ".elastic-elser-v2"; + var jinaEmbeddingsId = ".jina-embeddings-v3"; + var elasticRerankId = ".elastic-rerank-v1"; + + var rainbowSprinklesModelName = "rainbow-sprinkles"; + var elserModelName = "elser_model_2"; + var jinaEmbeddingsModelName = "jina-embeddings-v3"; + var elasticRerankModelName = "elastic-rerank-v1"; + + var elserMaxChunkSize = 250; + var elserSentenceOverlap = 1; + + var jinaDimensions = 1024; + var jinaElementType = DenseVectorFieldMapper.ElementType.FLOAT; + var jinaSimilarity = SimilarityMeasure.COSINE; + var jinaMaxChunkSize = 500; + var jinaOverlap = 2; + + var expected = new ElasticInferenceServiceAuthorizationResponseEntityV2( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + rainbowSprinklesId, + rainbowSprinklesModelName, + "chat_completion", + "ga", + List.of("multilingual"), + "2024-05-01", + "2025-12-31", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + elserModelId, + elserModelName, + "sparse_embedding", + "preview", + List.of("english"), + "2024-05-01", + null, + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + null, + null, + null, + Map.of("strategy", "sentence", "max_chunk_size", elserMaxChunkSize, "sentence_overlap", elserSentenceOverlap) + ) + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + jinaEmbeddingsId, + jinaEmbeddingsModelName, + "text_embedding", + "beta", + List.of("multilingual", "open-weights"), + "2024-05-01", + null, + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + jinaSimilarity.toString(), + jinaDimensions, + jinaElementType.toString(), + Map.of("strategy", "word", "max_chunk_size", jinaMaxChunkSize, "overlap", jinaOverlap) + ) + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + elasticRerankId, + elasticRerankModelName, + "rerank", + "preview", + List.of(), + "2024-05-01", + null, + null + ) + ) + ); + + assertThat(entity, is(expected)); + + var url = "http://example.com/authorize"; + var authModel = AuthorizationModel.of(expected, url); + + assertThat(authModel.getEndpointIds(), containsInAnyOrder(rainbowSprinklesId, elasticRerankId, elserModelId, jinaEmbeddingsId)); + assertThat( + authModel.getTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK)) + ); + + assertThat( + authModel.getEndpoints(Set.of(rainbowSprinklesId, jinaEmbeddingsId, elasticRerankId, elserModelId)), + containsInAnyOrder( + new ElasticInferenceServiceCompletionModel( + rainbowSprinklesId, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(rainbowSprinklesModelName), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ), + new ElasticInferenceServiceSparseEmbeddingsModel( + elserModelId, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(elserModelName, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new SentenceBoundaryChunkingSettings(elserMaxChunkSize, elserSentenceOverlap) + ), + new ElasticInferenceServiceDenseTextEmbeddingsModel( + jinaEmbeddingsId, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + jinaEmbeddingsModelName, + jinaSimilarity, + jinaDimensions, + null + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new WordBoundaryChunkingSettings(jinaMaxChunkSize, jinaOverlap) + ), + new ElasticInferenceServiceRerankModel( + elasticRerankId, + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings(elasticRerankModelName), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ) + ) + ); + } + } + + @Override + protected ElasticInferenceServiceAuthorizationResponseEntityV2 mutateInstanceForVersion( + ElasticInferenceServiceAuthorizationResponseEntityV2 instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceAuthorizationResponseEntityV2::new; + } + + @Override + protected ElasticInferenceServiceAuthorizationResponseEntityV2 createTestInstance() { + return createResponse(); + } + + @Override + protected ElasticInferenceServiceAuthorizationResponseEntityV2 mutateInstance( + ElasticInferenceServiceAuthorizationResponseEntityV2 instance + ) throws IOException { + var newEndpoints = new ArrayList<>(instance.getAuthorizedEndpoints()); + newEndpoints.add(createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))); + return new ElasticInferenceServiceAuthorizationResponseEntityV2(newEndpoints); + } +} From 70ef1a57225878b56ec2b3632cf9463adb8cee4f Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 18 Nov 2025 16:51:44 +0000 Subject: [PATCH 04/24] [CI] Auto commit changes from spotless --- .../completion/ElasticInferenceServiceCompletionModel.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 31d0e171ac673..fef27fa760bfa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.elastic.completion; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; From 20d7881b480ee995ec198d591a6eeed949366732 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 18 Nov 2025 15:00:03 -0500 Subject: [PATCH 05/24] Successful tests --- ...rviceAuthorizationRequestHandlerTests.java | 32 +- ...iceAuthorizationResponseEntityV2Tests.java | 383 +++++++++--------- 2 files changed, 213 insertions(+), 202 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index f7f1e5add05a4..920fabd59bc99 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -30,6 +30,8 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2Tests; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.After; @@ -47,6 +49,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -179,26 +182,25 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); try (var sender = senderFactory.createSender()) { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; + var responseData = ElasticInferenceServiceAuthorizationResponseEntityV2Tests.EisAuthorizationResponseData + .getEisAuthorizationData(eisGatewayUrl); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertThat(authResponse.getEndpointIds(), is(Set.of("model-a"))); + assertThat( + authResponse.getTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK)) + ); + assertThat(authResponse.getEndpointIds(), containsInAnyOrder(responseData.inferenceIds().toArray(String[]::new))); assertTrue(authResponse.isAuthorized()); + assertThat( + authResponse.getEndpoints(responseData.inferenceIds()), + containsInAnyOrder(responseData.expectedEndpoints().toArray(ElasticInferenceServiceModel[]::new)) + ); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); @@ -217,6 +219,8 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); ActionListener onlyOnceListener = ActionListener.assertOnce(listener); String responseJson = """ + { + "inference_endpoints": [ { "id": ".elastic-elser-v2", "model_name": "elser_model_2", @@ -234,6 +238,8 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { } } } + ] + } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java index 7ba368a9cf832..026819c22647b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; @@ -36,12 +37,197 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; public class ElasticInferenceServiceAuthorizationResponseEntityV2Tests extends AbstractBWCWireSerializationTestCase< ElasticInferenceServiceAuthorizationResponseEntityV2> { + + public static String EIS_AUTHORIZATION_RESPONSE_V2 = """ + { + "inference_endpoints": [ + { + "id": ".rainbow-sprinkles-elastic", + "model_name": "rainbow-sprinkles", + "task_type": "chat_completion", + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01", + "end_of_life_date": "2025-12-31" + }, + { + "id": ".elastic-elser-v2", + "model_name": "elser_model_2", + "task_type": "sparse_embedding", + "status": "preview", + "properties": [ + "english" + ], + "release_date": "2024-05-01", + "configuration": { + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 + } + } + }, + { + "id": ".jina-embeddings-v3", + "model_name": "jina-embeddings-v3", + "task_type": "text_embedding", + "status": "beta", + "properties": [ + "multilingual", + "open-weights" + ], + "release_date": "2024-05-01", + "configuration": { + "similarity": "cosine", + "dimensions": 1024, + "element_type": "float", + "chunking_settings": { + "strategy": "word", + "max_chunk_size": 500, + "overlap": 2 + } + } + }, + { + "id": ".elastic-rerank-v1", + "model_name": "elastic-rerank-v1", + "task_type": "rerank", + "status": "preview", + "properties": [], + "release_date": "2024-05-01" + } + ] + } + """; + + public record EisAuthorizationResponseData( + String responseJson, + ElasticInferenceServiceAuthorizationResponseEntityV2 responseEntity, + List expectedEndpoints, + Set inferenceIds + ) { + + public static EisAuthorizationResponseData getEisAuthorizationData(String url) { + + var authorizedEndpoints = List.of( + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + ".rainbow-sprinkles-elastic", + "rainbow-sprinkles", + "chat_completion", + "ga", + List.of("multilingual"), + "2024-05-01", + "2025-12-31", + null + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + ".elastic-elser-v2", + "elser_model_2", + "sparse_embedding", + "preview", + List.of("english"), + "2024-05-01", + null, + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + null, + null, + null, + Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) + ) + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + ".jina-embeddings-v3", + "jina-embeddings-v3", + "text_embedding", + "beta", + List.of("multilingual", "open-weights"), + "2024-05-01", + null, + new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + "cosine", + 1024, + "float", + Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2) + ) + ), + new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + ".elastic-rerank-v1", + "elastic-rerank-v1", + "rerank", + "preview", + List.of(), + "2024-05-01", + null, + null + ) + ); + + var inferenceIds = authorizedEndpoints.stream() + .map(ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); + + return new EisAuthorizationResponseData( + EIS_AUTHORIZATION_RESPONSE_V2, + new ElasticInferenceServiceAuthorizationResponseEntityV2(authorizedEndpoints), + List.of( + new ElasticInferenceServiceCompletionModel( + ".rainbow-sprinkles-elastic", + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ), + new ElasticInferenceServiceSparseEmbeddingsModel( + ".elastic-elser-v2", + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new SentenceBoundaryChunkingSettings(250, 1) + ), + new ElasticInferenceServiceDenseTextEmbeddingsModel( + ".jina-embeddings-v3", + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + "jina-embeddings-v3", + SimilarityMeasure.COSINE, + 1024, + null + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new WordBoundaryChunkingSettings(500, 2) + ), + new ElasticInferenceServiceRerankModel( + ".elastic-rerank-v1", + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings("elastic-rerank-v1"), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ) + ), + inferenceIds + ); + } + } + public static ElasticInferenceServiceAuthorizationResponseEntityV2 createResponse() { return new ElasticInferenceServiceAuthorizationResponseEntityV2( randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) @@ -114,205 +300,24 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEnd } public void testParseAllFields() throws IOException { - var json = """ - { - "inference_endpoints": [ - { - "id": ".rainbow-sprinkles-elastic", - "model_name": "rainbow-sprinkles", - "task_type": "chat_completion", - "status": "ga", - "properties": [ - "multilingual" - ], - "release_date": "2024-05-01", - "end_of_life_date": "2025-12-31" - }, - { - "id": ".elastic-elser-v2", - "model_name": "elser_model_2", - "task_type": "sparse_embedding", - "status": "preview", - "properties": [ - "english" - ], - "release_date": "2024-05-01", - "configuration": { - "chunking_settings": { - "strategy": "sentence", - "max_chunk_size": 250, - "sentence_overlap": 1 - } - } - }, - { - "id": ".jina-embeddings-v3", - "model_name": "jina-embeddings-v3", - "task_type": "text_embedding", - "status": "beta", - "properties": [ - "multilingual", - "open-weights" - ], - "release_date": "2024-05-01", - "configuration": { - "similarity": "cosine", - "dimensions": 1024, - "element_type": "float", - "chunking_settings": { - "strategy": "word", - "max_chunk_size": 500, - "overlap": 2 - } - } - }, - { - "id": ".elastic-rerank-v1", - "model_name": "elastic-rerank-v1", - "task_type": "rerank", - "status": "preview", - "properties": [], - "release_date": "2024-05-01" - } - ] - } - """; - try (var parser = createParser(JsonXContent.jsonXContent, json)) { + var url = "http://example.com/authorize"; + var responseData = EisAuthorizationResponseData.getEisAuthorizationData(url); + try (var parser = createParser(JsonXContent.jsonXContent, responseData.responseJson())) { var entity = ElasticInferenceServiceAuthorizationResponseEntityV2.PARSER.apply(parser, null); - var rainbowSprinklesId = ".rainbow-sprinkles-elastic"; - var elserModelId = ".elastic-elser-v2"; - var jinaEmbeddingsId = ".jina-embeddings-v3"; - var elasticRerankId = ".elastic-rerank-v1"; - - var rainbowSprinklesModelName = "rainbow-sprinkles"; - var elserModelName = "elser_model_2"; - var jinaEmbeddingsModelName = "jina-embeddings-v3"; - var elasticRerankModelName = "elastic-rerank-v1"; - - var elserMaxChunkSize = 250; - var elserSentenceOverlap = 1; - - var jinaDimensions = 1024; - var jinaElementType = DenseVectorFieldMapper.ElementType.FLOAT; - var jinaSimilarity = SimilarityMeasure.COSINE; - var jinaMaxChunkSize = 500; - var jinaOverlap = 2; - - var expected = new ElasticInferenceServiceAuthorizationResponseEntityV2( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - rainbowSprinklesId, - rainbowSprinklesModelName, - "chat_completion", - "ga", - List.of("multilingual"), - "2024-05-01", - "2025-12-31", - null - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - elserModelId, - elserModelName, - "sparse_embedding", - "preview", - List.of("english"), - "2024-05-01", - null, - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - null, - null, - null, - Map.of("strategy", "sentence", "max_chunk_size", elserMaxChunkSize, "sentence_overlap", elserSentenceOverlap) - ) - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - jinaEmbeddingsId, - jinaEmbeddingsModelName, - "text_embedding", - "beta", - List.of("multilingual", "open-weights"), - "2024-05-01", - null, - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - jinaSimilarity.toString(), - jinaDimensions, - jinaElementType.toString(), - Map.of("strategy", "word", "max_chunk_size", jinaMaxChunkSize, "overlap", jinaOverlap) - ) - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - elasticRerankId, - elasticRerankModelName, - "rerank", - "preview", - List.of(), - "2024-05-01", - null, - null - ) - ) - ); - - assertThat(entity, is(expected)); + assertThat(entity, is(responseData.responseEntity())); - var url = "http://example.com/authorize"; - var authModel = AuthorizationModel.of(expected, url); + var authModel = AuthorizationModel.of(responseData.responseEntity(), url); + assertThat(authModel.getEndpointIds(), containsInAnyOrder(responseData.inferenceIds().toArray(String[]::new))); - assertThat(authModel.getEndpointIds(), containsInAnyOrder(rainbowSprinklesId, elasticRerankId, elserModelId, jinaEmbeddingsId)); assertThat( authModel.getTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK)) ); - assertThat( - authModel.getEndpoints(Set.of(rainbowSprinklesId, jinaEmbeddingsId, elasticRerankId, elserModelId)), - containsInAnyOrder( - new ElasticInferenceServiceCompletionModel( - rainbowSprinklesId, - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - new ElasticInferenceServiceCompletionServiceSettings(rainbowSprinklesModelName), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) - ), - new ElasticInferenceServiceSparseEmbeddingsModel( - elserModelId, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(elserModelName, null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new SentenceBoundaryChunkingSettings(elserMaxChunkSize, elserSentenceOverlap) - ), - new ElasticInferenceServiceDenseTextEmbeddingsModel( - jinaEmbeddingsId, - TaskType.TEXT_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - jinaEmbeddingsModelName, - jinaSimilarity, - jinaDimensions, - null - ), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new WordBoundaryChunkingSettings(jinaMaxChunkSize, jinaOverlap) - ), - new ElasticInferenceServiceRerankModel( - elasticRerankId, - TaskType.RERANK, - ElasticInferenceService.NAME, - new ElasticInferenceServiceRerankServiceSettings(elasticRerankModelName), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) - ) - ) + authModel.getEndpoints(responseData.inferenceIds()), + containsInAnyOrder(responseData.expectedEndpoints().toArray(ElasticInferenceServiceModel[]::new)) ); } } From 2546292826502e8b4fd8a5d18cd8719824c33a1d Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 18 Nov 2025 16:11:52 -0500 Subject: [PATCH 06/24] Removing unused code --- .../authorization/AuthorizationModel.java | 36 +- ...ticInferenceServiceAuthorizationModel.java | 176 ---------- ...nceServiceAuthorizationRequestHandler.java | 6 +- ...ava => AuthorizationResponseEntityV2.java} | 24 +- ...nceServiceAuthorizationResponseEntity.java | 194 ----------- .../http/sender/HttpRequestSenderTests.java | 37 +- .../AuthorizationModelTests.java | 83 ++--- .../AuthorizationPollerTests.java | 138 ++------ ...rviceAuthorizationRequestHandlerTests.java | 67 +--- ...renceServiceAuthorizationRequestTests.java | 7 +- ...> AuthorizationResponseEntityV2Tests.java} | 320 +++++++++++------- ...rviceAuthorizationResponseEntityTests.java | 64 ---- 12 files changed, 306 insertions(+), 846 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/{ElasticInferenceServiceAuthorizationResponseEntityV2.java => AuthorizationResponseEntityV2.java} (92%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/{ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java => AuthorizationResponseEntityV2Tests.java} (50%) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java index 9608c556e4d0e..e11448bca18ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -48,13 +48,13 @@ public class AuthorizationModel { private static final Logger logger = LogManager.getLogger(AuthorizationModel.class); private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping"; - public static AuthorizationModel of(ElasticInferenceServiceAuthorizationResponseEntityV2 responseEntity, String baseEisUrl) { + public static AuthorizationModel of(AuthorizationResponseEntityV2 responseEntity, String baseEisUrl) { var components = new ElasticInferenceServiceComponents(baseEisUrl); return createInternal(responseEntity.getAuthorizedEndpoints(), components); } private static AuthorizationModel createInternal( - List responseEndpoints, + List responseEndpoints, ElasticInferenceServiceComponents components ) { var validEndpoints = new ArrayList(); @@ -69,7 +69,7 @@ private static AuthorizationModel createInternal( } private static ElasticInferenceServiceModel createModel( - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { try { @@ -110,7 +110,7 @@ private static TaskType getTaskType(String taskType) { } private static ElasticInferenceServiceCompletionModel createCompletionModel( - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceCompletionModel( @@ -125,7 +125,7 @@ private static ElasticInferenceServiceCompletionModel createCompletionModel( } private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddingsModel( - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceSparseEmbeddingsModel( @@ -140,24 +140,24 @@ private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddin ); } - private static ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration getConfigurationOrEmpty( - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint + private static AuthorizationResponseEntityV2.Configuration getConfigurationOrEmpty( + AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint ) { if (authorizedEndpoint.configuration() != null) { return authorizedEndpoint.configuration(); } - return ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.EMPTY; + return AuthorizationResponseEntityV2.Configuration.EMPTY; } private static Map getChunkingSettingsMap( - ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration configuration + AuthorizationResponseEntityV2.Configuration configuration ) { return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>()); } private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEmbeddingsModel( - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { var config = getConfigurationOrEmpty(authorizedEndpoint); @@ -180,19 +180,19 @@ private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEm ); } - private static void validateConfigurationForTextEmbedding(ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration config) { + private static void validateConfigurationForTextEmbedding(AuthorizationResponseEntityV2.Configuration config) { validateFieldPresent( - ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.ELEMENT_TYPE, + AuthorizationResponseEntityV2.Configuration.ELEMENT_TYPE, config.elementType(), TaskType.TEXT_EMBEDDING ); validateFieldPresent( - ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.DIMENSIONS, + AuthorizationResponseEntityV2.Configuration.DIMENSIONS, config.dimensions(), TaskType.TEXT_EMBEDDING ); validateFieldPresent( - ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.SIMILARITY, + AuthorizationResponseEntityV2.Configuration.SIMILARITY, config.similarity(), TaskType.TEXT_EMBEDDING ); @@ -207,10 +207,10 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy } private static SimilarityMeasure getSimilarityMeasure( - ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration configuration + AuthorizationResponseEntityV2.Configuration configuration ) { validateFieldPresent( - ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration.SIMILARITY, + AuthorizationResponseEntityV2.Configuration.SIMILARITY, configuration.similarity(), TaskType.TEXT_EMBEDDING ); @@ -219,7 +219,7 @@ private static SimilarityMeasure getSimilarityMeasure( } private static ElasticInferenceServiceRerankModel createRerankModel( - ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceRerankModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java deleted file mode 100644 index e8dc9c12c94f1..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; - -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format for consumption by the service. - */ -public class ElasticInferenceServiceAuthorizationModel { - - private final Map> taskTypeToModels; - private final EnumSet authorizedTaskTypes; - private final Set authorizedModelIds; - - /** - * Converts an authorization response from Elastic Inference Service into the {@link ElasticInferenceServiceAuthorizationModel} format. - * - * @param responseEntity the {@link ElasticInferenceServiceAuthorizationResponseEntity} response from the upstream gateway. - * @return a new {@link ElasticInferenceServiceAuthorizationModel} - */ - public static ElasticInferenceServiceAuthorizationModel of(ElasticInferenceServiceAuthorizationResponseEntity responseEntity) { - var taskTypeToModelsMap = new HashMap>(); - var enabledTaskTypesSet = EnumSet.noneOf(TaskType.class); - var enabledModelsSet = new HashSet(); - - for (var model : responseEntity.getAuthorizedModels()) { - // if there are no task types we'll ignore the model because it's likely we didn't understand - // the task type and don't support it anyway - if (model.taskTypes().isEmpty() == false) { - for (var taskType : model.taskTypes()) { - taskTypeToModelsMap.merge(taskType, Set.of(model.modelName()), (existingModelIds, newModelIds) -> { - var combinedNames = new HashSet<>(existingModelIds); - combinedNames.addAll(newModelIds); - return combinedNames; - }); - enabledTaskTypesSet.add(taskType); - } - enabledModelsSet.add(model.modelName()); - } - } - - return new ElasticInferenceServiceAuthorizationModel(taskTypeToModelsMap, enabledModelsSet, enabledTaskTypesSet); - } - - /** - * Returns an object indicating that the cluster has no access to Elastic Inference Service. - */ - public static ElasticInferenceServiceAuthorizationModel newDisabledService() { - return new ElasticInferenceServiceAuthorizationModel(Map.of(), Set.of(), EnumSet.noneOf(TaskType.class)); - } - - private ElasticInferenceServiceAuthorizationModel( - Map> taskTypeToModels, - Set authorizedModelIds, - EnumSet authorizedTaskTypes - ) { - this.taskTypeToModels = Objects.requireNonNull(taskTypeToModels); - this.authorizedModelIds = Objects.requireNonNull(authorizedModelIds); - this.authorizedTaskTypes = Objects.requireNonNull(authorizedTaskTypes); - } - - /** - * Returns true if at least one task type and model is authorized. - * @return true if this cluster is authorized for at least one model and task type. - */ - public boolean isAuthorized() { - return authorizedModelIds.isEmpty() == false && taskTypeToModels.isEmpty() == false && authorizedTaskTypes.isEmpty() == false; - } - - public Set getAuthorizedModelIds() { - return Set.copyOf(authorizedModelIds); - } - - public EnumSet getAuthorizedTaskTypes() { - return EnumSet.copyOf(authorizedTaskTypes); - } - - /** - * Returns a new {@link ElasticInferenceServiceAuthorizationModel} object retaining only the specified task types - * and applicable models that leverage those task types. Any task types not specified in the passed in set will be - * excluded from the returned object. This is essentially an intersection. - * @param taskTypes the task types to retain in the newly created object - * @return a new object containing models and task types limited to the specified set. - */ - public ElasticInferenceServiceAuthorizationModel newLimitedToTaskTypes(EnumSet taskTypes) { - var newTaskTypeToModels = new HashMap>(); - var taskTypesThatHaveModels = EnumSet.noneOf(TaskType.class); - - for (var taskType : taskTypes) { - var models = taskTypeToModels.get(taskType); - if (models != null) { - newTaskTypeToModels.put(taskType, models); - // we only want task types that correspond to actual models to ensure we're only enabling valid task types - taskTypesThatHaveModels.add(taskType); - } - } - - return new ElasticInferenceServiceAuthorizationModel( - newTaskTypeToModels, - enabledModels(newTaskTypeToModels), - taskTypesThatHaveModels - ); - } - - private static Set enabledModels(Map> taskTypeToModels) { - return taskTypeToModels.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); - } - - /** - * Returns a new {@link ElasticInferenceServiceAuthorizationModel} that combines the current model and the passed in one. - * @param other model to merge into this one - * @return a new model - */ - public ElasticInferenceServiceAuthorizationModel merge(ElasticInferenceServiceAuthorizationModel other) { - Map> newTaskTypeToModels = taskTypeToModels.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> new HashSet<>(e.getValue()))); - - for (var entry : other.taskTypeToModels.entrySet()) { - newTaskTypeToModels.merge(entry.getKey(), new HashSet<>(entry.getValue()), (existingModelIds, newModelIds) -> { - existingModelIds.addAll(newModelIds); - return existingModelIds; - }); - } - - var newAuthorizedTaskTypes = authorizedTaskTypes.isEmpty() ? EnumSet.noneOf(TaskType.class) : EnumSet.copyOf(authorizedTaskTypes); - newAuthorizedTaskTypes.addAll(other.authorizedTaskTypes); - - return new ElasticInferenceServiceAuthorizationModel( - newTaskTypeToModels, - enabledModels(newTaskTypeToModels), - newAuthorizedTaskTypes - ); - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) return false; - ElasticInferenceServiceAuthorizationModel that = (ElasticInferenceServiceAuthorizationModel) o; - return Objects.equals(taskTypeToModels, that.taskTypeToModels) - && Objects.equals(authorizedTaskTypes, that.authorizedTaskTypes) - && Objects.equals(authorizedModelIds, that.authorizedModelIds); - } - - @Override - public int hashCode() { - return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds); - } - - @Override - public String toString() { - return "{" - + "taskTypeToModels=" - + taskTypeToModels - + ", authorizedTaskTypes=" - + authorizedTaskTypes - + ", authorizedModelIds=" - + authorizedModelIds - + '}'; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 6476925da7f10..15b97e66ec8de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; @@ -47,7 +47,7 @@ public class ElasticInferenceServiceAuthorizationRequestHandler { private static ResponseHandler createAuthResponseHandler() { return new ElasticInferenceServiceResponseHandler( Strings.format("%s authorization", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - ElasticInferenceServiceAuthorizationResponseEntityV2::fromResponse + AuthorizationResponseEntityV2::fromResponse ); } @@ -119,7 +119,7 @@ public void getAuthorization(ActionListener listener, Sender sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); }) .andThenApply(authResult -> { - if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntityV2 authResponseEntity) { + if (authResult instanceof AuthorizationResponseEntityV2 authResponseEntity) { logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); return AuthorizationModel.of(authResponseEntity, baseUrl); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2.java similarity index 92% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2.java index 7d13eeeb89a6b..d7d058a03f125 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2.java @@ -37,19 +37,18 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -public class ElasticInferenceServiceAuthorizationResponseEntityV2 implements InferenceServiceResults { +public class AuthorizationResponseEntityV2 implements InferenceServiceResults { public static final String NAME = "elastic_inference_service_auth_results_v2"; private static final String INFERENCE_ENDPOINTS = "inference_endpoints"; @SuppressWarnings("unchecked") - public static ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - ElasticInferenceServiceAuthorizationResponseEntityV2.class.getSimpleName(), - true, - args -> new ElasticInferenceServiceAuthorizationResponseEntityV2((List) args[0]) - ); + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + AuthorizationResponseEntityV2.class.getSimpleName(), + true, + args -> new AuthorizationResponseEntityV2((List) args[0]) + ); static { PARSER.declareObjectArray( @@ -292,23 +291,22 @@ public String toString() { private final List authorizedEndpoints; - public ElasticInferenceServiceAuthorizationResponseEntityV2(List authorizedModels) { + public AuthorizationResponseEntityV2(List authorizedModels) { this.authorizedEndpoints = Objects.requireNonNull(authorizedModels); } /** * Create an empty response */ - public ElasticInferenceServiceAuthorizationResponseEntityV2() { + public AuthorizationResponseEntityV2() { this(List.of()); } - public ElasticInferenceServiceAuthorizationResponseEntityV2(StreamInput in) throws IOException { + public AuthorizationResponseEntityV2(StreamInput in) throws IOException { this(in.readCollectionAsList(AuthorizedEndpoint::new)); } - public static ElasticInferenceServiceAuthorizationResponseEntityV2 fromResponse(Request request, HttpResult response) - throws IOException { + public static AuthorizationResponseEntityV2 fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -354,7 +352,7 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ElasticInferenceServiceAuthorizationResponseEntityV2 that = (ElasticInferenceServiceAuthorizationResponseEntityV2) o; + AuthorizationResponseEntityV2 that = (AuthorizationResponseEntityV2) o; return Objects.equals(authorizedEndpoints, that.authorizedEndpoints); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java deleted file mode 100644 index 4e2eec9de0456..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.response; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.request.Request; - -import java.io.IOException; -import java.util.EnumSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; - -public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults { - - public static final String NAME = "elastic_inference_service_auth_results"; - - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class); - private static final String AUTH_FIELD_NAME = "authorized_models"; - private static final Map ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of( - "embed/text/sparse", - TaskType.SPARSE_EMBEDDING, - "chat", - TaskType.CHAT_COMPLETION, - "embed/text/dense", - TaskType.TEXT_EMBEDDING, - "rerank/text/text-similarity", - TaskType.RERANK - ); - - @SuppressWarnings("unchecked") - public static ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - ElasticInferenceServiceAuthorizationResponseEntity.class.getSimpleName(), - args -> new ElasticInferenceServiceAuthorizationResponseEntity((List) args[0]) - ); - - static { - PARSER.declareObjectArray(constructorArg(), AuthorizedModel.AUTHORIZED_MODEL_PARSER::apply, new ParseField("models")); - } - - public record AuthorizedModel(String modelName, EnumSet taskTypes) implements Writeable, ToXContentObject { - - @SuppressWarnings("unchecked") - public static ConstructingObjectParser AUTHORIZED_MODEL_PARSER = new ConstructingObjectParser<>( - AuthorizedModel.class.getSimpleName(), - args -> new AuthorizedModel((String) args[0], toTaskTypes((List) args[1])) - ); - - static { - AUTHORIZED_MODEL_PARSER.declareString(constructorArg(), new ParseField("model_name")); - AUTHORIZED_MODEL_PARSER.declareStringArray(constructorArg(), new ParseField("task_types")); - } - - private static EnumSet toTaskTypes(List stringTaskTypes) { - var taskTypes = EnumSet.noneOf(TaskType.class); - for (String taskType : stringTaskTypes) { - var mappedTaskType = ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING.get(taskType); - if (mappedTaskType != null) { - taskTypes.add(mappedTaskType); - } - } - - return taskTypes; - } - - public AuthorizedModel(StreamInput in) throws IOException { - this(in.readString(), in.readEnumSet(TaskType.class)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(modelName); - out.writeEnumSet(taskTypes); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - builder.field("model_name", modelName); - builder.field("task_types", taskTypes.stream().map(TaskType::toString).collect(Collectors.toList())); - - builder.endObject(); - - return builder; - } - - @Override - public String toString() { - return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes); - } - } - - private final List authorizedModels; - - public ElasticInferenceServiceAuthorizationResponseEntity(List authorizedModels) { - this.authorizedModels = Objects.requireNonNull(authorizedModels); - } - - /** - * Create an empty response - */ - public ElasticInferenceServiceAuthorizationResponseEntity() { - this(List.of()); - } - - public ElasticInferenceServiceAuthorizationResponseEntity(StreamInput in) throws IOException { - this(in.readCollectionAsList(AuthorizedModel::new)); - } - - public static ElasticInferenceServiceAuthorizationResponseEntity fromResponse(Request request, HttpResult response) throws IOException { - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { - return PARSER.apply(jsonParser, null); - } - } - - public List getAuthorizedModels() { - return authorizedModels; - } - - @Override - public String toString() { - return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", ")); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - throw new UnsupportedOperationException(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(authorizedModels); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public Map asMap() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ElasticInferenceServiceAuthorizationResponseEntity that = (ElasticInferenceServiceAuthorizationResponseEntity) o; - return Objects.equals(authorizedModels, that.authorizedModels); - } - - @Override - public int hashCode() { - return Objects.hash(authorizedModels); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 11a635cf5f41a..87fc69d96b7be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -33,14 +32,13 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; import org.junit.Before; import java.io.IOException; import java.util.ArrayList; -import java.util.EnumSet; import java.util.List; import java.util.Locale; import java.util.concurrent.CountDownLatch; @@ -57,6 +55,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests.getEisElserAuthorizationResponse; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -302,17 +301,9 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce try (var sender = createSender(senderFactory)) { sender.startSynchronously(); - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var url = getUrl(webServer); + var elserResponse = getEisElserAuthorizationResponse(url); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse.responseJson())); PlainActionFuture listener = new PlainActionFuture<>(); var request = new ElasticInferenceServiceAuthorizationRequest( @@ -323,25 +314,15 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce ); var responseHandler = new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - ElasticInferenceServiceAuthorizationResponseEntity::fromResponse + AuthorizationResponseEntityV2::fromResponse ); sender.sendWithoutQueuing(mock(Logger.class), request, responseHandler, null, listener); var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(ElasticInferenceServiceAuthorizationResponseEntity.class)); - var authResponse = (ElasticInferenceServiceAuthorizationResponseEntity) result; - assertThat( - authResponse.getAuthorizedModels(), - is( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-a", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) - ) - ) - ) - ); + assertThat(result, instanceOf(AuthorizationResponseEntityV2.class)); + var authResponse = (AuthorizationResponseEntityV2) result; + assertThat(authResponse.getAuthorizedEndpoints(), is(elserResponse.responseEntity().getAuthorizedEndpoints())); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java index dce473017eb3a..f8b39d2c3ef77 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java @@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -42,28 +42,10 @@ public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { } public void testExcludes_EndpointsWithoutValidTaskTypes() { - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - "id", - "name", - "invalid_task_type", - "ga", - null, - "", - "", - null - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - "id2", - "name", - TaskType.ANY.toString(), - "ga", - null, - "", - "", - null - ) + new AuthorizationResponseEntityV2.AuthorizedEndpoint("id", "name", "invalid_task_type", "ga", null, "", "", null), + new AuthorizationResponseEntityV2.AuthorizedEndpoint("id2", "name", TaskType.ANY.toString(), "ga", null, "", "", null) ) ); var auth = AuthorizationModel.of(response, "url"); @@ -75,9 +57,9 @@ public void testReturnsAuthorizedTaskTypes() { var id1 = "id1"; var id2 = "id2"; - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id1, "name1", TaskType.CHAT_COMPLETION.toString(), @@ -87,7 +69,7 @@ public void testReturnsAuthorizedTaskTypes() { "", null ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id2, "name2", TaskType.SPARSE_EMBEDDING.toString(), @@ -109,9 +91,9 @@ public void testReturnsAuthorizedTaskTypes() { public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { var id = "id1"; - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, "name1", TaskType.CHAT_COMPLETION.toString(), @@ -122,7 +104,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { null ), // This should be ignored because the id is a duplicate - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, "name2", TaskType.SPARSE_EMBEDDING.toString(), @@ -151,9 +133,9 @@ public void testReturnsAuthorizedEndpoints() { var similarity = SimilarityMeasure.COSINE; var dimensions = 123; - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id1, name1, TaskType.CHAT_COMPLETION.toString(), @@ -163,7 +145,7 @@ public void testReturnsAuthorizedEndpoints() { "", null ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -171,7 +153,7 @@ public void testReturnsAuthorizedEndpoints() { null, "", "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntityV2.Configuration( similarity.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -223,9 +205,9 @@ public void testScopesToTaskType() { var similarity = SimilarityMeasure.COSINE; var dimensions = 123; - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id1, name1, TaskType.CHAT_COMPLETION.toString(), @@ -235,7 +217,7 @@ public void testScopesToTaskType() { "", null ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -243,7 +225,7 @@ public void testScopesToTaskType() { null, "", "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntityV2.Configuration( similarity.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -292,9 +274,9 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { var dimensions = 123; - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id1, name1, TaskType.CHAT_COMPLETION.toString(), @@ -304,7 +286,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { "", null ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -312,14 +294,14 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntityV2.Configuration( null, dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), null ) ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -327,7 +309,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntityV2.Configuration( SimilarityMeasure.DOT_PRODUCT.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -378,9 +360,9 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { var url = "base_url"; - var response = new ElasticInferenceServiceAuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntityV2( List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( idChat, nameChat, TaskType.CHAT_COMPLETION.toString(), @@ -390,7 +372,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "", null ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( idSparse, nameSparse, TaskType.SPARSE_EMBEDDING.toString(), @@ -400,7 +382,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "", null ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( idDense, nameDense, TaskType.TEXT_EMBEDDING.toString(), @@ -408,14 +390,9 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { null, "", "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - similarity.toString(), - dimensions, - elementType, - null - ) + new AuthorizationResponseEntityV2.Configuration(similarity.toString(), dimensions, elementType, null) ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( idRerank, nameRerank, TaskType.RERANK.toString(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index fe9851eb4d01e..e06782d34725f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.Before; @@ -41,12 +41,9 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -<<<<<<< HEAD -import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2Tests.createAuthorizedEndpoint; -======= import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService; ->>>>>>> d74334e978a26a7fb76f312898431a000cdbc3b6 +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests.createAuthorizedEndpoint; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -191,7 +188,7 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -222,23 +219,7 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { poller.sendAuthorizationRequest(); verify(mockClient).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); var capturedRequest = requestArgCaptor.getValue(); - assertThat( - capturedRequest.getModels(), - is( - List.of( - new ElasticInferenceServiceSparseEmbeddingsModel( - sparseModel.id(), - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(sparseModel.modelName(), null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ) - ) - ) - ); + assertThat(capturedRequest.getModels(), is(List.of(createSparseEndpoint(sparseModel.id(), sparseModel.modelName(), url)))); verify(mockPersistentTasksService, never()).sendCompletionRequest( eq(persistentTaskId), @@ -250,26 +231,31 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { ); } + private ElasticInferenceServiceSparseEmbeddingsModel createSparseEndpoint(String endpointId, String modelName, String url) { + return new ElasticInferenceServiceSparseEmbeddingsModel( + endpointId, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelName, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ); + } + public void testSendsAuthorizationRequest_WhenCCMIsNotConfigurable() { var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + var url = "eis-url"; + var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ); + ActionListener listener = invocation.getArgument(0); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -301,17 +287,7 @@ public void testSendsAuthorizationRequest_WhenCCMIsNotConfigurable() { poller.sendAuthorizationRequest(); verify(mockClient).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); var capturedRequest = requestArgCaptor.getValue(); - assertThat( - capturedRequest.getModels(), - is( - List.of( - PreconfiguredEndpointModelAdapter.createModel( - InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2), - new ElasticInferenceServiceComponents("") - ) - ) - ) - ); + assertThat(capturedRequest.getModels(), is(List.of(createSparseEndpoint(sparseModel.id(), sparseModel.modelName(), url)))); verify(mockPersistentTasksService, never()).sendCompletionRequest( eq(persistentTaskId), @@ -333,68 +309,8 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { -<<<<<<< HEAD ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); -======= - ActionListener listener = invocation.getArgument(0); - listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - var mockClient = mock(Client.class); - when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); - - var poller = new AuthorizationPoller( - new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), - createWithEmptySettings(taskQueue.getThreadPool()), - mockAuthHandler, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - mockRegistry, - mockClient, - createMockCCMFeature(true), - createMockCCMService(true), - null - ); - - poller.sendAuthorizationRequest(); - verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); - } - - public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMapping() { - var mockRegistry = mock(ModelRegistry.class); - when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); - - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - // This is a model id that does not exist in the preconfigured endpoints map so it will not be stored - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ) - ) - ) - ) - ); ->>>>>>> d74334e978a26a7fb76f312898431a000cdbc3b6 + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -430,7 +346,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); listener.onResponse( - AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(completionModel)), url) + AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(completionModel)), url) ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -468,7 +384,7 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -525,7 +441,7 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 9ade6e4c45c80..407ea56eb4732 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -14,8 +14,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -24,18 +22,13 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xpack.core.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityV2Tests; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -54,6 +47,7 @@ import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests.getEisElserAuthorizationResponse; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; @@ -67,29 +61,6 @@ public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - private static final String ELSER_EIS_RESPONSE = """ - { - "inference_endpoints": [ - { - "id": ".elastic-elser-v2", - "model_name": "elser_model_2", - "task_type": "sparse_embedding", - "status": "preview", - "properties": [ - "english" - ], - "release_date": "2024-05-01", - "configuration": { - "chunking_settings": { - "strategy": "sentence", - "max_chunk_size": 250, - "sentence_overlap": 1 - } - } - } - ] - } - """; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -221,8 +192,9 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { ); try (var sender = senderFactory.createSender()) { - var responseData = ElasticInferenceServiceAuthorizationResponseEntityV2Tests.EisAuthorizationResponseData - .getEisAuthorizationData(eisGatewayUrl); + var responseData = AuthorizationResponseEntityV2Tests.getEisAuthorizationResponseWithMultipleEndpoints( + eisGatewayUrl + ); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); @@ -269,8 +241,10 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I createApplierFactory(secret) ); + var elserResponse = getEisElserAuthorizationResponse(eisGatewayUrl); + try (var sender = senderFactory.createSender()) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(ELSER_EIS_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse.responseJson())); PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); @@ -310,38 +284,21 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); ActionListener onlyOnceListener = ActionListener.assertOnce(listener); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(ELSER_EIS_RESPONSE)); + var elserResponse = getEisElserAuthorizationResponse(eisGatewayUrl); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse.responseJson())); try (var sender = senderFactory.createSender()) { authHandler.getAuthorization(onlyOnceListener, sender); authHandler.waitForAuthRequestCompletion(TIMEOUT); var endpointId = ".elastic-elser-v2"; - var endpointName = "elser_model_2"; - var maxChunkSize = 250; - var sentenceOverlap = 1; var authResponse = listener.actionGet(TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); assertThat(authResponse.getEndpointIds(), is(Set.of(endpointId))); assertTrue(authResponse.isAuthorized()); - assertThat( - authResponse.getEndpoints(Set.of(endpointId)), - is( - List.of( - new ElasticInferenceServiceSparseEmbeddingsModel( - endpointId, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(endpointName, null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(eisGatewayUrl), - new SentenceBoundaryChunkingSettings(maxChunkSize, sentenceOverlap) - ) - ) - ) - ); + assertThat(authResponse.getEndpoints(Set.of(endpointId)), is(elserResponse.expectedEndpoints())); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java index 715ee931f6150..be54db3c97084 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceAuthorizationRequestTests.java @@ -51,7 +51,12 @@ public void testCreateUriThrowsForInvalidBaseUrl() { public void testCreateUri_CreatesUri() throws URISyntaxException { String url = "https://inference.us-east-1.aws.svc.elastic.cloud"; - var request = new ElasticInferenceServiceAuthorizationRequest(url, traceContext, randomElasticInferenceServiceRequestMetadata()); + var request = new ElasticInferenceServiceAuthorizationRequest( + url, + traceContext, + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER + ); assertThat(request.getURI(), is(new URI(url + AUTHORIZATION_PATH))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2Tests.java similarity index 50% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2Tests.java index 026819c22647b..4f87e320f23ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityV2Tests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2Tests.java @@ -42,8 +42,38 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; -public class ElasticInferenceServiceAuthorizationResponseEntityV2Tests extends AbstractBWCWireSerializationTestCase< - ElasticInferenceServiceAuthorizationResponseEntityV2> { +public class AuthorizationResponseEntityV2Tests extends AbstractBWCWireSerializationTestCase { + + public record EisAuthorizationResponse( + String responseJson, + AuthorizationResponseEntityV2 responseEntity, + List expectedEndpoints, + Set inferenceIds + ) {} + + public static final String EIS_ELSER_RESPONSE = """ + { + "inference_endpoints": [ + { + "id": ".elastic-elser-v2", + "model_name": "elser_model_2", + "task_type": "sparse_embedding", + "status": "preview", + "properties": [ + "english" + ], + "release_date": "2024-05-01", + "configuration": { + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 + } + } + } + ] + } + """; public static String EIS_AUTHORIZATION_RESPONSE_V2 = """ { @@ -109,138 +139,173 @@ public class ElasticInferenceServiceAuthorizationResponseEntityV2Tests extends A } """; - public record EisAuthorizationResponseData( - String responseJson, - ElasticInferenceServiceAuthorizationResponseEntityV2 responseEntity, - List expectedEndpoints, - Set inferenceIds - ) { + public static EisAuthorizationResponse getEisElserAuthorizationResponse(String url) { + + var authorizedEndpoints = List.of( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( + ".elastic-elser-v2", + "elser_model_2", + "sparse_embedding", + "preview", + List.of("english"), + "2024-05-01", + null, + new AuthorizationResponseEntityV2.Configuration( + null, + null, + null, + Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) + ) + ) + ); - public static EisAuthorizationResponseData getEisAuthorizationData(String url) { + var inferenceIds = authorizedEndpoints.stream() + .map(AuthorizationResponseEntityV2.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); - var authorizedEndpoints = List.of( - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - ".rainbow-sprinkles-elastic", - "rainbow-sprinkles", - "chat_completion", - "ga", - List.of("multilingual"), - "2024-05-01", - "2025-12-31", - null - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + return new EisAuthorizationResponse( + EIS_ELSER_RESPONSE, + new AuthorizationResponseEntityV2(authorizedEndpoints), + List.of( + new ElasticInferenceServiceSparseEmbeddingsModel( ".elastic-elser-v2", - "elser_model_2", - "sparse_embedding", - "preview", - List.of("english"), - "2024-05-01", + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new SentenceBoundaryChunkingSettings(250, 1) + ) + ), + inferenceIds + ); + } + + public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { + + var authorizedEndpoints = List.of( + new AuthorizationResponseEntityV2.AuthorizedEndpoint( + ".rainbow-sprinkles-elastic", + "rainbow-sprinkles", + "chat_completion", + "ga", + List.of("multilingual"), + "2024-05-01", + "2025-12-31", + null + ), + new AuthorizationResponseEntityV2.AuthorizedEndpoint( + ".elastic-elser-v2", + "elser_model_2", + "sparse_embedding", + "preview", + List.of("english"), + "2024-05-01", + null, + new AuthorizationResponseEntityV2.Configuration( null, - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - null, - null, - null, - Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) - ) - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - ".jina-embeddings-v3", - "jina-embeddings-v3", - "text_embedding", - "beta", - List.of("multilingual", "open-weights"), - "2024-05-01", null, - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( - "cosine", - 1024, - "float", - Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2) - ) - ), - new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( - ".elastic-rerank-v1", - "elastic-rerank-v1", - "rerank", - "preview", - List.of(), - "2024-05-01", null, - null + Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) ) - ); + ), + new AuthorizationResponseEntityV2.AuthorizedEndpoint( + ".jina-embeddings-v3", + "jina-embeddings-v3", + "text_embedding", + "beta", + List.of("multilingual", "open-weights"), + "2024-05-01", + null, + new AuthorizationResponseEntityV2.Configuration( + "cosine", + 1024, + "float", + Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2) + ) + ), + new AuthorizationResponseEntityV2.AuthorizedEndpoint( + ".elastic-rerank-v1", + "elastic-rerank-v1", + "rerank", + "preview", + List.of(), + "2024-05-01", + null, + null + ) + ); - var inferenceIds = authorizedEndpoints.stream() - .map(ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint::id) - .collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream() + .map(AuthorizationResponseEntityV2.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); - return new EisAuthorizationResponseData( - EIS_AUTHORIZATION_RESPONSE_V2, - new ElasticInferenceServiceAuthorizationResponseEntityV2(authorizedEndpoints), - List.of( - new ElasticInferenceServiceCompletionModel( - ".rainbow-sprinkles-elastic", - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) - ), - new ElasticInferenceServiceSparseEmbeddingsModel( - ".elastic-elser-v2", - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new SentenceBoundaryChunkingSettings(250, 1) - ), - new ElasticInferenceServiceDenseTextEmbeddingsModel( - ".jina-embeddings-v3", - TaskType.TEXT_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - "jina-embeddings-v3", - SimilarityMeasure.COSINE, - 1024, - null - ), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new WordBoundaryChunkingSettings(500, 2) + return new EisAuthorizationResponse( + EIS_AUTHORIZATION_RESPONSE_V2, + new AuthorizationResponseEntityV2(authorizedEndpoints), + List.of( + new ElasticInferenceServiceCompletionModel( + ".rainbow-sprinkles-elastic", + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ), + new ElasticInferenceServiceSparseEmbeddingsModel( + ".elastic-elser-v2", + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new SentenceBoundaryChunkingSettings(250, 1) + ), + new ElasticInferenceServiceDenseTextEmbeddingsModel( + ".jina-embeddings-v3", + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + "jina-embeddings-v3", + SimilarityMeasure.COSINE, + 1024, + null ), - new ElasticInferenceServiceRerankModel( - ".elastic-rerank-v1", - TaskType.RERANK, - ElasticInferenceService.NAME, - new ElasticInferenceServiceRerankServiceSettings("elastic-rerank-v1"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) - ) + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new WordBoundaryChunkingSettings(500, 2) ), - inferenceIds - ); - } + new ElasticInferenceServiceRerankModel( + ".elastic-rerank-v1", + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings("elastic-rerank-v1"), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ) + ), + inferenceIds + ); } - public static ElasticInferenceServiceAuthorizationResponseEntityV2 createResponse() { - return new ElasticInferenceServiceAuthorizationResponseEntityV2( + public static AuthorizationResponseEntityV2 createResponse() { + return new AuthorizationResponseEntityV2( randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) ); } - public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { + public static AuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { var id = randomAlphaOfLength(10); var name = randomAlphaOfLength(10); var status = randomFrom("ga", "beta", "preview"); return switch (taskType) { - case CHAT_COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + case CHAT_COMPLETION -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, name, TaskType.CHAT_COMPLETION.toString(), @@ -250,7 +315,7 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEnd "", null ); - case SPARSE_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + case SPARSE_EMBEDDING -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, name, TaskType.SPARSE_EMBEDDING.toString(), @@ -260,7 +325,7 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEnd "", null ); - case TEXT_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + case TEXT_EMBEDDING -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, name, TaskType.TEXT_EMBEDDING.toString(), @@ -268,14 +333,14 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEnd null, "", "", - new ElasticInferenceServiceAuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntityV2.Configuration( randomFrom(SimilarityMeasure.values()).toString(), randomInt(), DenseVectorFieldMapper.ElementType.FLOAT.toString(), null ) ); - case RERANK -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + case RERANK -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, name, TaskType.RERANK.toString(), @@ -285,7 +350,7 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEnd "", null ); - case COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEndpoint( + case COMPLETION -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( id, name, TaskType.COMPLETION.toString(), @@ -302,9 +367,9 @@ public static ElasticInferenceServiceAuthorizationResponseEntityV2.AuthorizedEnd public void testParseAllFields() throws IOException { var url = "http://example.com/authorize"; - var responseData = EisAuthorizationResponseData.getEisAuthorizationData(url); + var responseData = getEisAuthorizationResponseWithMultipleEndpoints(url); try (var parser = createParser(JsonXContent.jsonXContent, responseData.responseJson())) { - var entity = ElasticInferenceServiceAuthorizationResponseEntityV2.PARSER.apply(parser, null); + var entity = AuthorizationResponseEntityV2.PARSER.apply(parser, null); assertThat(entity, is(responseData.responseEntity())); @@ -323,29 +388,24 @@ public void testParseAllFields() throws IOException { } @Override - protected ElasticInferenceServiceAuthorizationResponseEntityV2 mutateInstanceForVersion( - ElasticInferenceServiceAuthorizationResponseEntityV2 instance, - TransportVersion version - ) { + protected AuthorizationResponseEntityV2 mutateInstanceForVersion(AuthorizationResponseEntityV2 instance, TransportVersion version) { return instance; } @Override - protected Writeable.Reader instanceReader() { - return ElasticInferenceServiceAuthorizationResponseEntityV2::new; + protected Writeable.Reader instanceReader() { + return AuthorizationResponseEntityV2::new; } @Override - protected ElasticInferenceServiceAuthorizationResponseEntityV2 createTestInstance() { + protected AuthorizationResponseEntityV2 createTestInstance() { return createResponse(); } @Override - protected ElasticInferenceServiceAuthorizationResponseEntityV2 mutateInstance( - ElasticInferenceServiceAuthorizationResponseEntityV2 instance - ) throws IOException { + protected AuthorizationResponseEntityV2 mutateInstance(AuthorizationResponseEntityV2 instance) throws IOException { var newEndpoints = new ArrayList<>(instance.getAuthorizedEndpoints()); newEndpoints.add(createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))); - return new ElasticInferenceServiceAuthorizationResponseEntityV2(newEndpoints); + return new AuthorizationResponseEntityV2(newEndpoints); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java deleted file mode 100644 index 36387bb9f7fa4..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.response; - -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.json.JsonXContent; - -import java.io.IOException; -import java.util.EnumSet; -import java.util.List; - -import static org.hamcrest.Matchers.is; - -public class ElasticInferenceServiceAuthorizationResponseEntityTests extends ESTestCase { - - public void testParseAllFields() throws IOException { - String json = """ - { - "models": [ - { - "model_name": "test_model", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; - - try (var parser = createParser(JsonXContent.jsonXContent, json)) { - var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null); - var expected = new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "test_model", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) - ) - ) - ); - - assertThat(entity, is(expected)); - } - } - - public void testParsing_EmptyModels() throws IOException { - String json = """ - { - "models": [] - } - """; - - try (var parser = createParser(JsonXContent.jsonXContent, json)) { - var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null); - var expected = new ElasticInferenceServiceAuthorizationResponseEntity(List.of()); - - assertThat(entity, is(expected)); - } - } - -} From 839ebe0adc88699cc123026d931f6bdea6e1f98a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 18 Nov 2025 16:53:37 -0500 Subject: [PATCH 07/24] Renaming --- .../InferenceNamedWriteablesProvider.java | 8 +++ .../authorization/AuthorizationModel.java | 36 +++++------ ...nceServiceAuthorizationRequestHandler.java | 6 +- ....java => AuthorizationResponseEntity.java} | 18 +++--- .../http/sender/HttpRequestSenderTests.java | 10 +-- .../AuthorizationModelTests.java | 60 ++++++++--------- .../AuthorizationPollerTests.java | 16 ++--- ...rviceAuthorizationRequestHandlerTests.java | 6 +- ... => AuthorizationResponseEntityTests.java} | 64 +++++++++---------- 9 files changed, 114 insertions(+), 110 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/{AuthorizationResponseEntityV2.java => AuthorizationResponseEntity.java} (95%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/{AuthorizationResponseEntityV2Tests.java => AuthorizationResponseEntityTests.java} (85%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f39ca9f66f3ca..b60b7ad3a584c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -75,6 +75,7 @@ import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; @@ -704,6 +705,13 @@ private static void addInferenceResultsNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java index e11448bca18ec..a1f2cd0bdba83 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -48,13 +48,13 @@ public class AuthorizationModel { private static final Logger logger = LogManager.getLogger(AuthorizationModel.class); private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping"; - public static AuthorizationModel of(AuthorizationResponseEntityV2 responseEntity, String baseEisUrl) { + public static AuthorizationModel of(AuthorizationResponseEntity responseEntity, String baseEisUrl) { var components = new ElasticInferenceServiceComponents(baseEisUrl); return createInternal(responseEntity.getAuthorizedEndpoints(), components); } private static AuthorizationModel createInternal( - List responseEndpoints, + List responseEndpoints, ElasticInferenceServiceComponents components ) { var validEndpoints = new ArrayList(); @@ -69,7 +69,7 @@ private static AuthorizationModel createInternal( } private static ElasticInferenceServiceModel createModel( - AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { try { @@ -110,7 +110,7 @@ private static TaskType getTaskType(String taskType) { } private static ElasticInferenceServiceCompletionModel createCompletionModel( - AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceCompletionModel( @@ -125,7 +125,7 @@ private static ElasticInferenceServiceCompletionModel createCompletionModel( } private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddingsModel( - AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceSparseEmbeddingsModel( @@ -140,24 +140,24 @@ private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddin ); } - private static AuthorizationResponseEntityV2.Configuration getConfigurationOrEmpty( - AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint + private static AuthorizationResponseEntity.Configuration getConfigurationOrEmpty( + AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint ) { if (authorizedEndpoint.configuration() != null) { return authorizedEndpoint.configuration(); } - return AuthorizationResponseEntityV2.Configuration.EMPTY; + return AuthorizationResponseEntity.Configuration.EMPTY; } private static Map getChunkingSettingsMap( - AuthorizationResponseEntityV2.Configuration configuration + AuthorizationResponseEntity.Configuration configuration ) { return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>()); } private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEmbeddingsModel( - AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { var config = getConfigurationOrEmpty(authorizedEndpoint); @@ -180,19 +180,19 @@ private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEm ); } - private static void validateConfigurationForTextEmbedding(AuthorizationResponseEntityV2.Configuration config) { + private static void validateConfigurationForTextEmbedding(AuthorizationResponseEntity.Configuration config) { validateFieldPresent( - AuthorizationResponseEntityV2.Configuration.ELEMENT_TYPE, + AuthorizationResponseEntity.Configuration.ELEMENT_TYPE, config.elementType(), TaskType.TEXT_EMBEDDING ); validateFieldPresent( - AuthorizationResponseEntityV2.Configuration.DIMENSIONS, + AuthorizationResponseEntity.Configuration.DIMENSIONS, config.dimensions(), TaskType.TEXT_EMBEDDING ); validateFieldPresent( - AuthorizationResponseEntityV2.Configuration.SIMILARITY, + AuthorizationResponseEntity.Configuration.SIMILARITY, config.similarity(), TaskType.TEXT_EMBEDDING ); @@ -207,10 +207,10 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy } private static SimilarityMeasure getSimilarityMeasure( - AuthorizationResponseEntityV2.Configuration configuration + AuthorizationResponseEntity.Configuration configuration ) { validateFieldPresent( - AuthorizationResponseEntityV2.Configuration.SIMILARITY, + AuthorizationResponseEntity.Configuration.SIMILARITY, configuration.similarity(), TaskType.TEXT_EMBEDDING ); @@ -219,7 +219,7 @@ private static SimilarityMeasure getSimilarityMeasure( } private static ElasticInferenceServiceRerankModel createRerankModel( - AuthorizationResponseEntityV2.AuthorizedEndpoint authorizedEndpoint, + AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceRerankModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 15b97e66ec8de..5638b7b346497 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; @@ -47,7 +47,7 @@ public class ElasticInferenceServiceAuthorizationRequestHandler { private static ResponseHandler createAuthResponseHandler() { return new ElasticInferenceServiceResponseHandler( Strings.format("%s authorization", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - AuthorizationResponseEntityV2::fromResponse + AuthorizationResponseEntity::fromResponse ); } @@ -119,7 +119,7 @@ public void getAuthorization(ActionListener listener, Sender sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); }) .andThenApply(authResult -> { - if (authResult instanceof AuthorizationResponseEntityV2 authResponseEntity) { + if (authResult instanceof AuthorizationResponseEntity authResponseEntity) { logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); return AuthorizationModel.of(authResponseEntity, baseUrl); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java similarity index 95% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java index d7d058a03f125..6ba2664e52993 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java @@ -37,17 +37,17 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -public class AuthorizationResponseEntityV2 implements InferenceServiceResults { +public class AuthorizationResponseEntity implements InferenceServiceResults { public static final String NAME = "elastic_inference_service_auth_results_v2"; private static final String INFERENCE_ENDPOINTS = "inference_endpoints"; @SuppressWarnings("unchecked") - public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - AuthorizationResponseEntityV2.class.getSimpleName(), + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + AuthorizationResponseEntity.class.getSimpleName(), true, - args -> new AuthorizationResponseEntityV2((List) args[0]) + args -> new AuthorizationResponseEntity((List) args[0]) ); static { @@ -291,22 +291,22 @@ public String toString() { private final List authorizedEndpoints; - public AuthorizationResponseEntityV2(List authorizedModels) { + public AuthorizationResponseEntity(List authorizedModels) { this.authorizedEndpoints = Objects.requireNonNull(authorizedModels); } /** * Create an empty response */ - public AuthorizationResponseEntityV2() { + public AuthorizationResponseEntity() { this(List.of()); } - public AuthorizationResponseEntityV2(StreamInput in) throws IOException { + public AuthorizationResponseEntity(StreamInput in) throws IOException { this(in.readCollectionAsList(AuthorizedEndpoint::new)); } - public static AuthorizationResponseEntityV2 fromResponse(Request request, HttpResult response) throws IOException { + public static AuthorizationResponseEntity fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -352,7 +352,7 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - AuthorizationResponseEntityV2 that = (AuthorizationResponseEntityV2) o; + AuthorizationResponseEntity that = (AuthorizationResponseEntity) o; return Objects.equals(authorizedEndpoints, that.authorizedEndpoints); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 87fc69d96b7be..0420b58db9d1f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -32,7 +32,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; import org.junit.Before; @@ -55,7 +55,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests.getEisElserAuthorizationResponse; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisElserAuthorizationResponse; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -314,14 +314,14 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce ); var responseHandler = new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - AuthorizationResponseEntityV2::fromResponse + AuthorizationResponseEntity::fromResponse ); sender.sendWithoutQueuing(mock(Logger.class), request, responseHandler, null, listener); var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(AuthorizationResponseEntityV2.class)); - var authResponse = (AuthorizationResponseEntityV2) result; + assertThat(result, instanceOf(AuthorizationResponseEntity.class)); + var authResponse = (AuthorizationResponseEntity) result; assertThat(authResponse.getAuthorizedEndpoints(), is(elserResponse.responseEntity().getAuthorizedEndpoints())); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java index f8b39d2c3ef77..ea8881852b1f7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java @@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -42,10 +42,10 @@ public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { } public void testExcludes_EndpointsWithoutValidTaskTypes() { - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint("id", "name", "invalid_task_type", "ga", null, "", "", null), - new AuthorizationResponseEntityV2.AuthorizedEndpoint("id2", "name", TaskType.ANY.toString(), "ga", null, "", "", null) + new AuthorizationResponseEntity.AuthorizedEndpoint("id", "name", "invalid_task_type", "ga", null, "", "", null), + new AuthorizationResponseEntity.AuthorizedEndpoint("id2", "name", TaskType.ANY.toString(), "ga", null, "", "", null) ) ); var auth = AuthorizationModel.of(response, "url"); @@ -57,9 +57,9 @@ public void testReturnsAuthorizedTaskTypes() { var id1 = "id1"; var id2 = "id2"; - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id1, "name1", TaskType.CHAT_COMPLETION.toString(), @@ -69,7 +69,7 @@ public void testReturnsAuthorizedTaskTypes() { "", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id2, "name2", TaskType.SPARSE_EMBEDDING.toString(), @@ -91,9 +91,9 @@ public void testReturnsAuthorizedTaskTypes() { public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { var id = "id1"; - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id, "name1", TaskType.CHAT_COMPLETION.toString(), @@ -104,7 +104,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { null ), // This should be ignored because the id is a duplicate - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id, "name2", TaskType.SPARSE_EMBEDDING.toString(), @@ -133,9 +133,9 @@ public void testReturnsAuthorizedEndpoints() { var similarity = SimilarityMeasure.COSINE; var dimensions = 123; - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, TaskType.CHAT_COMPLETION.toString(), @@ -145,7 +145,7 @@ public void testReturnsAuthorizedEndpoints() { "", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -153,7 +153,7 @@ public void testReturnsAuthorizedEndpoints() { null, "", "", - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( similarity.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -205,9 +205,9 @@ public void testScopesToTaskType() { var similarity = SimilarityMeasure.COSINE; var dimensions = 123; - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, TaskType.CHAT_COMPLETION.toString(), @@ -217,7 +217,7 @@ public void testScopesToTaskType() { "", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -225,7 +225,7 @@ public void testScopesToTaskType() { null, "", "", - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( similarity.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -274,9 +274,9 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { var dimensions = 123; - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, TaskType.CHAT_COMPLETION.toString(), @@ -286,7 +286,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { "", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -294,14 +294,14 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( null, dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), null ) ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, TaskType.TEXT_EMBEDDING.toString(), @@ -309,7 +309,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( SimilarityMeasure.DOT_PRODUCT.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -360,9 +360,9 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { var url = "base_url"; - var response = new AuthorizationResponseEntityV2( + var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( idChat, nameChat, TaskType.CHAT_COMPLETION.toString(), @@ -372,7 +372,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( idSparse, nameSparse, TaskType.SPARSE_EMBEDDING.toString(), @@ -382,7 +382,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( idDense, nameDense, TaskType.TEXT_EMBEDDING.toString(), @@ -390,9 +390,9 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { null, "", "", - new AuthorizationResponseEntityV2.Configuration(similarity.toString(), dimensions, elementType, null) + new AuthorizationResponseEntity.Configuration(similarity.toString(), dimensions, elementType, null) ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( idRerank, nameRerank, TaskType.RERANK.toString(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index e06782d34725f..df733298a47bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.Before; @@ -43,7 +43,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests.createAuthorizedEndpoint; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createAuthorizedEndpoint; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -188,7 +188,7 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -255,7 +255,7 @@ public void testSendsAuthorizationRequest_WhenCCMIsNotConfigurable() { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -310,7 +310,7 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -346,7 +346,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); listener.onResponse( - AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(completionModel)), url) + AuthorizationModel.of(new AuthorizationResponseEntity(List.of(completionModel)), url) ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -384,7 +384,7 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -441,7 +441,7 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntityV2(List.of(sparseModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 407ea56eb4732..45f60b9346b4c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -28,7 +28,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityV2Tests.getEisElserAuthorizationResponse; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisElserAuthorizationResponse; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; @@ -192,7 +192,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { ); try (var sender = senderFactory.createSender()) { - var responseData = AuthorizationResponseEntityV2Tests.getEisAuthorizationResponseWithMultipleEndpoints( + var responseData = AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints( eisGatewayUrl ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2Tests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java similarity index 85% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2Tests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 4f87e320f23ab..57b08507c2099 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityV2Tests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -42,11 +42,11 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; -public class AuthorizationResponseEntityV2Tests extends AbstractBWCWireSerializationTestCase { +public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializationTestCase { public record EisAuthorizationResponse( String responseJson, - AuthorizationResponseEntityV2 responseEntity, + AuthorizationResponseEntity responseEntity, List expectedEndpoints, Set inferenceIds ) {} @@ -142,7 +142,7 @@ public record EisAuthorizationResponse( public static EisAuthorizationResponse getEisElserAuthorizationResponse(String url) { var authorizedEndpoints = List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( ".elastic-elser-v2", "elser_model_2", "sparse_embedding", @@ -150,7 +150,7 @@ public static EisAuthorizationResponse getEisElserAuthorizationResponse(String u List.of("english"), "2024-05-01", null, - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( null, null, null, @@ -159,13 +159,11 @@ public static EisAuthorizationResponse getEisElserAuthorizationResponse(String u ) ); - var inferenceIds = authorizedEndpoints.stream() - .map(AuthorizationResponseEntityV2.AuthorizedEndpoint::id) - .collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_ELSER_RESPONSE, - new AuthorizationResponseEntityV2(authorizedEndpoints), + new AuthorizationResponseEntity(authorizedEndpoints), List.of( new ElasticInferenceServiceSparseEmbeddingsModel( ".elastic-elser-v2", @@ -185,7 +183,7 @@ public static EisAuthorizationResponse getEisElserAuthorizationResponse(String u public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { var authorizedEndpoints = List.of( - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( ".rainbow-sprinkles-elastic", "rainbow-sprinkles", "chat_completion", @@ -195,7 +193,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn "2025-12-31", null ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( ".elastic-elser-v2", "elser_model_2", "sparse_embedding", @@ -203,14 +201,14 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn List.of("english"), "2024-05-01", null, - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( null, null, null, Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) ) ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( ".jina-embeddings-v3", "jina-embeddings-v3", "text_embedding", @@ -218,14 +216,14 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn List.of("multilingual", "open-weights"), "2024-05-01", null, - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( "cosine", 1024, "float", Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2) ) ), - new AuthorizationResponseEntityV2.AuthorizedEndpoint( + new AuthorizationResponseEntity.AuthorizedEndpoint( ".elastic-rerank-v1", "elastic-rerank-v1", "rerank", @@ -237,13 +235,11 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn ) ); - var inferenceIds = authorizedEndpoints.stream() - .map(AuthorizationResponseEntityV2.AuthorizedEndpoint::id) - .collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_AUTHORIZATION_RESPONSE_V2, - new AuthorizationResponseEntityV2(authorizedEndpoints), + new AuthorizationResponseEntity(authorizedEndpoints), List.of( new ElasticInferenceServiceCompletionModel( ".rainbow-sprinkles-elastic", @@ -293,19 +289,19 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn ); } - public static AuthorizationResponseEntityV2 createResponse() { - return new AuthorizationResponseEntityV2( + public static AuthorizationResponseEntity createResponse() { + return new AuthorizationResponseEntity( randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) ); } - public static AuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { + public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { var id = randomAlphaOfLength(10); var name = randomAlphaOfLength(10); var status = randomFrom("ga", "beta", "preview"); return switch (taskType) { - case CHAT_COMPLETION -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( + case CHAT_COMPLETION -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, TaskType.CHAT_COMPLETION.toString(), @@ -315,7 +311,7 @@ public static AuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedE "", null ); - case SPARSE_EMBEDDING -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( + case SPARSE_EMBEDDING -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, TaskType.SPARSE_EMBEDDING.toString(), @@ -325,7 +321,7 @@ public static AuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedE "", null ); - case TEXT_EMBEDDING -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( + case TEXT_EMBEDDING -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, TaskType.TEXT_EMBEDDING.toString(), @@ -333,14 +329,14 @@ public static AuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedE null, "", "", - new AuthorizationResponseEntityV2.Configuration( + new AuthorizationResponseEntity.Configuration( randomFrom(SimilarityMeasure.values()).toString(), randomInt(), DenseVectorFieldMapper.ElementType.FLOAT.toString(), null ) ); - case RERANK -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( + case RERANK -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, TaskType.RERANK.toString(), @@ -350,7 +346,7 @@ public static AuthorizationResponseEntityV2.AuthorizedEndpoint createAuthorizedE "", null ); - case COMPLETION -> new AuthorizationResponseEntityV2.AuthorizedEndpoint( + case COMPLETION -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, TaskType.COMPLETION.toString(), @@ -369,7 +365,7 @@ public void testParseAllFields() throws IOException { var url = "http://example.com/authorize"; var responseData = getEisAuthorizationResponseWithMultipleEndpoints(url); try (var parser = createParser(JsonXContent.jsonXContent, responseData.responseJson())) { - var entity = AuthorizationResponseEntityV2.PARSER.apply(parser, null); + var entity = AuthorizationResponseEntity.PARSER.apply(parser, null); assertThat(entity, is(responseData.responseEntity())); @@ -388,24 +384,24 @@ public void testParseAllFields() throws IOException { } @Override - protected AuthorizationResponseEntityV2 mutateInstanceForVersion(AuthorizationResponseEntityV2 instance, TransportVersion version) { + protected AuthorizationResponseEntity mutateInstanceForVersion(AuthorizationResponseEntity instance, TransportVersion version) { return instance; } @Override - protected Writeable.Reader instanceReader() { - return AuthorizationResponseEntityV2::new; + protected Writeable.Reader instanceReader() { + return AuthorizationResponseEntity::new; } @Override - protected AuthorizationResponseEntityV2 createTestInstance() { + protected AuthorizationResponseEntity createTestInstance() { return createResponse(); } @Override - protected AuthorizationResponseEntityV2 mutateInstance(AuthorizationResponseEntityV2 instance) throws IOException { + protected AuthorizationResponseEntity mutateInstance(AuthorizationResponseEntity instance) throws IOException { var newEndpoints = new ArrayList<>(instance.getAuthorizedEndpoints()); newEndpoints.add(createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))); - return new AuthorizationResponseEntityV2(newEndpoints); + return new AuthorizationResponseEntity(newEndpoints); } } From e7b23716009d00eb85aae959fe14bee284ccbe1b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 18 Nov 2025 22:01:36 +0000 Subject: [PATCH 08/24] [CI] Auto commit changes from spotless --- .../authorization/AuthorizationModel.java | 32 ++++--------------- .../AuthorizationPollerTests.java | 4 +-- ...rviceAuthorizationRequestHandlerTests.java | 4 +-- 3 files changed, 8 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java index a1f2cd0bdba83..5c29d466bced1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -150,9 +150,7 @@ private static AuthorizationResponseEntity.Configuration getConfigurationOrEmpty return AuthorizationResponseEntity.Configuration.EMPTY; } - private static Map getChunkingSettingsMap( - AuthorizationResponseEntity.Configuration configuration - ) { + private static Map getChunkingSettingsMap(AuthorizationResponseEntity.Configuration configuration) { return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>()); } @@ -181,21 +179,9 @@ private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEm } private static void validateConfigurationForTextEmbedding(AuthorizationResponseEntity.Configuration config) { - validateFieldPresent( - AuthorizationResponseEntity.Configuration.ELEMENT_TYPE, - config.elementType(), - TaskType.TEXT_EMBEDDING - ); - validateFieldPresent( - AuthorizationResponseEntity.Configuration.DIMENSIONS, - config.dimensions(), - TaskType.TEXT_EMBEDDING - ); - validateFieldPresent( - AuthorizationResponseEntity.Configuration.SIMILARITY, - config.similarity(), - TaskType.TEXT_EMBEDDING - ); + validateFieldPresent(AuthorizationResponseEntity.Configuration.ELEMENT_TYPE, config.elementType(), TaskType.TEXT_EMBEDDING); + validateFieldPresent(AuthorizationResponseEntity.Configuration.DIMENSIONS, config.dimensions(), TaskType.TEXT_EMBEDDING); + validateFieldPresent(AuthorizationResponseEntity.Configuration.SIMILARITY, config.similarity(), TaskType.TEXT_EMBEDDING); } private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) { @@ -206,14 +192,8 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy } } - private static SimilarityMeasure getSimilarityMeasure( - AuthorizationResponseEntity.Configuration configuration - ) { - validateFieldPresent( - AuthorizationResponseEntity.Configuration.SIMILARITY, - configuration.similarity(), - TaskType.TEXT_EMBEDDING - ); + private static SimilarityMeasure getSimilarityMeasure(AuthorizationResponseEntity.Configuration configuration) { + validateFieldPresent(AuthorizationResponseEntity.Configuration.SIMILARITY, configuration.similarity(), TaskType.TEXT_EMBEDDING); return SimilarityMeasure.fromString(configuration.similarity()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index df733298a47bb..218d211b7c9e9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -345,9 +345,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse( - AuthorizationModel.of(new AuthorizationResponseEntity(List.of(completionModel)), url) - ); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(completionModel)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 45f60b9346b4c..b51faa46e5f6f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -192,9 +192,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { ); try (var sender = senderFactory.createSender()) { - var responseData = AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints( - eisGatewayUrl - ); + var responseData = AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints(eisGatewayUrl); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); From f6b2d744faecac5ee6da70a903d583d4409dc337 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 19 Nov 2025 09:17:52 -0500 Subject: [PATCH 09/24] Working integration tests --- ...etModelsWithElasticInferenceServiceIT.java | 3 +- ...icInferenceServiceAuthorizationServer.java | 97 +++++++---- .../AuthorizationTaskExecutorIT.java | 44 +++-- ...horizationTaskExecutorMultipleNodesIT.java | 7 +- .../InternalPreconfiguredEndpoints.java | 152 ------------------ .../InternalPreconfiguredEndpointsTests.java | 23 --- ...rviceAuthorizationRequestHandlerTests.java | 13 +- .../AuthorizationResponseEntityTests.java | 9 +- 8 files changed, 116 insertions(+), 232 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpointsTests.java diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index b413a38a052e8..5c066e47f1f22 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -18,6 +18,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -51,7 +52,7 @@ public void testGetDefaultEndpoints() throws IOException { assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".gp-llm-v2-chat_completion", TaskType.CHAT_COMPLETION); - assertInferenceIdTaskType(allModels, ".elser-2-elastic", TaskType.SPARSE_EMBEDDING); + assertInferenceIdTaskType(allModels, DEFAULT_ELSER_ENDPOINT_ID_V2, TaskType.SPARSE_EMBEDDING); assertInferenceIdTaskType(allModels, ".jina-embeddings-v3", TaskType.TEXT_EMBEDDING); assertInferenceIdTaskType(allModels, ".elastic-rerank-v1", TaskType.RERANK); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 8bb9a7e576baf..5fd08dd7e724a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -23,38 +23,79 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class); private final MockWebServer webServer = new MockWebServer(); - public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() { - var server = new MockElasticInferenceServiceAuthorizationServer(); - - server.enqueueAuthorizeAllModelsResponse(); - return server; - } - public void enqueueAuthorizeAllModelsResponse() { String responseJson = """ { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "gp-llm-v2", - "task_types": ["chat"] - }, - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] + "inference_endpoints": [ + { + "id": ".rainbow-sprinkles-elastic", + "model_name": "rainbow-sprinkles", + "task_type": "chat_completion", + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01", + "end_of_life_date": "2025-12-31" + }, + { + "id": ".gp-llm-v2-chat_completion", + "model_name": "gp-llm-v2", + "task_type": "chat_completion", + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01", + "end_of_life_date": "2025-12-31" + }, + { + "id": ".elser-2-elastic", + "model_name": "elser_model_2", + "task_type": "sparse_embedding", + "status": "preview", + "properties": [ + "english" + ], + "release_date": "2024-05-01", + "configuration": { + "chunking_settings": { + "strategy": "sentence", + "max_chunk_size": 250, + "sentence_overlap": 1 } - ] + } + }, + { + "id": ".jina-embeddings-v3", + "model_name": "jina-embeddings-v3", + "task_type": "text_embedding", + "status": "beta", + "properties": [ + "multilingual", + "open-weights" + ], + "release_date": "2024-05-01", + "configuration": { + "similarity": "cosine", + "dimensions": 1024, + "element_type": "float", + "chunking_settings": { + "strategy": "word", + "max_chunk_size": 500, + "overlap": 2 + } + } + }, + { + "id": ".elastic-rerank-v1", + "model_name": "elastic-rerank-v1", + "task_type": "rerank", + "status": "preview", + "properties": [], + "release_date": "2024-05-01" + } + ] } """; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index e506f20899581..e0305133be3e9 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; @@ -37,6 +36,7 @@ import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -47,6 +47,30 @@ import static org.hamcrest.Matchers.not; public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { + + // rainbow-sprinkles + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; + + // gp-llm-v2 + public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; + + // elser-2 + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = ".elser-2-elastic"; + + // multilingual-text-embed + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; + + // rerank-v1 + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; + + public static final Set EIS_PRECONFIGURED_ENDPOINT_IDS = Set.of( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, + DEFAULT_ELSER_ENDPOINT_ID_V2, + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + DEFAULT_RERANK_ENDPOINT_ID_V1 + ); + public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; public static final String EMPTY_AUTH_RESPONSE = """ @@ -94,7 +118,7 @@ public void shutdown() { static void removeEisPreconfiguredEndpoints(ModelRegistry modelRegistry) { // Delete all the eis preconfigured endpoints var listener = new PlainActionFuture(); - modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener); + modelRegistry.deleteModels(EIS_PRECONFIGURED_ENDPOINT_IDS, listener); listener.actionGet(TimeValue.THIRTY_SECONDS); } @@ -149,7 +173,7 @@ static void assertNoAuthorizedEisEndpoints( var eisEndpoints = getEisEndpoints(modelRegistry); assertThat(eisEndpoints, empty()); - for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) { + for (String eisPreconfiguredEndpoints : EIS_PRECONFIGURED_ENDPOINT_IDS) { assertFalse(modelRegistry.containsPreconfiguredInferenceEndpointId(eisPreconfiguredEndpoints)); } } @@ -250,15 +274,13 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) { var rainbowSprinklesModel = eisEndpoints.get(0); assertChatCompletionUnparsedModel(rainbowSprinklesModel); - assertTrue( - modelRegistry.containsPreconfiguredInferenceEndpointId(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) - ); + assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); } static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); - assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertThat(rainbowSprinklesModel.inferenceEntityId(), is(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); } public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { @@ -293,12 +315,12 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); assertThat(eisEndpoints.size(), is(2)); - assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); - assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertTrue(eisEndpoints.containsKey(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertChatCompletionUnparsedModel(eisEndpoints.get(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); - assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); + assertTrue(eisEndpoints.containsKey(DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); - var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); + var textEmbeddingEndpoint = eisEndpoints.get(DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING)); assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME)); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index 2bc5e2df2853b..e94004c0f0972 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; import org.junit.AfterClass; @@ -34,6 +33,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.EMPTY_AUTH_RESPONSE; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.cancelAuthorizationTask; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; @@ -130,10 +130,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun var rainbowSprinklesEndpoint = eisEndpoints.get(0); assertThat(rainbowSprinklesEndpoint.getService(), is(ElasticInferenceService.NAME)); - assertThat( - rainbowSprinklesEndpoint.getInferenceEntityId(), - is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) - ); + assertThat(rainbowSprinklesEndpoint.getInferenceEntityId(), is(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); }); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java index 8b1ee97a2840d..df95a936c2c91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java @@ -7,162 +7,10 @@ package org.elasticsearch.xpack.inference.services.elastic; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Function; - -import static java.util.stream.Collectors.toMap; - -/** - * Represents the preconfigured endpoints that are included in Elasticsearch. EIS will support dynamic preconfigured endpoints which means - * it can provide new preconfigured endpoints that do not exist in the source here. - */ public class InternalPreconfiguredEndpoints { - // rainbow-sprinkles - public static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; - - // gp-llm-v2 - public static final String GP_LLM_V2_MODEL_ID = "gp-llm-v2"; - public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; - // elser-2 - public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = ".elser-2-elastic"; - // multilingual-text-embed - public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; - public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "jina-embeddings-v3"; - public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; - - // rerank-v1 - public static final String DEFAULT_RERANK_MODEL_ID_V1 = "elastic-rerank-v1"; - public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; - - public record MinimalModel( - ModelConfigurations configurations, - ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings - ) {} - - private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SERVICE_SETTINGS = - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - private static final ElasticInferenceServiceCompletionServiceSettings GP_LLM_V2_COMPLETION_SERVICE_SETTINGS = - new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V2_MODEL_ID); - private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_EMBEDDINGS_SERVICE_SETTINGS = - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null); - private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS = - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - defaultDenseTextEmbeddingsSimilarity(), - DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - null - ); - private static final ElasticInferenceServiceRerankServiceSettings RERANK_SERVICE_SETTINGS = - new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1); - - // A single model name can map to multiple inference endpoints, so we need a String to a List - private static final Map> MODEL_NAME_TO_MINIMAL_MODELS = Map.of( - DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, - List.of( - new MinimalModel( - new ModelConfigurations( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - COMPLETION_SERVICE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - COMPLETION_SERVICE_SETTINGS - ) - ), - GP_LLM_V2_MODEL_ID, - List.of( - new MinimalModel( - new ModelConfigurations( - GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - GP_LLM_V2_COMPLETION_SERVICE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - GP_LLM_V2_COMPLETION_SERVICE_SETTINGS - ) - ), - DEFAULT_ELSER_2_MODEL_ID, - List.of( - new MinimalModel( - new ModelConfigurations( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - SPARSE_EMBEDDINGS_SERVICE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - SPARSE_EMBEDDINGS_SERVICE_SETTINGS - ) - ), - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - List.of( - new MinimalModel( - new ModelConfigurations( - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, - TaskType.TEXT_EMBEDDING, - ElasticInferenceService.NAME, - DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS - ) - ), - DEFAULT_RERANK_MODEL_ID_V1, - List.of( - new MinimalModel( - new ModelConfigurations( - DEFAULT_RERANK_ENDPOINT_ID_V1, - TaskType.RERANK, - ElasticInferenceService.NAME, - RERANK_SERVICE_SETTINGS, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - RERANK_SERVICE_SETTINGS - ) - ) - ); - - private static final Map INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODELS.entrySet() - .stream() - .flatMap(entry -> entry.getValue().stream()) - .collect(toMap(m -> m.configurations().getInferenceEntityId(), Function.identity())); - - public static final Set EIS_PRECONFIGURED_ENDPOINT_IDS = Set.copyOf(INFERENCE_ID_TO_MINIMAL_MODEL.keySet()); - - public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { - return SimilarityMeasure.COSINE; - } - - public static List getWithModelName(String modelName) { - var minimalModels = MODEL_NAME_TO_MINIMAL_MODELS.get(modelName); - if (minimalModels == null) { - return List.of(); - } - - return minimalModels; - } - - public static MinimalModel getWithInferenceId(String inferenceId) { - return INFERENCE_ID_TO_MINIMAL_MODEL.get(inferenceId); - } - private InternalPreconfiguredEndpoints() {} } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpointsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpointsTests.java deleted file mode 100644 index cfd166c7d240e..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpointsTests.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic; - -import org.elasticsearch.test.ESTestCase; - -import static org.hamcrest.Matchers.hasSize; - -public class InternalPreconfiguredEndpointsTests extends ESTestCase { - public void testGetWithModelName_ReturnsAnEmptyList_IfNameDoesNotExist() { - assertThat(InternalPreconfiguredEndpoints.getWithModelName("non-existent-model"), hasSize(0)); - } - - public void testGetWithModelName_ReturnsChatCompletionModels() { - var models = InternalPreconfiguredEndpoints.getWithModelName(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - assertThat(models, hasSize(1)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 45f60b9346b4c..c8483c3d79eef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -45,6 +45,7 @@ import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisElserAuthorizationResponse; @@ -192,9 +193,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { ); try (var sender = senderFactory.createSender()) { - var responseData = AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints( - eisGatewayUrl - ); + var responseData = AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints(eisGatewayUrl); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); @@ -252,7 +251,7 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I var authResponse = listener.actionGet(TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - assertThat(authResponse.getEndpointIds(), is(Set.of(".elastic-elser-v2"))); + assertThat(authResponse.getEndpointIds(), is(Set.of(DEFAULT_ELSER_ENDPOINT_ID_V2))); assertTrue(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); @@ -292,13 +291,11 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { authHandler.getAuthorization(onlyOnceListener, sender); authHandler.waitForAuthRequestCompletion(TIMEOUT); - var endpointId = ".elastic-elser-v2"; - var authResponse = listener.actionGet(TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - assertThat(authResponse.getEndpointIds(), is(Set.of(endpointId))); + assertThat(authResponse.getEndpointIds(), is(Set.of(DEFAULT_ELSER_ENDPOINT_ID_V2))); assertTrue(authResponse.isAuthorized()); - assertThat(authResponse.getEndpoints(Set.of(endpointId)), is(elserResponse.expectedEndpoints())); + assertThat(authResponse.getEndpoints(Set.of(DEFAULT_ELSER_ENDPOINT_ID_V2)), is(elserResponse.expectedEndpoints())); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 57b08507c2099..3357c9872222d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -39,6 +39,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; @@ -143,7 +144,7 @@ public static EisAuthorizationResponse getEisElserAuthorizationResponse(String u var authorizedEndpoints = List.of( new AuthorizationResponseEntity.AuthorizedEndpoint( - ".elastic-elser-v2", + DEFAULT_ELSER_ENDPOINT_ID_V2, "elser_model_2", "sparse_embedding", "preview", @@ -166,7 +167,7 @@ public static EisAuthorizationResponse getEisElserAuthorizationResponse(String u new AuthorizationResponseEntity(authorizedEndpoints), List.of( new ElasticInferenceServiceSparseEmbeddingsModel( - ".elastic-elser-v2", + DEFAULT_ELSER_ENDPOINT_ID_V2, TaskType.SPARSE_EMBEDDING, ElasticInferenceService.NAME, new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), @@ -194,7 +195,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn null ), new AuthorizationResponseEntity.AuthorizedEndpoint( - ".elastic-elser-v2", + DEFAULT_ELSER_ENDPOINT_ID_V2, "elser_model_2", "sparse_embedding", "preview", @@ -251,7 +252,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn new ElasticInferenceServiceComponents(url) ), new ElasticInferenceServiceSparseEmbeddingsModel( - ".elastic-elser-v2", + DEFAULT_ELSER_ENDPOINT_ID_V2, TaskType.SPARSE_EMBEDDING, ElasticInferenceService.NAME, new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), From f2a3d1a74d12af6da6e0c4cf57e00e97d8b81141 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 19 Nov 2025 09:21:38 -0500 Subject: [PATCH 10/24] Fixing forbidden calls --- .../authorization/AuthorizationModel.java | 34 ++++--------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java index a1f2cd0bdba83..505280564f13d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -150,9 +150,7 @@ private static AuthorizationResponseEntity.Configuration getConfigurationOrEmpty return AuthorizationResponseEntity.Configuration.EMPTY; } - private static Map getChunkingSettingsMap( - AuthorizationResponseEntity.Configuration configuration - ) { + private static Map getChunkingSettingsMap(AuthorizationResponseEntity.Configuration configuration) { return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>()); } @@ -181,21 +179,9 @@ private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEm } private static void validateConfigurationForTextEmbedding(AuthorizationResponseEntity.Configuration config) { - validateFieldPresent( - AuthorizationResponseEntity.Configuration.ELEMENT_TYPE, - config.elementType(), - TaskType.TEXT_EMBEDDING - ); - validateFieldPresent( - AuthorizationResponseEntity.Configuration.DIMENSIONS, - config.dimensions(), - TaskType.TEXT_EMBEDDING - ); - validateFieldPresent( - AuthorizationResponseEntity.Configuration.SIMILARITY, - config.similarity(), - TaskType.TEXT_EMBEDDING - ); + validateFieldPresent(AuthorizationResponseEntity.Configuration.ELEMENT_TYPE, config.elementType(), TaskType.TEXT_EMBEDDING); + validateFieldPresent(AuthorizationResponseEntity.Configuration.DIMENSIONS, config.dimensions(), TaskType.TEXT_EMBEDDING); + validateFieldPresent(AuthorizationResponseEntity.Configuration.SIMILARITY, config.similarity(), TaskType.TEXT_EMBEDDING); } private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) { @@ -206,14 +192,8 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy } } - private static SimilarityMeasure getSimilarityMeasure( - AuthorizationResponseEntity.Configuration configuration - ) { - validateFieldPresent( - AuthorizationResponseEntity.Configuration.SIMILARITY, - configuration.similarity(), - TaskType.TEXT_EMBEDDING - ); + private static SimilarityMeasure getSimilarityMeasure(AuthorizationResponseEntity.Configuration configuration) { + validateFieldPresent(AuthorizationResponseEntity.Configuration.SIMILARITY, configuration.similarity(), TaskType.TEXT_EMBEDDING); return SimilarityMeasure.fromString(configuration.similarity()); } @@ -288,7 +268,7 @@ public List getEndpoints(Set endpointIds) @Override public String toString() { - return String.format("AuthorizationModel{authorizedEndpoints=%s, taskTypes=%s}", authorizedEndpoints, taskTypes); + return Strings.format("AuthorizationModel{authorizedEndpoints=%s, taskTypes=%s}", authorizedEndpoints, taskTypes); } @Override From b2aecd2fab8b4e0a2f2cbc39dc67793ab8d84cf0 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 19 Nov 2025 10:47:20 -0500 Subject: [PATCH 11/24] Fixing tests --- .../elastic/response/AuthorizationResponseEntityTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 3357c9872222d..218f5ece71efb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -56,7 +56,7 @@ public record EisAuthorizationResponse( { "inference_endpoints": [ { - "id": ".elastic-elser-v2", + "id": ".elser-2-elastic", "model_name": "elser_model_2", "task_type": "sparse_embedding", "status": "preview", @@ -91,7 +91,7 @@ public record EisAuthorizationResponse( "end_of_life_date": "2025-12-31" }, { - "id": ".elastic-elser-v2", + "id": ".elser-2-elastic", "model_name": "elser_model_2", "task_type": "sparse_embedding", "status": "preview", From 5715af616cdff62c2f45e51dc6483ab7a4589f4d Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 19 Nov 2025 13:42:07 -0500 Subject: [PATCH 12/24] Fixing integration tests --- .../AuthorizationTaskExecutorIT.java | 92 ++---- ...horizationTaskExecutorMultipleNodesIT.java | 15 +- .../inference/integration/CCMServiceIT.java | 6 +- .../AuthorizationResponseEntityTests.java | 301 +++++++++++------- 4 files changed, 234 insertions(+), 180 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index e0305133be3e9..02e6442e95528 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -42,57 +43,30 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { - // rainbow-sprinkles - public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; - - // gp-llm-v2 - public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; - - // elser-2 - public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = ".elser-2-elastic"; - - // multilingual-text-embed - public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; - - // rerank-v1 - public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; - public static final Set EIS_PRECONFIGURED_ENDPOINT_IDS = Set.of( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, - DEFAULT_ELSER_ENDPOINT_ID_V2, - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, - DEFAULT_RERANK_ENDPOINT_ID_V1 + RAINBOW_SPRINKLES_ENDPOINT_ID_V1, + ELSER_V2_ENDPOINT_ID, + JINA_EMBED_ENDPOINT_ID, + RERANK_V1_ENDPOINT_ID ); public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; - public static final String EMPTY_AUTH_RESPONSE = """ - { - "models": [ - ] - } - """; - - public static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; + private static AuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; private ModelRegistry modelRegistry; private AuthorizationTaskExecutor authorizationTaskExecutor; @@ -101,7 +75,8 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); + chatCompletionResponse = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl); } @Before @@ -147,7 +122,7 @@ protected Collection> getPlugins() { public void testCreatesEisChatCompletionEndpoint() throws Exception { assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); @@ -252,13 +227,13 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception { public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception { assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); // Simulate that the model is no longer authorized - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); @@ -274,53 +249,44 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) { var rainbowSprinklesModel = eisEndpoints.get(0); assertChatCompletionUnparsedModel(rainbowSprinklesModel); - assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); } static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); - assertThat(rainbowSprinklesModel.inferenceEntityId(), is(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertThat(rainbowSprinklesModel.inferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); } public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); // Simulate that the model is no longer authorized - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); // Simulate that a text embedding model is now authorized - var authorizedTextEmbeddingResponse = """ - { - "models": [ - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse)); + var jinaEmbedResponse = AuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponse.responseJson())); + restartPollingTaskAndWaitForAuthResponse(); var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); assertThat(eisEndpoints.size(), is(2)); - assertTrue(eisEndpoints.containsKey(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); - assertChatCompletionUnparsedModel(eisEndpoints.get(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertTrue(eisEndpoints.containsKey(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); + assertChatCompletionUnparsedModel(eisEndpoints.get(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); - assertTrue(eisEndpoints.containsKey(DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); + assertTrue(eisEndpoints.containsKey(JINA_EMBED_ENDPOINT_ID)); - var textEmbeddingEndpoint = eisEndpoints.get(DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); + var textEmbeddingEndpoint = eisEndpoints.get(JINA_EMBED_ENDPOINT_ID); assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING)); assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME)); } @@ -329,7 +295,7 @@ public void testRestartsTaskAfterAbort() throws Exception { // Ensure the task is created and we get an initial authorization response assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); // Abort the task and ensure it is restarted restartPollingTaskAndWaitForAuthResponse(); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index e94004c0f0972..8840051841079 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -32,11 +33,11 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; -import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; -import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.EMPTY_AUTH_RESPONSE; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.cancelAuthorizationTask; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -54,12 +55,14 @@ public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase { private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; + private static AuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; @BeforeClass public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); + chatCompletionResponse = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl); } @Before @@ -110,7 +113,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun ); // queue a response that authorizes one model - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); assertTrue("expected the node to shutdown properly", internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); @@ -130,7 +133,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun var rainbowSprinklesEndpoint = eisEndpoints.get(0); assertThat(rainbowSprinklesEndpoint.getService(), is(ElasticInferenceService.NAME)); - assertThat(rainbowSprinklesEndpoint.getInferenceEntityId(), is(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertThat(rainbowSprinklesEndpoint.getInferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); }); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java index b206ee02b6602..e3c25fba317c4 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMModel; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; +import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -31,7 +32,6 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTH_TASK_ACTION; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.assertChatCompletionEndpointExists; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.getEisEndpoints; @@ -48,6 +48,7 @@ public class CCMServiceIT extends CCMSingleNodeIT { private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; + private static AuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; private AuthorizationTaskExecutor authorizationTaskExecutor; private ModelRegistry modelRegistry; @@ -75,6 +76,7 @@ public void delete(ActionListener listener) { public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); + chatCompletionResponse = AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse(gatewayUrl); } @Before @@ -137,7 +139,7 @@ public void testCreatesEisChatCompletionEndpoint() throws Exception { var eisEndpoints = getEisEndpoints(modelRegistry); assertThat(eisEndpoints, empty()); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); var listener = new TestPlainActionFuture(); ccmService.get().storeConfiguration(new CCMModel(new SecureString("secret".toCharArray())), listener); listener.actionGet(TimeValue.THIRTY_SECONDS); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 218f5ece71efb..25152959219ac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -39,12 +39,27 @@ import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializationTestCase { + // rainbow-sprinkles + public static final String RAINBOW_SPRINKLES_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; + public static final String RAINBOW_SPRINKLES_MODEL_NAME = "rainbow-sprinkles"; + + // elser-2 + public static final String ELSER_V2_ENDPOINT_ID = ".elser-2-elastic"; + public static final String ELSER_V2_MODEL_NAME = "elser_model_2"; + + // multilingual-text-embed + public static final String JINA_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; + public static final String JINA_EMBED_MODEL_NAME = "jina-embeddings-v3"; + + // rerank-v1 + public static final String RERANK_V1_ENDPOINT_ID = ".elastic-rerank-v1"; + public static final String RERANK_V1_MODEL_NAME = "elastic-rerank-v2"; + public record EisAuthorizationResponse( String responseJson, AuthorizationResponseEntity responseEntity, @@ -52,6 +67,58 @@ public record EisAuthorizationResponse( Set inferenceIds ) {} + public static final String EIS_EMPTY_RESPONSE = """ + { + "inference_endpoints": [] + } + """; + + public static final String EIS_RAINBOW_SPRINKLES_RESPONSE = """ + { + "inference_endpoints": [ + { + "id": ".rainbow-sprinkles-elastic", + "model_name": "rainbow-sprinkles", + "task_type": "chat_completion", + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01", + "end_of_life_date": "2025-12-31" + } + ] + } + """; + + public static final String EIS_JINA_EMBED_RESPONSE = """ + { + "inference_endpoints": [ + { + "id": ".jina-embeddings-v3", + "model_name": "jina-embeddings-v3", + "task_type": "text_embedding", + "status": "beta", + "properties": [ + "multilingual", + "open-weights" + ], + "release_date": "2024-05-01", + "configuration": { + "similarity": "cosine", + "dimensions": 1024, + "element_type": "float", + "chunking_settings": { + "strategy": "word", + "max_chunk_size": 500, + "overlap": 2 + } + } + } + ] + } + """; + public static final String EIS_ELSER_RESPONSE = """ { "inference_endpoints": [ @@ -141,92 +208,57 @@ public record EisAuthorizationResponse( """; public static EisAuthorizationResponse getEisElserAuthorizationResponse(String url) { - - var authorizedEndpoints = List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( - DEFAULT_ELSER_ENDPOINT_ID_V2, - "elser_model_2", - "sparse_embedding", - "preview", - List.of("english"), - "2024-05-01", - null, - new AuthorizationResponseEntity.Configuration( - null, - null, - null, - Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) - ) - ) - ); + var authorizedEndpoints = List.of(createElserAuthorizedEndpoint()); var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_ELSER_RESPONSE, new AuthorizationResponseEntity(authorizedEndpoints), - List.of( - new ElasticInferenceServiceSparseEmbeddingsModel( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new SentenceBoundaryChunkingSettings(250, 1) - ) - ), + List.of(createElserExpectedEndpoint(url)), inferenceIds ); } - public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { + private static ElasticInferenceServiceModel createElserExpectedEndpoint(String url) { + return new ElasticInferenceServiceSparseEmbeddingsModel( + ELSER_V2_ENDPOINT_ID, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(ELSER_V2_MODEL_NAME, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new SentenceBoundaryChunkingSettings(250, 1) + ); + } - var authorizedEndpoints = List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( - ".rainbow-sprinkles-elastic", - "rainbow-sprinkles", - "chat_completion", - "ga", - List.of("multilingual"), - "2024-05-01", - "2025-12-31", - null - ), - new AuthorizationResponseEntity.AuthorizedEndpoint( - DEFAULT_ELSER_ENDPOINT_ID_V2, - "elser_model_2", - "sparse_embedding", - "preview", - List.of("english"), - "2024-05-01", + private static AuthorizationResponseEntity.AuthorizedEndpoint createElserAuthorizedEndpoint() { + return new AuthorizationResponseEntity.AuthorizedEndpoint( + ELSER_V2_ENDPOINT_ID, + ELSER_V2_MODEL_NAME, + "sparse_embedding", + "preview", + List.of("english"), + "2024-05-01", + null, + new AuthorizationResponseEntity.Configuration( null, - new AuthorizationResponseEntity.Configuration( - null, - null, - null, - Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) - ) - ), - new AuthorizationResponseEntity.AuthorizedEndpoint( - ".jina-embeddings-v3", - "jina-embeddings-v3", - "text_embedding", - "beta", - List.of("multilingual", "open-weights"), - "2024-05-01", null, - new AuthorizationResponseEntity.Configuration( - "cosine", - 1024, - "float", - Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2) - ) - ), + null, + Map.of("strategy", "sentence", "max_chunk_size", 250, "sentence_overlap", 1) + ) + ); + } + + public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { + var authorizedEndpoints = List.of( + createRainbowSprinklesAuthorizedEndpoint(), + createElserAuthorizedEndpoint(), + createJinaEmbedAuthorizedEndpoint(), new AuthorizationResponseEntity.AuthorizedEndpoint( - ".elastic-rerank-v1", - "elastic-rerank-v1", + RERANK_V1_ENDPOINT_ID, + RERANK_V1_MODEL_NAME, "rerank", "preview", List.of(), @@ -242,45 +274,14 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn EIS_AUTHORIZATION_RESPONSE_V2, new AuthorizationResponseEntity(authorizedEndpoints), List.of( - new ElasticInferenceServiceCompletionModel( - ".rainbow-sprinkles-elastic", - TaskType.CHAT_COMPLETION, - ElasticInferenceService.NAME, - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) - ), - new ElasticInferenceServiceSparseEmbeddingsModel( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser_model_2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new SentenceBoundaryChunkingSettings(250, 1) - ), - new ElasticInferenceServiceDenseTextEmbeddingsModel( - ".jina-embeddings-v3", - TaskType.TEXT_EMBEDDING, - ElasticInferenceService.NAME, - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - "jina-embeddings-v3", - SimilarityMeasure.COSINE, - 1024, - null - ), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url), - new WordBoundaryChunkingSettings(500, 2) - ), + createRainbowSprinklesExpectedEndpoint(url), + createElserExpectedEndpoint(url), + createJinaExpectedEndpoint(url), new ElasticInferenceServiceRerankModel( - ".elastic-rerank-v1", + RERANK_V1_ENDPOINT_ID, TaskType.RERANK, ElasticInferenceService.NAME, - new ElasticInferenceServiceRerankServiceSettings("elastic-rerank-v1"), + new ElasticInferenceServiceRerankServiceSettings(RERANK_V1_MODEL_NAME), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, new ElasticInferenceServiceComponents(url) @@ -290,6 +291,88 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn ); } + private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprinklesAuthorizedEndpoint() { + return new AuthorizationResponseEntity.AuthorizedEndpoint( + RAINBOW_SPRINKLES_ENDPOINT_ID_V1, + RAINBOW_SPRINKLES_MODEL_NAME, + "chat_completion", + "ga", + List.of("multilingual"), + "2024-05-01", + "2025-12-31", + null + ); + } + + private static ElasticInferenceServiceModel createRainbowSprinklesExpectedEndpoint(String url) { + return new ElasticInferenceServiceCompletionModel( + RAINBOW_SPRINKLES_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(RAINBOW_SPRINKLES_MODEL_NAME), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + } + + public static EisAuthorizationResponse getEisRainbowSprinklesAuthorizationResponse(String url) { + var authorizedEndpoints = List.of(createRainbowSprinklesAuthorizedEndpoint()); + + var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); + + return new EisAuthorizationResponse( + EIS_RAINBOW_SPRINKLES_RESPONSE, + new AuthorizationResponseEntity(authorizedEndpoints), + List.of(createRainbowSprinklesExpectedEndpoint(url)), + inferenceIds + ); + } + + public static EisAuthorizationResponse getEisJinaEmbedAuthorizationResponse(String url) { + var authorizedEndpoints = List.of(createJinaEmbedAuthorizedEndpoint()); + + var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); + + return new EisAuthorizationResponse( + EIS_JINA_EMBED_RESPONSE, + new AuthorizationResponseEntity(authorizedEndpoints), + List.of(createJinaExpectedEndpoint(url)), + inferenceIds + ); + } + + private static AuthorizationResponseEntity.AuthorizedEndpoint createJinaEmbedAuthorizedEndpoint() { + return new AuthorizationResponseEntity.AuthorizedEndpoint( + JINA_EMBED_ENDPOINT_ID, + JINA_EMBED_MODEL_NAME, + "text_embedding", + "beta", + List.of("multilingual", "open-weights"), + "2024-05-01", + null, + new AuthorizationResponseEntity.Configuration( + "cosine", + 1024, + "float", + Map.of("strategy", "word", "max_chunk_size", 500, "overlap", 2) + ) + ); + } + + private static ElasticInferenceServiceModel createJinaExpectedEndpoint(String url) { + return new ElasticInferenceServiceDenseTextEmbeddingsModel( + JINA_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(JINA_EMBED_MODEL_NAME, SimilarityMeasure.COSINE, 1024, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url), + new WordBoundaryChunkingSettings(500, 2) + ); + } + public static AuthorizationResponseEntity createResponse() { return new AuthorizationResponseEntity( randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) From 7834662b41a12249a11d8b6d00536f18f83450c3 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 19 Nov 2025 15:21:53 -0500 Subject: [PATCH 13/24] Refactoring tests --- .../qa/inference-service-tests/build.gradle | 4 + .../inference/BaseMockEISAuthServerTest.java | 4 +- ...etModelsWithElasticInferenceServiceIT.java | 16 ++- ...icInferenceServiceAuthorizationServer.java | 79 +---------- .../AuthorizationTaskExecutorIT.java | 20 +-- ...horizationTaskExecutorMultipleNodesIT.java | 4 +- .../authorization/AuthorizationModel.java | 2 +- .../response/AuthorizationResponseEntity.java | 18 +-- .../AuthorizationModelTests.java | 56 +++++--- ...rviceAuthorizationRequestHandlerTests.java | 12 +- .../AuthorizationResponseEntityTests.java | 125 ++++++++++++++---- 11 files changed, 185 insertions(+), 155 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/build.gradle b/x-pack/plugin/inference/qa/inference-service-tests/build.gradle index 593258ef48e1f..c5e3bebd2ca6a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/build.gradle +++ b/x-pack/plugin/inference/qa/inference-service-tests/build.gradle @@ -5,6 +5,10 @@ dependencies { javaRestTestImplementation project(path: xpackModule('core')) javaRestTestImplementation project(path: xpackModule('inference')) clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') + + // Allow javaRestTest to see unit-test classes from x-pack:plugin:inference so we can import some variables + javaRestTestImplementation(testArtifact(project(xpackModule('inference')))) + // Added this to have access to MockWebServer within the tests javaRestTestImplementation(testArtifact(project(xpackModule('core')))) } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index be4bacec10600..143ce5d1fe2be 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -17,7 +17,6 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.junit.Before; import org.junit.ClassRule; import org.junit.Rule; @@ -26,6 +25,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; public class BaseMockEISAuthServerTest extends ESRestTestCase { @@ -93,6 +93,6 @@ public void ensureEisPreconfiguredEndpointsExist() throws Exception { // available // Technically this only needs to be done before the suite runs but the underlying client is created in @Before and not statically // for the suite - assertBusy(() -> getModel(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2)); + assertBusy(() -> getModel(ELSER_V2_ENDPOINT_ID)); } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 5c066e47f1f22..6cc0128f566ba 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -18,7 +18,11 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -50,11 +54,11 @@ public void testGetDefaultEndpoints() throws IOException { assertEquals("chat_completion", model.get("task_type")); } - assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); - assertInferenceIdTaskType(allModels, ".gp-llm-v2-chat_completion", TaskType.CHAT_COMPLETION); - assertInferenceIdTaskType(allModels, DEFAULT_ELSER_ENDPOINT_ID_V2, TaskType.SPARSE_EMBEDDING); - assertInferenceIdTaskType(allModels, ".jina-embeddings-v3", TaskType.TEXT_EMBEDDING); - assertInferenceIdTaskType(allModels, ".elastic-rerank-v1", TaskType.RERANK); + assertInferenceIdTaskType(allModels, RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION); + assertInferenceIdTaskType(allModels, GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION); + assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING); + assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING); + assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK); } private static void assertInferenceIdTaskType(List> models, String inferenceId, TaskType taskType) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 5fd08dd7e724a..f6e5b2f515f70 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -17,6 +17,7 @@ import org.junit.runners.model.Statement; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints; public class MockElasticInferenceServiceAuthorizationServer implements TestRule { @@ -24,82 +25,8 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule private final MockWebServer webServer = new MockWebServer(); public void enqueueAuthorizeAllModelsResponse() { - String responseJson = """ - { - "inference_endpoints": [ - { - "id": ".rainbow-sprinkles-elastic", - "model_name": "rainbow-sprinkles", - "task_type": "chat_completion", - "status": "ga", - "properties": [ - "multilingual" - ], - "release_date": "2024-05-01", - "end_of_life_date": "2025-12-31" - }, - { - "id": ".gp-llm-v2-chat_completion", - "model_name": "gp-llm-v2", - "task_type": "chat_completion", - "status": "ga", - "properties": [ - "multilingual" - ], - "release_date": "2024-05-01", - "end_of_life_date": "2025-12-31" - }, - { - "id": ".elser-2-elastic", - "model_name": "elser_model_2", - "task_type": "sparse_embedding", - "status": "preview", - "properties": [ - "english" - ], - "release_date": "2024-05-01", - "configuration": { - "chunking_settings": { - "strategy": "sentence", - "max_chunk_size": 250, - "sentence_overlap": 1 - } - } - }, - { - "id": ".jina-embeddings-v3", - "model_name": "jina-embeddings-v3", - "task_type": "text_embedding", - "status": "beta", - "properties": [ - "multilingual", - "open-weights" - ], - "release_date": "2024-05-01", - "configuration": { - "similarity": "cosine", - "dimensions": 1024, - "element_type": "float", - "chunking_settings": { - "strategy": "word", - "max_chunk_size": 500, - "overlap": 2 - } - } - }, - { - "id": ".elastic-rerank-v1", - "model_name": "elastic-rerank-v1", - "task_type": "rerank", - "status": "preview", - "properties": [], - "release_date": "2024-05-01" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var authResponse = getEisAuthorizationResponseWithMultipleEndpoints("ignored"); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponse.responseJson())); } public String getUrl() { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 02e6442e95528..9f67e354ed695 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -45,8 +45,8 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; import static org.hamcrest.Matchers.empty; @@ -56,9 +56,9 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { public static final Set EIS_PRECONFIGURED_ENDPOINT_IDS = Set.of( - RAINBOW_SPRINKLES_ENDPOINT_ID_V1, + RAINBOW_SPRINKLES_ENDPOINT_ID, ELSER_V2_ENDPOINT_ID, - JINA_EMBED_ENDPOINT_ID, + JINA_EMBED_V3_ENDPOINT_ID, RERANK_V1_ENDPOINT_ID ); @@ -249,13 +249,13 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) { var rainbowSprinklesModel = eisEndpoints.get(0); assertChatCompletionUnparsedModel(rainbowSprinklesModel); - assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); + assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID)); } static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); - assertThat(rainbowSprinklesModel.inferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); + assertThat(rainbowSprinklesModel.inferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID)); } public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { @@ -281,12 +281,12 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); assertThat(eisEndpoints.size(), is(2)); - assertTrue(eisEndpoints.containsKey(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); - assertChatCompletionUnparsedModel(eisEndpoints.get(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); + assertTrue(eisEndpoints.containsKey(RAINBOW_SPRINKLES_ENDPOINT_ID)); + assertChatCompletionUnparsedModel(eisEndpoints.get(RAINBOW_SPRINKLES_ENDPOINT_ID)); - assertTrue(eisEndpoints.containsKey(JINA_EMBED_ENDPOINT_ID)); + assertTrue(eisEndpoints.containsKey(JINA_EMBED_V3_ENDPOINT_ID)); - var textEmbeddingEndpoint = eisEndpoints.get(JINA_EMBED_ENDPOINT_ID); + var textEmbeddingEndpoint = eisEndpoints.get(JINA_EMBED_V3_ENDPOINT_ID); assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING)); assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME)); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index 8840051841079..0632a478120cd 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -36,7 +36,7 @@ import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.cancelAuthorizationTask; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -133,7 +133,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun var rainbowSprinklesEndpoint = eisEndpoints.get(0); assertThat(rainbowSprinklesEndpoint.getService(), is(ElasticInferenceService.NAME)); - assertThat(rainbowSprinklesEndpoint.getInferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID_V1)); + assertThat(rainbowSprinklesEndpoint.getInferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID)); assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); }); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java index 505280564f13d..0ce773c07d4ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -73,7 +73,7 @@ private static ElasticInferenceServiceModel createModel( ElasticInferenceServiceComponents components ) { try { - var taskType = getTaskType(authorizedEndpoint.taskType()); + var taskType = getTaskType(authorizedEndpoint.taskType().elasticsearchTaskType()); if (taskType == null) { logger.warn(UNKNOWN_TASK_TYPE_LOG_MESSAGE, authorizedEndpoint.id(), authorizedEndpoint.taskType()); return null; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java index 6ba2664e52993..aef87bec4f852 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java @@ -61,7 +61,7 @@ public class AuthorizationResponseEntity implements InferenceServiceResults { public record AuthorizedEndpoint( String id, String modelName, - String taskType, + TaskTypeObject taskType, String status, @Nullable List properties, String releaseDate, @@ -71,7 +71,7 @@ public record AuthorizedEndpoint( private static final String ID = "id"; private static final String MODEL_NAME = "model_name"; - private static final String TASK_TYPE = "task_type"; + private static final String TASK_TYPE = "task_types"; private static final String STATUS = "status"; private static final String PROPERTIES = "properties"; private static final String RELEASE_DATE = "release_date"; @@ -85,7 +85,7 @@ public record AuthorizedEndpoint( args -> new AuthorizedEndpoint( (String) args[0], (String) args[1], - (String) args[2], + (TaskTypeObject) args[2], (String) args[3], (List) args[4], (String) args[5], @@ -97,7 +97,7 @@ public record AuthorizedEndpoint( static { AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(ID)); AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(MODEL_NAME)); - AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(TASK_TYPE)); + AUTHORIZED_ENDPOINT_PARSER.declareObject(constructorArg(), TaskTypeObject.PARSER::apply, new ParseField(TASK_TYPE)); AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(STATUS)); AUTHORIZED_ENDPOINT_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(PROPERTIES)); AUTHORIZED_ENDPOINT_PARSER.declareString(constructorArg(), new ParseField(RELEASE_DATE)); @@ -109,7 +109,7 @@ public AuthorizedEndpoint(StreamInput in) throws IOException { this( in.readString(), in.readString(), - in.readString(), + new TaskTypeObject(in), in.readString(), in.readOptionalCollectionAsList(StreamInput::readString), in.readString(), @@ -122,7 +122,7 @@ public AuthorizedEndpoint(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeString(id); out.writeString(modelName); - out.writeString(taskType); + taskType.writeTo(out); out.writeString(status); out.writeOptionalCollection(properties, StreamOutput::writeString); out.writeString(releaseDate); @@ -183,12 +183,12 @@ public record TaskTypeObject(String eisTaskType, String elasticsearchTaskType) i ); static { - PARSER.declareString(optionalConstructorArg(), new ParseField(EIS_TASK_TYPE_FIELD)); + PARSER.declareString(constructorArg(), new ParseField(EIS_TASK_TYPE_FIELD)); PARSER.declareString(constructorArg(), new ParseField(ELASTICSEARCH_TASK_TYPE_FIELD)); } public TaskTypeObject(StreamInput in) throws IOException { - this(in.readOptionalString(), in.readString()); + this(in.readString(), in.readString()); } @Override @@ -198,7 +198,7 @@ public String toString() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(eisTaskType); + out.writeString(eisTaskType); out.writeString(elasticsearchTaskType); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java index ea8881852b1f7..7902cc1327996 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java @@ -31,6 +31,10 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_CHAT_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMBED_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_SPARSE_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createTaskTypeObject; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -44,8 +48,26 @@ public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { public void testExcludes_EndpointsWithoutValidTaskTypes() { var response = new AuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint("id", "name", "invalid_task_type", "ga", null, "", "", null), - new AuthorizationResponseEntity.AuthorizedEndpoint("id2", "name", TaskType.ANY.toString(), "ga", null, "", "", null) + new AuthorizationResponseEntity.AuthorizedEndpoint( + "id", + "name", + createTaskTypeObject("", "invalid_task_type"), + "ga", + null, + "", + "", + null + ), + new AuthorizationResponseEntity.AuthorizedEndpoint( + "id2", + "name", + createTaskTypeObject("", TaskType.ANY.toString()), + "ga", + null, + "", + "", + null + ) ) ); var auth = AuthorizationModel.of(response, "url"); @@ -62,7 +84,7 @@ public void testReturnsAuthorizedTaskTypes() { new AuthorizationResponseEntity.AuthorizedEndpoint( id1, "name1", - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, "", @@ -72,7 +94,7 @@ public void testReturnsAuthorizedTaskTypes() { new AuthorizationResponseEntity.AuthorizedEndpoint( id2, "name2", - TaskType.SPARSE_EMBEDDING.toString(), + createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), "ga", null, "", @@ -96,7 +118,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { new AuthorizationResponseEntity.AuthorizedEndpoint( id, "name1", - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, "", @@ -107,7 +129,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { new AuthorizationResponseEntity.AuthorizedEndpoint( id, "name2", - TaskType.SPARSE_EMBEDDING.toString(), + createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), "ga", null, "", @@ -138,7 +160,7 @@ public void testReturnsAuthorizedEndpoints() { new AuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, "", @@ -148,7 +170,7 @@ public void testReturnsAuthorizedEndpoints() { new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, - TaskType.TEXT_EMBEDDING.toString(), + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, "", @@ -210,7 +232,7 @@ public void testScopesToTaskType() { new AuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, "", @@ -220,7 +242,7 @@ public void testScopesToTaskType() { new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, - TaskType.TEXT_EMBEDDING.toString(), + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, "", @@ -279,7 +301,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { new AuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, "", @@ -289,7 +311,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, - TaskType.TEXT_EMBEDDING.toString(), + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, "", @@ -304,7 +326,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { new AuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, - TaskType.TEXT_EMBEDDING.toString(), + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, "", @@ -365,7 +387,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { new AuthorizationResponseEntity.AuthorizedEndpoint( idChat, nameChat, - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, "", @@ -375,7 +397,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { new AuthorizationResponseEntity.AuthorizedEndpoint( idSparse, nameSparse, - TaskType.SPARSE_EMBEDDING.toString(), + createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), "ga", null, "", @@ -385,7 +407,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { new AuthorizationResponseEntity.AuthorizedEndpoint( idDense, nameDense, - TaskType.TEXT_EMBEDDING.toString(), + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, "", @@ -395,7 +417,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { new AuthorizationResponseEntity.AuthorizedEndpoint( idRerank, nameRerank, - TaskType.RERANK.toString(), + createTaskTypeObject(EIS_SPARSE_PATH, TaskType.RERANK.toString()), "ga", null, "", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index c8483c3d79eef..9f7726300dae2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -45,9 +44,10 @@ import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.bearerToken; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; -import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisElserAuthorizationResponse; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; @@ -193,7 +193,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { ); try (var sender = senderFactory.createSender()) { - var responseData = AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints(eisGatewayUrl); + var responseData = getEisAuthorizationResponseWithMultipleEndpoints(eisGatewayUrl); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); @@ -251,7 +251,7 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I var authResponse = listener.actionGet(TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - assertThat(authResponse.getEndpointIds(), is(Set.of(DEFAULT_ELSER_ENDPOINT_ID_V2))); + assertThat(authResponse.getEndpointIds(), is(Set.of(ELSER_V2_ENDPOINT_ID))); assertTrue(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); @@ -293,9 +293,9 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { var authResponse = listener.actionGet(TIMEOUT); assertThat(authResponse.getTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - assertThat(authResponse.getEndpointIds(), is(Set.of(DEFAULT_ELSER_ENDPOINT_ID_V2))); + assertThat(authResponse.getEndpointIds(), is(Set.of(ELSER_V2_ENDPOINT_ID))); assertTrue(authResponse.isAuthorized()); - assertThat(authResponse.getEndpoints(Set.of(DEFAULT_ELSER_ENDPOINT_ID_V2)), is(elserResponse.expectedEndpoints())); + assertThat(authResponse.getEndpoints(Set.of(ELSER_V2_ENDPOINT_ID)), is(elserResponse.expectedEndpoints())); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(1)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 25152959219ac..8133932cac804 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -45,20 +45,28 @@ public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializationTestCase { // rainbow-sprinkles - public static final String RAINBOW_SPRINKLES_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; + public static final String RAINBOW_SPRINKLES_ENDPOINT_ID = ".rainbow-sprinkles-elastic"; public static final String RAINBOW_SPRINKLES_MODEL_NAME = "rainbow-sprinkles"; + public static final String EIS_CHAT_PATH = "chat"; + + // gp-llm-v2 + public static final String GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; + public static final String GP_LLM_V1_MODEL_NAME = "gp-llm-v2"; // elser-2 public static final String ELSER_V2_ENDPOINT_ID = ".elser-2-elastic"; public static final String ELSER_V2_MODEL_NAME = "elser_model_2"; + public static final String EIS_SPARSE_PATH = "embed/text/sparse"; // multilingual-text-embed - public static final String JINA_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; - public static final String JINA_EMBED_MODEL_NAME = "jina-embeddings-v3"; + public static final String JINA_EMBED_V3_ENDPOINT_ID = ".jina-embeddings-v3"; + public static final String JINA_EMBED_V3_MODEL_NAME = "jina-embeddings-v3"; + public static final String EIS_EMBED_PATH = "embed/text/dense"; // rerank-v1 public static final String RERANK_V1_ENDPOINT_ID = ".elastic-rerank-v1"; - public static final String RERANK_V1_MODEL_NAME = "elastic-rerank-v2"; + public static final String RERANK_V1_MODEL_NAME = "elastic-rerank-v1"; + public static final String EIS_RERANK_PATH = "rerank/text/text-similarity"; public record EisAuthorizationResponse( String responseJson, @@ -79,7 +87,10 @@ public record EisAuthorizationResponse( { "id": ".rainbow-sprinkles-elastic", "model_name": "rainbow-sprinkles", - "task_type": "chat_completion", + "task_types": { + "eis": "chat", + "elasticsearch": "chat_completion" + }, "status": "ga", "properties": [ "multilingual" @@ -97,7 +108,10 @@ public record EisAuthorizationResponse( { "id": ".jina-embeddings-v3", "model_name": "jina-embeddings-v3", - "task_type": "text_embedding", + "task_types": { + "eis": "embed/text/dense", + "elasticsearch": "text_embedding" + }, "status": "beta", "properties": [ "multilingual", @@ -125,7 +139,10 @@ public record EisAuthorizationResponse( { "id": ".elser-2-elastic", "model_name": "elser_model_2", - "task_type": "sparse_embedding", + "task_types": { + "eis": "embed/text/sparse", + "elasticsearch": "sparse_embedding" + }, "status": "preview", "properties": [ "english" @@ -149,7 +166,10 @@ public record EisAuthorizationResponse( { "id": ".rainbow-sprinkles-elastic", "model_name": "rainbow-sprinkles", - "task_type": "chat_completion", + "task_types": { + "eis": "chat", + "elasticsearch": "chat_completion" + }, "status": "ga", "properties": [ "multilingual" @@ -157,10 +177,26 @@ public record EisAuthorizationResponse( "release_date": "2024-05-01", "end_of_life_date": "2025-12-31" }, + { + "id": ".gp-llm-v2-chat_completion", + "model_name": "gp-llm-v2", + "task_types": { + "eis": "chat", + "elasticsearch": "chat_completion" + }, + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01" + }, { "id": ".elser-2-elastic", "model_name": "elser_model_2", - "task_type": "sparse_embedding", + "task_types": { + "eis": "embed/text/sparse", + "elasticsearch": "sparse_embedding" + }, "status": "preview", "properties": [ "english" @@ -177,7 +213,10 @@ public record EisAuthorizationResponse( { "id": ".jina-embeddings-v3", "model_name": "jina-embeddings-v3", - "task_type": "text_embedding", + "task_types": { + "eis": "embed/text/dense", + "elasticsearch": "text_embedding" + }, "status": "beta", "properties": [ "multilingual", @@ -198,7 +237,10 @@ public record EisAuthorizationResponse( { "id": ".elastic-rerank-v1", "model_name": "elastic-rerank-v1", - "task_type": "rerank", + "task_types": { + "eis": "rerank/text/text-similarity", + "elasticsearch": "rerank" + }, "status": "preview", "properties": [], "release_date": "2024-05-01" @@ -237,7 +279,7 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createElserAuthori return new AuthorizationResponseEntity.AuthorizedEndpoint( ELSER_V2_ENDPOINT_ID, ELSER_V2_MODEL_NAME, - "sparse_embedding", + createTaskTypeObject(EIS_SPARSE_PATH, "sparse_embedding"), "preview", List.of("english"), "2024-05-01", @@ -251,15 +293,20 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createElserAuthori ); } + public static AuthorizationResponseEntity.TaskTypeObject createTaskTypeObject(String eisTaskType, String elasticsearchTaskType) { + return new AuthorizationResponseEntity.TaskTypeObject(eisTaskType, elasticsearchTaskType); + } + public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { var authorizedEndpoints = List.of( createRainbowSprinklesAuthorizedEndpoint(), + createGpLlmV2AuthorizedEndpoint(), createElserAuthorizedEndpoint(), createJinaEmbedAuthorizedEndpoint(), new AuthorizationResponseEntity.AuthorizedEndpoint( RERANK_V1_ENDPOINT_ID, RERANK_V1_MODEL_NAME, - "rerank", + createTaskTypeObject(EIS_RERANK_PATH, "rerank"), "preview", List.of(), "2024-05-01", @@ -275,6 +322,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn new AuthorizationResponseEntity(authorizedEndpoints), List.of( createRainbowSprinklesExpectedEndpoint(url), + createGpLlmV2ExpectedEndpoint(url), createElserExpectedEndpoint(url), createJinaExpectedEndpoint(url), new ElasticInferenceServiceRerankModel( @@ -293,9 +341,9 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprinklesAuthorizedEndpoint() { return new AuthorizationResponseEntity.AuthorizedEndpoint( - RAINBOW_SPRINKLES_ENDPOINT_ID_V1, + RAINBOW_SPRINKLES_ENDPOINT_ID, RAINBOW_SPRINKLES_MODEL_NAME, - "chat_completion", + createTaskTypeObject(EIS_CHAT_PATH, "chat_completion"), "ga", List.of("multilingual"), "2024-05-01", @@ -304,9 +352,34 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprin ); } + private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String url) { + return new ElasticInferenceServiceCompletionModel( + GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V1_MODEL_NAME), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + } + + private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2AuthorizedEndpoint() { + return new AuthorizationResponseEntity.AuthorizedEndpoint( + GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID, + GP_LLM_V1_MODEL_NAME, + createTaskTypeObject(EIS_CHAT_PATH, "chat_completion"), + "ga", + List.of("multilingual"), + "2024-05-01", + null, + null + ); + } + private static ElasticInferenceServiceModel createRainbowSprinklesExpectedEndpoint(String url) { return new ElasticInferenceServiceCompletionModel( - RAINBOW_SPRINKLES_ENDPOINT_ID_V1, + RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION, ElasticInferenceService.NAME, new ElasticInferenceServiceCompletionServiceSettings(RAINBOW_SPRINKLES_MODEL_NAME), @@ -344,9 +417,9 @@ public static EisAuthorizationResponse getEisJinaEmbedAuthorizationResponse(Stri private static AuthorizationResponseEntity.AuthorizedEndpoint createJinaEmbedAuthorizedEndpoint() { return new AuthorizationResponseEntity.AuthorizedEndpoint( - JINA_EMBED_ENDPOINT_ID, - JINA_EMBED_MODEL_NAME, - "text_embedding", + JINA_EMBED_V3_ENDPOINT_ID, + JINA_EMBED_V3_MODEL_NAME, + createTaskTypeObject(EIS_EMBED_PATH, "text_embedding"), "beta", List.of("multilingual", "open-weights"), "2024-05-01", @@ -362,10 +435,10 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createJinaEmbedAut private static ElasticInferenceServiceModel createJinaExpectedEndpoint(String url) { return new ElasticInferenceServiceDenseTextEmbeddingsModel( - JINA_EMBED_ENDPOINT_ID, + JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING, ElasticInferenceService.NAME, - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(JINA_EMBED_MODEL_NAME, SimilarityMeasure.COSINE, 1024, null), + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(JINA_EMBED_V3_MODEL_NAME, SimilarityMeasure.COSINE, 1024, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, new ElasticInferenceServiceComponents(url), @@ -388,7 +461,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd case CHAT_COMPLETION -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, - TaskType.CHAT_COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), status, null, "", @@ -398,7 +471,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd case SPARSE_EMBEDDING -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, - TaskType.SPARSE_EMBEDDING.toString(), + createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), status, null, "", @@ -408,7 +481,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd case TEXT_EMBEDDING -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, - TaskType.TEXT_EMBEDDING.toString(), + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), status, null, "", @@ -423,7 +496,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd case RERANK -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, - TaskType.RERANK.toString(), + createTaskTypeObject(EIS_RERANK_PATH, TaskType.RERANK.toString()), status, null, "", @@ -433,7 +506,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd case COMPLETION -> new AuthorizationResponseEntity.AuthorizedEndpoint( id, name, - TaskType.COMPLETION.toString(), + createTaskTypeObject(EIS_CHAT_PATH, TaskType.COMPLETION.toString()), status, null, "", From a540427533465118192ecb3441ae55794180e652 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 20 Nov 2025 09:13:55 -0500 Subject: [PATCH 14/24] Adding some comments --- .../InferenceNamedWriteablesProvider.java | 6 + .../authorization/AuthorizationModel.java | 10 +- .../response/AuthorizationResponseEntity.java | 3 + .../AuthorizationModelTests.java | 137 ++++++++++++++++-- 4 files changed, 139 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index b60b7ad3a584c..c49ae4bf92f46 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -705,6 +705,12 @@ private static void addInferenceResultsNamedWriteables(List createDenseTextEmbeddingsModel(authorizedEndpoint, components); case RERANK -> createRerankModel(authorizedEndpoint, components); default -> { - logger.warn(UNKNOWN_TASK_TYPE_LOG_MESSAGE, authorizedEndpoint.id(), taskType); + logger.info(UNSUPPORTED_TASK_TYPE_LOG_MESSAGE, authorizedEndpoint.id(), taskType); yield null; } }; @@ -227,7 +228,12 @@ public static AuthorizationModel empty() { AuthorizationModel(List authorizedEndpoints) { Objects.requireNonNull(authorizedEndpoints); this.authorizedEndpoints = authorizedEndpoints.stream() - .collect(Collectors.toMap(ElasticInferenceServiceModel::getInferenceEntityId, Function.identity(), (a, b) -> a, HashMap::new)); + .collect( + Collectors.toMap(ElasticInferenceServiceModel::getInferenceEntityId, Function.identity(), (firstModel, secondModel) -> { + logger.warn("Found inference id collision for id [{}], ignoring second model", firstModel.inferenceEntityId()); + return firstModel; + }, HashMap::new) + ); var taskTypesSet = EnumSet.noneOf(TaskType.class); taskTypesSet.addAll(this.authorizedEndpoints.values().stream().map(ElasticInferenceServiceModel::getTaskType).toList()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java index aef87bec4f852..23a354cee9705 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java @@ -37,6 +37,9 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +/** + * Handles parsing the v2 authorization response from the Elastic Inference Service. + */ public class AuthorizationResponseEntity implements InferenceServiceResults { public static final String NAME = "elastic_inference_service_auth_results_v2"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java index 7902cc1327996..94b4e2a7ecb04 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java @@ -37,12 +37,24 @@ import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createTaskTypeObject; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; public class AuthorizationModelTests extends ESTestCase { public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { assertFalse(new AuthorizationModel(List.of()).isAuthorized()); - assertFalse(AuthorizationModel.empty().isAuthorized()); + { + var emptyAuthUsingMethod = AuthorizationModel.empty(); + assertFalse(emptyAuthUsingMethod.isAuthorized()); + assertThat(emptyAuthUsingMethod.getEndpointIds(), empty()); + assertThat(emptyAuthUsingMethod, is(new AuthorizationModel(List.of()))); + } + { + var emptyAuthUsingOf = AuthorizationModel.of(new AuthorizationResponseEntity(List.of()), "url"); + assertFalse(emptyAuthUsingOf.isAuthorized()); + assertThat(emptyAuthUsingOf.getEndpointIds(), empty()); + assertThat(emptyAuthUsingOf, is(new AuthorizationModel(List.of()))); + } } public void testExcludes_EndpointsWithoutValidTaskTypes() { @@ -110,6 +122,54 @@ public void testReturnsAuthorizedTaskTypes() { assertTrue(auth.isAuthorized()); } + public void testIgnoresDuplicateId() { + var id1 = "id1"; + var name1 = "name1"; + + var response = new AuthorizationResponseEntity( + List.of( + new AuthorizationResponseEntity.AuthorizedEndpoint( + id1, + name1, + createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), + "ga", + null, + "", + "", + null + ), + new AuthorizationResponseEntity.AuthorizedEndpoint( + id1, + "name2", + createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), + "ga", + null, + "", + "", + null + ) + ) + ); + + var auth = AuthorizationModel.of(response, "url"); + assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); + assertThat(auth.getEndpointIds(), is(Set.of(id1))); + assertTrue(auth.isAuthorized()); + + var url = "url"; + var chatCompletionEndpoint = new ElasticInferenceServiceCompletionModel( + id1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(name1), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + + assertThat(auth.getEndpoints(Set.of(id1)), is(List.of(chatCompletionEndpoint))); + } + public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { var id = "id1"; @@ -289,18 +349,21 @@ public void testScopesToTaskType() { public void testReturnsAuthorizedEndpoints_FiltersInvalid() { var id1 = "id1"; - var id2 = "invalid_text_embedding"; + var invalidTextEmbedding1 = "invalid_text_embedding1"; + var invalidTextEmbedding2 = "invalid_text_embedding2"; + var invalidTextEmbedding3 = "invalid_text_embedding3"; + var invalidTextEmbedding4 = "invalid_text_embedding4"; - var name1 = "name1"; - var name2 = "name2"; + var name = "name1"; var dimensions = 123; var response = new AuthorizationResponseEntity( List.of( + // Valid chat completion new AuthorizationResponseEntity.AuthorizedEndpoint( id1, - name1, + name, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), "ga", null, @@ -308,9 +371,10 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { "", null ), + // Missing similarity measure new AuthorizationResponseEntity.AuthorizedEndpoint( - id2, - name2, + invalidTextEmbedding1, + name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, @@ -323,9 +387,10 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null ) ), + // Invalid chunking settings new AuthorizationResponseEntity.AuthorizedEndpoint( - id2, - name2, + invalidTextEmbedding2, + name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), "ga", null, @@ -335,9 +400,51 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { SimilarityMeasure.DOT_PRODUCT.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), - // Invalid chunking settings Map.of("unexpected_field", "unexpected_value") ) + ), + // Invalid similarity measure + new AuthorizationResponseEntity.AuthorizedEndpoint( + invalidTextEmbedding3, + name, + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), + "ga", + null, + "", + "", + new AuthorizationResponseEntity.Configuration( + "invalid_similarity", + dimensions, + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ), + // Missing dimensions + new AuthorizationResponseEntity.AuthorizedEndpoint( + invalidTextEmbedding4, + name, + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), + "ga", + null, + "", + "", + new AuthorizationResponseEntity.Configuration( + SimilarityMeasure.COSINE.toString(), + null, + DenseVectorFieldMapper.ElementType.FLOAT.toString(), + null + ) + ), + // Missing element type + new AuthorizationResponseEntity.AuthorizedEndpoint( + invalidTextEmbedding4, + name, + createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), + "ga", + null, + "", + "", + new AuthorizationResponseEntity.Configuration(SimilarityMeasure.COSINE.toString(), 123, null, null) ) ) ); @@ -353,16 +460,16 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { id1, TaskType.CHAT_COMPLETION, ElasticInferenceService.NAME, - new ElasticInferenceServiceCompletionServiceSettings(name1), + new ElasticInferenceServiceCompletionServiceSettings(name), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, new ElasticInferenceServiceComponents(url) ); - assertThat(auth.getEndpoints(Set.of(id1, id2)), is(List.of(chatCompletionEndpoint))); - - assertThat(auth.getEndpoints(Set.of(id2)), is(List.of())); - assertThat(auth.getEndpoints(Set.of()), is(List.of())); + assertThat( + auth.getEndpoints(Set.of(id1, invalidTextEmbedding1, invalidTextEmbedding2, invalidTextEmbedding3, invalidTextEmbedding4)), + is(List.of(chatCompletionEndpoint)) + ); } public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { From e8e3577631813eef6f24eb7293ccbd18f5a48b37 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 24 Nov 2025 15:33:08 -0500 Subject: [PATCH 15/24] Fixing gp llm v2 name --- ...erenceGetModelsWithElasticInferenceServiceIT.java | 4 ++-- .../response/AuthorizationResponseEntityTests.java | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 6cc0128f566ba..f3349b2b496ec 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -19,7 +19,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; @@ -55,7 +55,7 @@ public void testGetDefaultEndpoints() throws IOException { } assertInferenceIdTaskType(allModels, RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION); - assertInferenceIdTaskType(allModels, GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION); + assertInferenceIdTaskType(allModels, GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING); assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING); assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 8133932cac804..5d8560bb6f10e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -50,8 +50,8 @@ public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializati public static final String EIS_CHAT_PATH = "chat"; // gp-llm-v2 - public static final String GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; - public static final String GP_LLM_V1_MODEL_NAME = "gp-llm-v2"; + public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; + public static final String GP_LLM_V2_MODEL_NAME = "gp-llm-v2"; // elser-2 public static final String ELSER_V2_ENDPOINT_ID = ".elser-2-elastic"; @@ -354,10 +354,10 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprin private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String url) { return new ElasticInferenceServiceCompletionModel( - GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID, + GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION, ElasticInferenceService.NAME, - new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V1_MODEL_NAME), + new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V2_MODEL_NAME), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, new ElasticInferenceServiceComponents(url) @@ -366,8 +366,8 @@ private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2AuthorizedEndpoint() { return new AuthorizationResponseEntity.AuthorizedEndpoint( - GP_LLM_V1_CHAT_COMPLETION_ENDPOINT_ID, - GP_LLM_V1_MODEL_NAME, + GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, + GP_LLM_V2_MODEL_NAME, createTaskTypeObject(EIS_CHAT_PATH, "chat_completion"), "ga", List.of("multilingual"), From 05b0564986838b8fbb8c8878385e0a52efc0320f Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 24 Nov 2025 15:36:14 -0500 Subject: [PATCH 16/24] Updating test name for rerank --- ...cInferenceServiceAuthorizationRequestHandlerTests.java | 2 +- .../response/AuthorizationResponseEntityTests.java | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 9f7726300dae2..aca8c30cc244a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -143,7 +143,7 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep "inference_endpoints": [ { "id": 123, - "model_name": "elastic-rerank-v1", + "model_name": "jina-reranker-v2", "task_type": "rerank", "status": "preview", "properties": [], diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index 5d8560bb6f10e..cb4438cba7236 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -64,8 +64,8 @@ public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializati public static final String EIS_EMBED_PATH = "embed/text/dense"; // rerank-v1 - public static final String RERANK_V1_ENDPOINT_ID = ".elastic-rerank-v1"; - public static final String RERANK_V1_MODEL_NAME = "elastic-rerank-v1"; + public static final String RERANK_V1_ENDPOINT_ID = ".jina-reranker-v2"; + public static final String RERANK_V1_MODEL_NAME = "jina-reranker-v2"; public static final String EIS_RERANK_PATH = "rerank/text/text-similarity"; public record EisAuthorizationResponse( @@ -235,8 +235,8 @@ public record EisAuthorizationResponse( } }, { - "id": ".elastic-rerank-v1", - "model_name": "elastic-rerank-v1", + "id": ".jina-reranker-v2", + "model_name": "jina-reranker-v2", "task_types": { "eis": "rerank/text/text-similarity", "elasticsearch": "rerank" From a4543db3547283288bf68e55bcf4e36970d66bed Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 24 Nov 2025 20:53:15 -0500 Subject: [PATCH 17/24] Removing named writeable --- .../inference/InferenceNamedWriteablesProvider.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 7072b4af45120..50d7119511829 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -706,19 +706,6 @@ private static void addInferenceResultsNamedWriteables(List namedWriteables) { From 7dd71f0ee1332f5157f95f741cbe1b5c5207a468 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 24 Nov 2025 20:54:48 -0500 Subject: [PATCH 18/24] Removing import --- .../xpack/inference/InferenceNamedWriteablesProvider.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 50d7119511829..bbd73505922a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -76,7 +76,6 @@ import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; From b920069d29714d1e63be9d35388d3eb854551ebc Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 24 Nov 2025 20:58:04 -0500 Subject: [PATCH 19/24] comments --- .../elastic/response/AuthorizationResponseEntity.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java index 23a354cee9705..994c4e2ebeaf1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java @@ -39,6 +39,12 @@ /** * Handles parsing the v2 authorization response from the Elastic Inference Service. + * + * Note: This class does not really need to be {@link InferenceServiceResults}. We do this so that we can leverage the existing + * {@link org.elasticsearch.xpack.inference.external.http.sender.Sender} framework. + * + * Because of this, we don't need to register this class as a named writeable in the NamedWriteableRegistry. It will never be + * sent over the wire between nodes. */ public class AuthorizationResponseEntity implements InferenceServiceResults { From 6b6d21b42b7a77cb82e9ded7b392ff6c30d133ae Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 25 Nov 2025 09:21:44 -0500 Subject: [PATCH 20/24] Adding support for completion --- ...etModelsWithElasticInferenceServiceIT.java | 3 +- .../authorization/AuthorizationModel.java | 8 +-- .../AuthorizationResponseEntityTests.java | 49 +++++++++++++++++-- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 671cdce58510e..0f1e4cd535956 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -20,6 +20,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; @@ -62,7 +63,7 @@ public void testGetDefaultEndpoints() throws IOException { assertInferenceIdTaskType(allModels, RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION); - assertInferenceIdTaskType(allModels, ".gp-llm-v2-completion", TaskType.COMPLETION); + assertInferenceIdTaskType(allModels, GP_LLM_V2_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION); assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING); assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING); assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java index 34df0383bc9a7..89349a9a775bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java @@ -38,8 +38,6 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; - /** * Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format for consumption by the service. */ @@ -81,7 +79,8 @@ private static ElasticInferenceServiceModel createModel( } return switch (taskType) { - case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, components); + case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.CHAT_COMPLETION, components); + case COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.COMPLETION, components); case SPARSE_EMBEDDING -> createSparseEmbeddingsModel(authorizedEndpoint, components); case TEXT_EMBEDDING -> createDenseTextEmbeddingsModel(authorizedEndpoint, components); case RERANK -> createRerankModel(authorizedEndpoint, components); @@ -112,11 +111,12 @@ private static TaskType getTaskType(String taskType) { private static ElasticInferenceServiceCompletionModel createCompletionModel( AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, + TaskType taskType, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceCompletionModel( authorizedEndpoint.id(), - CHAT_COMPLETION, + taskType, ElasticInferenceService.NAME, new ElasticInferenceServiceCompletionServiceSettings(authorizedEndpoint.modelName()), EmptyTaskSettings.INSTANCE, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index cb4438cba7236..aa2d8f3a101db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -51,6 +51,7 @@ public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializati // gp-llm-v2 public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; + public static final String GP_LLM_V2_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-completion"; public static final String GP_LLM_V2_MODEL_NAME = "gp-llm-v2"; // elser-2 @@ -190,6 +191,19 @@ public record EisAuthorizationResponse( ], "release_date": "2024-05-01" }, + { + "id": ".gp-llm-v2-completion", + "model_name": "gp-llm-v2", + "task_types": { + "eis": "chat", + "elasticsearch": "completion" + }, + "status": "ga", + "properties": [ + "multilingual" + ], + "release_date": "2024-05-01" + }, { "id": ".elser-2-elastic", "model_name": "elser_model_2", @@ -300,7 +314,8 @@ public static AuthorizationResponseEntity.TaskTypeObject createTaskTypeObject(St public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { var authorizedEndpoints = List.of( createRainbowSprinklesAuthorizedEndpoint(), - createGpLlmV2AuthorizedEndpoint(), + createGpLlmV2ChatCompletionAuthorizedEndpoint(), + createGpLlmV2CompletionAuthorizedEndpoint(), createElserAuthorizedEndpoint(), createJinaEmbedAuthorizedEndpoint(), new AuthorizationResponseEntity.AuthorizedEndpoint( @@ -322,7 +337,8 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn new AuthorizationResponseEntity(authorizedEndpoints), List.of( createRainbowSprinklesExpectedEndpoint(url), - createGpLlmV2ExpectedEndpoint(url), + createGpLlmV2ChatCompletionExpectedEndpoint(url), + createGpLlmV2CompletionExpectedEndpoint(url), createElserExpectedEndpoint(url), createJinaExpectedEndpoint(url), new ElasticInferenceServiceRerankModel( @@ -352,7 +368,7 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprin ); } - private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String url) { + private static ElasticInferenceServiceModel createGpLlmV2ChatCompletionExpectedEndpoint(String url) { return new ElasticInferenceServiceCompletionModel( GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION, @@ -364,7 +380,19 @@ private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String ); } - private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2AuthorizedEndpoint() { + private static ElasticInferenceServiceModel createGpLlmV2CompletionExpectedEndpoint(String url) { + return new ElasticInferenceServiceCompletionModel( + GP_LLM_V2_COMPLETION_ENDPOINT_ID, + TaskType.COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V2_MODEL_NAME), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + } + + private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2ChatCompletionAuthorizedEndpoint() { return new AuthorizationResponseEntity.AuthorizedEndpoint( GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, GP_LLM_V2_MODEL_NAME, @@ -377,6 +405,19 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2Autho ); } + private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2CompletionAuthorizedEndpoint() { + return new AuthorizationResponseEntity.AuthorizedEndpoint( + GP_LLM_V2_COMPLETION_ENDPOINT_ID, + GP_LLM_V2_MODEL_NAME, + createTaskTypeObject(EIS_CHAT_PATH, "completion"), + "ga", + List.of("multilingual"), + "2024-05-01", + null, + null + ); + } + private static ElasticInferenceServiceModel createRainbowSprinklesExpectedEndpoint(String url) { return new ElasticInferenceServiceCompletionModel( RAINBOW_SPRINKLES_ENDPOINT_ID, From 578546f534bb645cc00a8f9a0751b734cb746bfe Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 25 Nov 2025 09:57:53 -0500 Subject: [PATCH 21/24] Fixing tests --- .../authorization/AuthorizationPoller.java | 5 ---- .../AuthorizationPollerTests.java | 5 ++-- ...rviceAuthorizationRequestHandlerTests.java | 10 ++++++- .../AuthorizationResponseEntityTests.java | 27 ++++++++++++++++++- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 96767c4810c10..974059fb6db90 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -27,7 +27,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; @@ -61,7 +60,6 @@ public class AuthorizationPoller extends AllocatedPersistentTask { private final AtomicBoolean shutdown = new AtomicBoolean(false); private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; private final AtomicBoolean initialized = new AtomicBoolean(false); - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; private final Client client; private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1); private final CCMFeature ccmFeature; @@ -118,9 +116,6 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); this.sender = Objects.requireNonNull(sender); this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( - elasticInferenceServiceSettings.getElasticInferenceServiceUrl() - ); this.modelRegistry = Objects.requireNonNull(modelRegistry); this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN); this.ccmFeature = Objects.requireNonNull(ccmFeature); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index 218d211b7c9e9..74b7a2411714e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService; import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createAuthorizedEndpoint; +import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createInvalidTaskTypeAuthorizedEndpoint; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -336,7 +337,7 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() { var url = "eis-url"; - var completionModel = createAuthorizedEndpoint(TaskType.COMPLETION); + var invalidTaskTypeEndpoint = createInvalidTaskTypeAuthorizedEndpoint(); var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); @@ -345,7 +346,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(completionModel)), url)); + listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(invalidTaskTypeEndpoint)), url)); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index aca8c30cc244a..b6f40d3c3d95f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -203,7 +203,15 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { var authResponse = listener.actionGet(TIMEOUT); assertThat( authResponse.getTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK)) + is( + EnumSet.of( + TaskType.CHAT_COMPLETION, + TaskType.SPARSE_EMBEDDING, + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + TaskType.COMPLETION + ) + ) ); assertThat(authResponse.getEndpointIds(), containsInAnyOrder(responseData.inferenceIds().toArray(String[]::new))); assertTrue(authResponse.isAuthorized()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java index aa2d8f3a101db..ebdd0ff5686fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java @@ -493,6 +493,23 @@ public static AuthorizationResponseEntity createResponse() { ); } + public static AuthorizationResponseEntity.AuthorizedEndpoint createInvalidTaskTypeAuthorizedEndpoint() { + var id = randomAlphaOfLength(10); + var name = randomAlphaOfLength(10); + var status = randomFrom("ga", "beta", "preview"); + + return new AuthorizationResponseEntity.AuthorizedEndpoint( + id, + name, + createTaskTypeObject("invalid/task/type", TaskType.ANY.toString()), + status, + null, + "", + "", + null + ); + } + public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { var id = randomAlphaOfLength(10); var name = randomAlphaOfLength(10); @@ -572,7 +589,15 @@ public void testParseAllFields() throws IOException { assertThat( authModel.getTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK)) + is( + EnumSet.of( + TaskType.CHAT_COMPLETION, + TaskType.SPARSE_EMBEDDING, + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + TaskType.COMPLETION + ) + ) ); assertThat( authModel.getEndpoints(responseData.inferenceIds()), From 6aa0ea50bf694f5d266e53b89fe1e4679594df2b Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 25 Nov 2025 13:49:44 -0500 Subject: [PATCH 22/24] Addressing feedback --- .../inference/BaseMockEISAuthServerTest.java | 2 +- ...etModelsWithElasticInferenceServiceIT.java | 12 +- .../inference/InferenceGetServicesIT.java | 3 +- ...icInferenceServiceAuthorizationServer.java | 2 +- .../AuthorizationTaskExecutorIT.java | 18 +-- ...horizationTaskExecutorMultipleNodesIT.java | 10 +- .../inference/integration/CCMServiceIT.java | 8 +- .../TransportGetInferenceServicesAction.java | 12 +- .../authorization/AuthorizationPoller.java | 4 +- ...icInferenceServiceAuthorizationModel.java} | 88 +++++++----- ...nceServiceAuthorizationRequestHandler.java | 12 +- ...ceServiceAuthorizationResponseEntity.java} | 23 ++-- .../http/sender/HttpRequestSenderTests.java | 10 +- .../AuthorizationPollerTests.java | 60 ++++++--- ...erenceServiceAuthorizationModelTests.java} | 127 ++++++++++-------- ...rviceAuthorizationRequestHandlerTests.java | 22 +-- ...viceAuthorizationResponseEntityTests.java} | 106 ++++++++------- 17 files changed, 304 insertions(+), 215 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/{AuthorizationModel.java => ElasticInferenceServiceAuthorizationModel.java} (74%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/{AuthorizationResponseEntity.java => ElasticInferenceServiceAuthorizationResponseEntity.java} (93%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/{AuthorizationModelTests.java => ElasticInferenceServiceAuthorizationModelTests.java} (78%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/{AuthorizationResponseEntityTests.java => ElasticInferenceServiceAuthorizationResponseEntityTests.java} (80%) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 143ce5d1fe2be..8eac4923529a4 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; public class BaseMockEISAuthServerTest extends ESRestTestCase { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 0f1e4cd535956..9a0d78ad66d23 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -18,12 +18,12 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 2194c5b08122a..2f79f320c89da 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -176,7 +176,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "hugging_face", "amazon_sagemaker", "mistral", - "watsonxai" + "watsonxai", + "elastic" ).toArray() ) ); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index f6e5b2f515f70..a25e82c4a336f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -17,7 +17,7 @@ import org.junit.runners.model.Statement; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints; public class MockElasticInferenceServiceAuthorizationServer implements TestRule { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 9f67e354ed695..431509791255c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -28,7 +28,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -43,12 +43,12 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -66,7 +66,7 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; - private static AuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; + private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; private ModelRegistry modelRegistry; private AuthorizationTaskExecutor authorizationTaskExecutor; @@ -273,7 +273,7 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep assertChatCompletionEndpointExists(); // Simulate that a text embedding model is now authorized - var jinaEmbedResponse = AuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl); + var jinaEmbedResponse = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponse.responseJson())); restartPollingTaskAndWaitForAuthResponse(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index 0632a478120cd..cf3b08fea3bfb 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -35,9 +35,9 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.cancelAuthorizationTask; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -55,7 +55,7 @@ public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase { private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; - private static AuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; + private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; @BeforeClass public static void initClass() throws IOException { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java index 057cec0695ff1..3a9147f5f4fc1 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java @@ -23,7 +23,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMModel; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -49,7 +49,7 @@ public class CCMServiceIT extends CCMSingleNodeIT { private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; - private static AuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; + private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; private AuthorizationTaskExecutor authorizationTaskExecutor; private ModelRegistry modelRegistry; @@ -79,7 +79,9 @@ public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); - chatCompletionResponse = AuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse(gatewayUrl); + chatCompletionResponse = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse( + gatewayUrl + ); } @Before diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index 605b410075d05..6f7f47abad355 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationModel; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import java.util.ArrayList; @@ -122,7 +122,7 @@ private void getServiceConfigurationsForServicesAndEis( ArrayList> availableServices, @Nullable TaskType requestedTaskType ) { - SubscribableListener.newForked(authModelListener -> { + SubscribableListener.newForked(authModelListener -> { // Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); }).>andThen((configurationListener, authorizationModel) -> { @@ -133,12 +133,14 @@ private void getServiceConfigurationsForServicesAndEis( return; } - var config = ElasticInferenceService.createConfiguration(authorizationModel.getTaskTypes()); + // If there was a requested task type and the authorization response from EIS doesn't support it, we'll exclude EIS as a valid + // service if (requestedTaskType != null && authorizationModel.getTaskTypes().contains(requestedTaskType) == false) { configurationListener.onResponse(serviceConfigs); return; } + var config = ElasticInferenceService.createConfiguration(authorizationModel.getTaskTypes()); serviceConfigs.add(config); serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService)); configurationListener.onResponse(serviceConfigs); @@ -150,14 +152,14 @@ private void getServiceConfigurationsForServicesAndEis( ); } - private void getEisAuthorization(ActionListener listener, Sender sender) { + private void getEisAuthorization(ActionListener listener, Sender sender) { var disabledServiceListener = listener.delegateResponse((delegate, e) -> { logger.warn( "Failed to retrieve authorization information from the " + "Elastic Inference Service while determining service configurations. Marking service as disabled.", e ); - delegate.onResponse(AuthorizationModel.empty()); + delegate.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized()); }); eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 974059fb6db90..f8ef0596b4749 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -306,7 +306,7 @@ private void shouldSendAuthRequest(ActionListener> } private void sendRequest(ActionListener listener) { - SubscribableListener.newForked( + SubscribableListener.newForked( authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) ) .andThenApply(this::getNewInferenceEndpointsToStore) @@ -314,7 +314,7 @@ private void sendRequest(ActionListener listener) { .addListener(listener); } - private List getNewInferenceEndpointsToStore(AuthorizationModel authModel) { + private List getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { logger.debug("Received authorization response, {}", authModel); var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java similarity index 74% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java index 89349a9a775bc..dec387036f0b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -39,21 +39,25 @@ import java.util.stream.Collectors; /** - * Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format for consumption by the service. + * Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format + * for consumption by the {@link ElasticInferenceService}. */ -public class AuthorizationModel { +public class ElasticInferenceServiceAuthorizationModel { - private static final Logger logger = LogManager.getLogger(AuthorizationModel.class); + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationModel.class); private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping"; private static final String UNSUPPORTED_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unsupported task type [{}], skipping"; - public static AuthorizationModel of(AuthorizationResponseEntity responseEntity, String baseEisUrl) { + public static ElasticInferenceServiceAuthorizationModel of( + ElasticInferenceServiceAuthorizationResponseEntity responseEntity, + String baseEisUrl + ) { var components = new ElasticInferenceServiceComponents(baseEisUrl); return createInternal(responseEntity.getAuthorizedEndpoints(), components); } - private static AuthorizationModel createInternal( - List responseEndpoints, + private static ElasticInferenceServiceAuthorizationModel createInternal( + List responseEndpoints, ElasticInferenceServiceComponents components ) { var validEndpoints = new ArrayList(); @@ -64,11 +68,11 @@ private static AuthorizationModel createInternal( } } - return new AuthorizationModel(validEndpoints); + return new ElasticInferenceServiceAuthorizationModel(validEndpoints); } private static ElasticInferenceServiceModel createModel( - AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { try { @@ -81,7 +85,7 @@ private static ElasticInferenceServiceModel createModel( return switch (taskType) { case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.CHAT_COMPLETION, components); case COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.COMPLETION, components); - case SPARSE_EMBEDDING -> createSparseEmbeddingsModel(authorizedEndpoint, components); + case SPARSE_EMBEDDING -> createSparseTextEmbeddingsModel(authorizedEndpoint, components); case TEXT_EMBEDDING -> createDenseTextEmbeddingsModel(authorizedEndpoint, components); case RERANK -> createRerankModel(authorizedEndpoint, components); default -> { @@ -110,7 +114,7 @@ private static TaskType getTaskType(String taskType) { } private static ElasticInferenceServiceCompletionModel createCompletionModel( - AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, TaskType taskType, ElasticInferenceServiceComponents components ) { @@ -125,8 +129,8 @@ private static ElasticInferenceServiceCompletionModel createCompletionModel( ); } - private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddingsModel( - AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, + private static ElasticInferenceServiceSparseEmbeddingsModel createSparseTextEmbeddingsModel( + ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceSparseEmbeddingsModel( @@ -141,22 +145,26 @@ private static ElasticInferenceServiceSparseEmbeddingsModel createSparseEmbeddin ); } - private static AuthorizationResponseEntity.Configuration getConfigurationOrEmpty( - AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint + private static ElasticInferenceServiceAuthorizationResponseEntity.Configuration getConfigurationOrEmpty( + ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint ) { if (authorizedEndpoint.configuration() != null) { return authorizedEndpoint.configuration(); } - return AuthorizationResponseEntity.Configuration.EMPTY; + return ElasticInferenceServiceAuthorizationResponseEntity.Configuration.EMPTY; } - private static Map getChunkingSettingsMap(AuthorizationResponseEntity.Configuration configuration) { + private static Map getChunkingSettingsMap( + ElasticInferenceServiceAuthorizationResponseEntity.Configuration configuration + ) { + // We intentionally want to return an empty map here instead of null, because ChunkingSettingsBuilder.fromMap() + // will return the "new" default value in that case return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap<>()); } private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEmbeddingsModel( - AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { var config = getConfigurationOrEmpty(authorizedEndpoint); @@ -179,10 +187,22 @@ private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEm ); } - private static void validateConfigurationForTextEmbedding(AuthorizationResponseEntity.Configuration config) { - validateFieldPresent(AuthorizationResponseEntity.Configuration.ELEMENT_TYPE, config.elementType(), TaskType.TEXT_EMBEDDING); - validateFieldPresent(AuthorizationResponseEntity.Configuration.DIMENSIONS, config.dimensions(), TaskType.TEXT_EMBEDDING); - validateFieldPresent(AuthorizationResponseEntity.Configuration.SIMILARITY, config.similarity(), TaskType.TEXT_EMBEDDING); + private static void validateConfigurationForTextEmbedding(ElasticInferenceServiceAuthorizationResponseEntity.Configuration config) { + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntity.Configuration.ELEMENT_TYPE, + config.elementType(), + TaskType.TEXT_EMBEDDING + ); + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntity.Configuration.DIMENSIONS, + config.dimensions(), + TaskType.TEXT_EMBEDDING + ); + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntity.Configuration.SIMILARITY, + config.similarity(), + TaskType.TEXT_EMBEDDING + ); } private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) { @@ -193,14 +213,18 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy } } - private static SimilarityMeasure getSimilarityMeasure(AuthorizationResponseEntity.Configuration configuration) { - validateFieldPresent(AuthorizationResponseEntity.Configuration.SIMILARITY, configuration.similarity(), TaskType.TEXT_EMBEDDING); + private static SimilarityMeasure getSimilarityMeasure(ElasticInferenceServiceAuthorizationResponseEntity.Configuration configuration) { + validateFieldPresent( + ElasticInferenceServiceAuthorizationResponseEntity.Configuration.SIMILARITY, + configuration.similarity(), + TaskType.TEXT_EMBEDDING + ); return SimilarityMeasure.fromString(configuration.similarity()); } private static ElasticInferenceServiceRerankModel createRerankModel( - AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, + ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components ) { return new ElasticInferenceServiceRerankModel( @@ -217,15 +241,15 @@ private static ElasticInferenceServiceRerankModel createRerankModel( /** * Returns an object indicating that the cluster is not authorized for any endpoints from EIS. */ - public static AuthorizationModel empty() { - return new AuthorizationModel(List.of()); + public static ElasticInferenceServiceAuthorizationModel unauthorized() { + return new ElasticInferenceServiceAuthorizationModel(List.of()); } private final Map authorizedEndpoints; private final EnumSet taskTypes; // Default for testing - AuthorizationModel(List authorizedEndpoints) { + ElasticInferenceServiceAuthorizationModel(List authorizedEndpoints) { Objects.requireNonNull(authorizedEndpoints); this.authorizedEndpoints = authorizedEndpoints.stream() .collect( @@ -249,15 +273,15 @@ public boolean isAuthorized() { } /** - * Returns a new {@link AuthorizationModel} object retaining only the specified task types + * Returns a new {@link ElasticInferenceServiceAuthorizationModel} object retaining only the specified task types * and applicable models that leverage those task types. Any task types not specified in the provided parameter will be * excluded from the returned object. This is essentially an intersection. * @param taskTypes the task types to retain in the newly created object * @return a new object containing endpoints limited to the specified task types */ - public AuthorizationModel newLimitedToTaskTypes(EnumSet taskTypes) { + public ElasticInferenceServiceAuthorizationModel newLimitedToTaskTypes(EnumSet taskTypes) { var endpoints = this.authorizedEndpoints.values().stream().filter(endpoint -> taskTypes.contains(endpoint.getTaskType())).toList(); - return new AuthorizationModel(endpoints); + return new ElasticInferenceServiceAuthorizationModel(endpoints); } public EnumSet getTaskTypes() { @@ -280,7 +304,7 @@ public String toString() { @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; - AuthorizationModel that = (AuthorizationModel) o; + ElasticInferenceServiceAuthorizationModel that = (ElasticInferenceServiceAuthorizationModel) o; return Objects.equals(authorizedEndpoints, that.authorizedEndpoints) && Objects.equals(taskTypes, that.taskTypes); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 5638b7b346497..acaa3fa85e693 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; @@ -47,7 +47,7 @@ public class ElasticInferenceServiceAuthorizationRequestHandler { private static ResponseHandler createAuthResponseHandler() { return new ElasticInferenceServiceResponseHandler( Strings.format("%s authorization", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - AuthorizationResponseEntity::fromResponse + ElasticInferenceServiceAuthorizationResponseEntity::fromResponse ); } @@ -88,13 +88,13 @@ public ElasticInferenceServiceAuthorizationRequestHandler( * @param listener a listener to receive the response * @param sender a {@link Sender} for making the request to the Elastic Inference Service */ - public void getAuthorization(ActionListener listener, Sender sender) { + public void getAuthorization(ActionListener listener, Sender sender) { try { logger.debug("Retrieving authorization information from the Elastic Inference Service."); if (Strings.isNullOrEmpty(baseUrl)) { logger.debug("The base URL for the authorization service is not valid, rejecting authorization."); - listener.onResponse(AuthorizationModel.empty()); + listener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized()); return; } @@ -119,9 +119,9 @@ public void getAuthorization(ActionListener listener, Sender sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener); }) .andThenApply(authResult -> { - if (authResult instanceof AuthorizationResponseEntity authResponseEntity) { + if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); - return AuthorizationModel.of(authResponseEntity, baseUrl); + return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity, baseUrl); } var errorMessage = Strings.format( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java similarity index 93% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java index 994c4e2ebeaf1..be3dc28ee0517 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -46,18 +46,19 @@ * Because of this, we don't need to register this class as a named writeable in the NamedWriteableRegistry. It will never be * sent over the wire between nodes. */ -public class AuthorizationResponseEntity implements InferenceServiceResults { +public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults { public static final String NAME = "elastic_inference_service_auth_results_v2"; private static final String INFERENCE_ENDPOINTS = "inference_endpoints"; @SuppressWarnings("unchecked") - public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - AuthorizationResponseEntity.class.getSimpleName(), - true, - args -> new AuthorizationResponseEntity((List) args[0]) - ); + public static ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + ElasticInferenceServiceAuthorizationResponseEntity.class.getSimpleName(), + true, + args -> new ElasticInferenceServiceAuthorizationResponseEntity((List) args[0]) + ); static { PARSER.declareObjectArray( @@ -300,22 +301,22 @@ public String toString() { private final List authorizedEndpoints; - public AuthorizationResponseEntity(List authorizedModels) { + public ElasticInferenceServiceAuthorizationResponseEntity(List authorizedModels) { this.authorizedEndpoints = Objects.requireNonNull(authorizedModels); } /** * Create an empty response */ - public AuthorizationResponseEntity() { + public ElasticInferenceServiceAuthorizationResponseEntity() { this(List.of()); } - public AuthorizationResponseEntity(StreamInput in) throws IOException { + public ElasticInferenceServiceAuthorizationResponseEntity(StreamInput in) throws IOException { this(in.readCollectionAsList(AuthorizedEndpoint::new)); } - public static AuthorizationResponseEntity fromResponse(Request request, HttpResult response) throws IOException { + public static ElasticInferenceServiceAuthorizationResponseEntity fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -361,7 +362,7 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - AuthorizationResponseEntity that = (AuthorizationResponseEntity) o; + ElasticInferenceServiceAuthorizationResponseEntity that = (ElasticInferenceServiceAuthorizationResponseEntity) o; return Objects.equals(authorizedEndpoints, that.authorizedEndpoints); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 0420b58db9d1f..78f08e8f91a7c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -32,7 +32,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; import org.junit.Before; @@ -55,7 +55,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisElserAuthorizationResponse; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisElserAuthorizationResponse; import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -314,14 +314,14 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce ); var responseHandler = new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - AuthorizationResponseEntity::fromResponse + ElasticInferenceServiceAuthorizationResponseEntity::fromResponse ); sender.sendWithoutQueuing(mock(Logger.class), request, responseHandler, null, listener); var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(AuthorizationResponseEntity.class)); - var authResponse = (AuthorizationResponseEntity) result; + assertThat(result, instanceOf(ElasticInferenceServiceAuthorizationResponseEntity.class)); + var authResponse = (ElasticInferenceServiceAuthorizationResponseEntity) result; assertThat(authResponse.getAuthorizedEndpoints(), is(elserResponse.responseEntity().getAuthorizedEndpoints())); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index 74b7a2411714e..eb0419ccf9f27 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.junit.Before; @@ -43,8 +43,8 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createAuthorizedEndpoint; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createInvalidTaskTypeAuthorizedEndpoint; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.createAuthorizedEndpoint; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.createInvalidTaskTypeAuthorizedEndpoint; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -188,8 +188,13 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of(sparseModel)), + url + ) + ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -255,8 +260,13 @@ public void testSendsAuthorizationRequest_WhenCCMIsNotConfigurable() { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of(sparseModel)), + url + ) + ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -310,8 +320,13 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of(sparseModel)), + url + ) + ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -345,8 +360,13 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(invalidTaskTypeEndpoint)), url)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of(invalidTaskTypeEndpoint)), + url + ) + ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -382,8 +402,13 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of(sparseModel)), + url + ) + ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -439,8 +464,13 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(sparseModel)), url)); + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of(sparseModel)), + url + ) + ); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java similarity index 78% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java index 94b4e2a7ecb04..0a2b80d222ca0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java @@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -31,36 +31,39 @@ import java.util.Map; import java.util.Set; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_CHAT_PATH; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_EMBED_PATH; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.EIS_SPARSE_PATH; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createTaskTypeObject; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_CHAT_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_EMBED_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_SPARSE_PATH; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.createTaskTypeObject; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; -public class AuthorizationModelTests extends ESTestCase { +public class ElasticInferenceServiceAuthorizationModelTests extends ESTestCase { public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { - assertFalse(new AuthorizationModel(List.of()).isAuthorized()); + assertFalse(new ElasticInferenceServiceAuthorizationModel(List.of()).isAuthorized()); { - var emptyAuthUsingMethod = AuthorizationModel.empty(); + var emptyAuthUsingMethod = ElasticInferenceServiceAuthorizationModel.unauthorized(); assertFalse(emptyAuthUsingMethod.isAuthorized()); assertThat(emptyAuthUsingMethod.getEndpointIds(), empty()); - assertThat(emptyAuthUsingMethod, is(new AuthorizationModel(List.of()))); + assertThat(emptyAuthUsingMethod, is(new ElasticInferenceServiceAuthorizationModel(List.of()))); } { - var emptyAuthUsingOf = AuthorizationModel.of(new AuthorizationResponseEntity(List.of()), "url"); + var emptyAuthUsingOf = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity(List.of()), + "url" + ); assertFalse(emptyAuthUsingOf.isAuthorized()); assertThat(emptyAuthUsingOf.getEndpointIds(), empty()); - assertThat(emptyAuthUsingOf, is(new AuthorizationModel(List.of()))); + assertThat(emptyAuthUsingOf, is(new ElasticInferenceServiceAuthorizationModel(List.of()))); } } public void testExcludes_EndpointsWithoutValidTaskTypes() { - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( "id", "name", createTaskTypeObject("", "invalid_task_type"), @@ -70,7 +73,7 @@ public void testExcludes_EndpointsWithoutValidTaskTypes() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( "id2", "name", createTaskTypeObject("", TaskType.ANY.toString()), @@ -82,7 +85,7 @@ public void testExcludes_EndpointsWithoutValidTaskTypes() { ) ) ); - var auth = AuthorizationModel.of(response, "url"); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, "url"); assertTrue(auth.getTaskTypes().isEmpty()); assertFalse(auth.isAuthorized()); } @@ -91,9 +94,9 @@ public void testReturnsAuthorizedTaskTypes() { var id1 = "id1"; var id2 = "id2"; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id1, "name1", createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -103,7 +106,7 @@ public void testReturnsAuthorizedTaskTypes() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id2, "name2", createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), @@ -116,7 +119,7 @@ public void testReturnsAuthorizedTaskTypes() { ) ); - var auth = AuthorizationModel.of(response, "url"); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, "url"); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); assertThat(auth.getEndpointIds(), is(Set.of(id1, id2))); assertTrue(auth.isAuthorized()); @@ -126,9 +129,9 @@ public void testIgnoresDuplicateId() { var id1 = "id1"; var name1 = "name1"; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -138,7 +141,7 @@ public void testIgnoresDuplicateId() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id1, "name2", createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), @@ -151,7 +154,7 @@ public void testIgnoresDuplicateId() { ) ); - var auth = AuthorizationModel.of(response, "url"); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, "url"); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); assertThat(auth.getEndpointIds(), is(Set.of(id1))); assertTrue(auth.isAuthorized()); @@ -173,9 +176,9 @@ public void testIgnoresDuplicateId() { public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { var id = "id1"; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, "name1", createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -186,7 +189,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { null ), // This should be ignored because the id is a duplicate - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, "name2", createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), @@ -199,7 +202,7 @@ public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { ) ); - var auth = AuthorizationModel.of(response, "url"); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, "url"); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); assertThat(auth.getEndpointIds(), is(Set.of(id))); assertTrue(auth.isAuthorized()); @@ -215,9 +218,9 @@ public void testReturnsAuthorizedEndpoints() { var similarity = SimilarityMeasure.COSINE; var dimensions = 123; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -227,7 +230,7 @@ public void testReturnsAuthorizedEndpoints() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -235,7 +238,7 @@ public void testReturnsAuthorizedEndpoints() { null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( similarity.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -247,7 +250,7 @@ public void testReturnsAuthorizedEndpoints() { var url = "url"; - var auth = AuthorizationModel.of(response, url); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, url); assertThat(auth.getEndpointIds(), is(Set.of(id1, id2))); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING))); assertTrue(auth.isAuthorized()); @@ -287,9 +290,9 @@ public void testScopesToTaskType() { var similarity = SimilarityMeasure.COSINE; var dimensions = 123; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id1, name1, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -299,7 +302,7 @@ public void testScopesToTaskType() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id2, name2, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -307,7 +310,7 @@ public void testScopesToTaskType() { null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( similarity.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -319,7 +322,7 @@ public void testScopesToTaskType() { var url = "url"; - var auth = AuthorizationModel.of(response, url); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, url); assertThat(auth.getEndpointIds(), is(Set.of(id1, id2))); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING))); assertTrue(auth.isAuthorized()); @@ -358,10 +361,10 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { var dimensions = 123; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( // Valid chat completion - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id1, name, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -372,7 +375,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null ), // Missing similarity measure - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( invalidTextEmbedding1, name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -380,7 +383,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( null, dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -388,7 +391,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { ) ), // Invalid chunking settings - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( invalidTextEmbedding2, name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -396,7 +399,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( SimilarityMeasure.DOT_PRODUCT.toString(), dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -404,7 +407,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { ) ), // Invalid similarity measure - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( invalidTextEmbedding3, name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -412,7 +415,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( "invalid_similarity", dimensions, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -420,7 +423,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { ) ), // Missing dimensions - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( invalidTextEmbedding4, name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -428,7 +431,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( SimilarityMeasure.COSINE.toString(), null, DenseVectorFieldMapper.ElementType.FLOAT.toString(), @@ -436,7 +439,7 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { ) ), // Missing element type - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( invalidTextEmbedding4, name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -444,14 +447,19 @@ public void testReturnsAuthorizedEndpoints_FiltersInvalid() { null, "", "", - new AuthorizationResponseEntity.Configuration(SimilarityMeasure.COSINE.toString(), 123, null, null) + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( + SimilarityMeasure.COSINE.toString(), + 123, + null, + null + ) ) ) ); var url = "url"; - var auth = AuthorizationModel.of(response, url); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, url); assertThat(auth.getEndpointIds(), is(Set.of(id1))); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); assertTrue(auth.isAuthorized()); @@ -489,9 +497,9 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { var url = "base_url"; - var response = new AuthorizationResponseEntity( + var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of( - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( idChat, nameChat, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -501,7 +509,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( idSparse, nameSparse, createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), @@ -511,7 +519,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { "", null ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( idDense, nameDense, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -519,9 +527,14 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { null, "", "", - new AuthorizationResponseEntity.Configuration(similarity.toString(), dimensions, elementType, null) + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( + similarity.toString(), + dimensions, + elementType, + null + ) ), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( idRerank, nameRerank, createTaskTypeObject(EIS_SPARSE_PATH, TaskType.RERANK.toString()), @@ -534,7 +547,7 @@ public void testCreatesAllSupportedTaskTypesAndReturnsCorrectModels() { ) ); - var auth = AuthorizationModel.of(response, url); + var auth = ElasticInferenceServiceAuthorizationModel.of(response, url); var endpoints = auth.getEndpoints(Set.of(idChat, idSparse, idDense, idRerank)); assertThat(endpoints.size(), is(4)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index b6f40d3c3d95f..87c3817dbb577 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -46,9 +46,9 @@ import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createApplierFactory; import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactoryTests.createNoopApplierFactory; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints; -import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.getEisElserAuthorizationResponse; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints; +import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisElserAuthorizationResponse; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; @@ -88,7 +88,7 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger, createNoopApplierFactory()); try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); @@ -110,7 +110,7 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger, createNoopApplierFactory()); try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); @@ -155,7 +155,7 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep queueWebServerResponsesForRetries(responseWithInvalidIdField); - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT)); @@ -197,7 +197,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseData.responseJson())); - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); @@ -253,7 +253,7 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I try (var sender = senderFactory.createSender()) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse.responseJson())); - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var authResponse = listener.actionGet(TIMEOUT); @@ -288,8 +288,8 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { createNoopApplierFactory() ); - PlainActionFuture listener = new PlainActionFuture<>(); - ActionListener onlyOnceListener = ActionListener.assertOnce(listener); + PlainActionFuture listener = new PlainActionFuture<>(); + ActionListener onlyOnceListener = ActionListener.assertOnce(listener); var elserResponse = getEisElserAuthorizationResponse(eisGatewayUrl); @@ -328,7 +328,7 @@ public void testGetAuthorization_InvalidResponse() throws IOException { var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger, createNoopApplierFactory()); try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java similarity index 80% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java index ebdd0ff5686fc..d8655e3e16abb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java @@ -21,7 +21,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationModel; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; @@ -42,7 +42,8 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; -public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializationTestCase { +public class ElasticInferenceServiceAuthorizationResponseEntityTests extends AbstractBWCWireSerializationTestCase< + ElasticInferenceServiceAuthorizationResponseEntity> { // rainbow-sprinkles public static final String RAINBOW_SPRINKLES_ENDPOINT_ID = ".rainbow-sprinkles-elastic"; @@ -71,7 +72,7 @@ public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializati public record EisAuthorizationResponse( String responseJson, - AuthorizationResponseEntity responseEntity, + ElasticInferenceServiceAuthorizationResponseEntity responseEntity, List expectedEndpoints, Set inferenceIds ) {} @@ -266,11 +267,13 @@ public record EisAuthorizationResponse( public static EisAuthorizationResponse getEisElserAuthorizationResponse(String url) { var authorizedEndpoints = List.of(createElserAuthorizedEndpoint()); - var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream() + .map(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_ELSER_RESPONSE, - new AuthorizationResponseEntity(authorizedEndpoints), + new ElasticInferenceServiceAuthorizationResponseEntity(authorizedEndpoints), List.of(createElserExpectedEndpoint(url)), inferenceIds ); @@ -289,8 +292,8 @@ private static ElasticInferenceServiceModel createElserExpectedEndpoint(String u ); } - private static AuthorizationResponseEntity.AuthorizedEndpoint createElserAuthorizedEndpoint() { - return new AuthorizationResponseEntity.AuthorizedEndpoint( + private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createElserAuthorizedEndpoint() { + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( ELSER_V2_ENDPOINT_ID, ELSER_V2_MODEL_NAME, createTaskTypeObject(EIS_SPARSE_PATH, "sparse_embedding"), @@ -298,7 +301,7 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createElserAuthori List.of("english"), "2024-05-01", null, - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( null, null, null, @@ -307,8 +310,11 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createElserAuthori ); } - public static AuthorizationResponseEntity.TaskTypeObject createTaskTypeObject(String eisTaskType, String elasticsearchTaskType) { - return new AuthorizationResponseEntity.TaskTypeObject(eisTaskType, elasticsearchTaskType); + public static ElasticInferenceServiceAuthorizationResponseEntity.TaskTypeObject createTaskTypeObject( + String eisTaskType, + String elasticsearchTaskType + ) { + return new ElasticInferenceServiceAuthorizationResponseEntity.TaskTypeObject(eisTaskType, elasticsearchTaskType); } public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) { @@ -318,7 +324,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn createGpLlmV2CompletionAuthorizedEndpoint(), createElserAuthorizedEndpoint(), createJinaEmbedAuthorizedEndpoint(), - new AuthorizationResponseEntity.AuthorizedEndpoint( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( RERANK_V1_ENDPOINT_ID, RERANK_V1_MODEL_NAME, createTaskTypeObject(EIS_RERANK_PATH, "rerank"), @@ -330,11 +336,13 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn ) ); - var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream() + .map(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_AUTHORIZATION_RESPONSE_V2, - new AuthorizationResponseEntity(authorizedEndpoints), + new ElasticInferenceServiceAuthorizationResponseEntity(authorizedEndpoints), List.of( createRainbowSprinklesExpectedEndpoint(url), createGpLlmV2ChatCompletionExpectedEndpoint(url), @@ -355,8 +363,8 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn ); } - private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprinklesAuthorizedEndpoint() { - return new AuthorizationResponseEntity.AuthorizedEndpoint( + private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprinklesAuthorizedEndpoint() { + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( RAINBOW_SPRINKLES_ENDPOINT_ID, RAINBOW_SPRINKLES_MODEL_NAME, createTaskTypeObject(EIS_CHAT_PATH, "chat_completion"), @@ -392,8 +400,8 @@ private static ElasticInferenceServiceModel createGpLlmV2CompletionExpectedEndpo ); } - private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2ChatCompletionAuthorizedEndpoint() { - return new AuthorizationResponseEntity.AuthorizedEndpoint( + private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2ChatCompletionAuthorizedEndpoint() { + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, GP_LLM_V2_MODEL_NAME, createTaskTypeObject(EIS_CHAT_PATH, "chat_completion"), @@ -405,8 +413,8 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2ChatC ); } - private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2CompletionAuthorizedEndpoint() { - return new AuthorizationResponseEntity.AuthorizedEndpoint( + private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2CompletionAuthorizedEndpoint() { + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( GP_LLM_V2_COMPLETION_ENDPOINT_ID, GP_LLM_V2_MODEL_NAME, createTaskTypeObject(EIS_CHAT_PATH, "completion"), @@ -433,11 +441,13 @@ private static ElasticInferenceServiceModel createRainbowSprinklesExpectedEndpoi public static EisAuthorizationResponse getEisRainbowSprinklesAuthorizationResponse(String url) { var authorizedEndpoints = List.of(createRainbowSprinklesAuthorizedEndpoint()); - var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream() + .map(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_RAINBOW_SPRINKLES_RESPONSE, - new AuthorizationResponseEntity(authorizedEndpoints), + new ElasticInferenceServiceAuthorizationResponseEntity(authorizedEndpoints), List.of(createRainbowSprinklesExpectedEndpoint(url)), inferenceIds ); @@ -446,18 +456,20 @@ public static EisAuthorizationResponse getEisRainbowSprinklesAuthorizationRespon public static EisAuthorizationResponse getEisJinaEmbedAuthorizationResponse(String url) { var authorizedEndpoints = List.of(createJinaEmbedAuthorizedEndpoint()); - var inferenceIds = authorizedEndpoints.stream().map(AuthorizationResponseEntity.AuthorizedEndpoint::id).collect(Collectors.toSet()); + var inferenceIds = authorizedEndpoints.stream() + .map(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint::id) + .collect(Collectors.toSet()); return new EisAuthorizationResponse( EIS_JINA_EMBED_RESPONSE, - new AuthorizationResponseEntity(authorizedEndpoints), + new ElasticInferenceServiceAuthorizationResponseEntity(authorizedEndpoints), List.of(createJinaExpectedEndpoint(url)), inferenceIds ); } - private static AuthorizationResponseEntity.AuthorizedEndpoint createJinaEmbedAuthorizedEndpoint() { - return new AuthorizationResponseEntity.AuthorizedEndpoint( + private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createJinaEmbedAuthorizedEndpoint() { + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( JINA_EMBED_V3_ENDPOINT_ID, JINA_EMBED_V3_MODEL_NAME, createTaskTypeObject(EIS_EMBED_PATH, "text_embedding"), @@ -465,7 +477,7 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createJinaEmbedAut List.of("multilingual", "open-weights"), "2024-05-01", null, - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( "cosine", 1024, "float", @@ -487,18 +499,18 @@ private static ElasticInferenceServiceModel createJinaExpectedEndpoint(String ur ); } - public static AuthorizationResponseEntity createResponse() { - return new AuthorizationResponseEntity( + public static ElasticInferenceServiceAuthorizationResponseEntity createResponse() { + return new ElasticInferenceServiceAuthorizationResponseEntity( randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) ); } - public static AuthorizationResponseEntity.AuthorizedEndpoint createInvalidTaskTypeAuthorizedEndpoint() { + public static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createInvalidTaskTypeAuthorizedEndpoint() { var id = randomAlphaOfLength(10); var name = randomAlphaOfLength(10); var status = randomFrom("ga", "beta", "preview"); - return new AuthorizationResponseEntity.AuthorizedEndpoint( + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, name, createTaskTypeObject("invalid/task/type", TaskType.ANY.toString()), @@ -510,13 +522,13 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createInvalidTaskTy ); } - public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { + public static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) { var id = randomAlphaOfLength(10); var name = randomAlphaOfLength(10); var status = randomFrom("ga", "beta", "preview"); return switch (taskType) { - case CHAT_COMPLETION -> new AuthorizationResponseEntity.AuthorizedEndpoint( + case CHAT_COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, name, createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), @@ -526,7 +538,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd "", null ); - case SPARSE_EMBEDDING -> new AuthorizationResponseEntity.AuthorizedEndpoint( + case SPARSE_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, name, createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), @@ -536,7 +548,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd "", null ); - case TEXT_EMBEDDING -> new AuthorizationResponseEntity.AuthorizedEndpoint( + case TEXT_EMBEDDING -> new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, name, createTaskTypeObject(EIS_EMBED_PATH, TaskType.TEXT_EMBEDDING.toString()), @@ -544,14 +556,14 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd null, "", "", - new AuthorizationResponseEntity.Configuration( + new ElasticInferenceServiceAuthorizationResponseEntity.Configuration( randomFrom(SimilarityMeasure.values()).toString(), randomInt(), DenseVectorFieldMapper.ElementType.FLOAT.toString(), null ) ); - case RERANK -> new AuthorizationResponseEntity.AuthorizedEndpoint( + case RERANK -> new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, name, createTaskTypeObject(EIS_RERANK_PATH, TaskType.RERANK.toString()), @@ -561,7 +573,7 @@ public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEnd "", null ); - case COMPLETION -> new AuthorizationResponseEntity.AuthorizedEndpoint( + case COMPLETION -> new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( id, name, createTaskTypeObject(EIS_CHAT_PATH, TaskType.COMPLETION.toString()), @@ -580,11 +592,11 @@ public void testParseAllFields() throws IOException { var url = "http://example.com/authorize"; var responseData = getEisAuthorizationResponseWithMultipleEndpoints(url); try (var parser = createParser(JsonXContent.jsonXContent, responseData.responseJson())) { - var entity = AuthorizationResponseEntity.PARSER.apply(parser, null); + var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null); assertThat(entity, is(responseData.responseEntity())); - var authModel = AuthorizationModel.of(responseData.responseEntity(), url); + var authModel = ElasticInferenceServiceAuthorizationModel.of(responseData.responseEntity(), url); assertThat(authModel.getEndpointIds(), containsInAnyOrder(responseData.inferenceIds().toArray(String[]::new))); assertThat( @@ -607,24 +619,28 @@ public void testParseAllFields() throws IOException { } @Override - protected AuthorizationResponseEntity mutateInstanceForVersion(AuthorizationResponseEntity instance, TransportVersion version) { + protected ElasticInferenceServiceAuthorizationResponseEntity mutateInstanceForVersion( + ElasticInferenceServiceAuthorizationResponseEntity instance, + TransportVersion version + ) { return instance; } @Override - protected Writeable.Reader instanceReader() { - return AuthorizationResponseEntity::new; + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceAuthorizationResponseEntity::new; } @Override - protected AuthorizationResponseEntity createTestInstance() { + protected ElasticInferenceServiceAuthorizationResponseEntity createTestInstance() { return createResponse(); } @Override - protected AuthorizationResponseEntity mutateInstance(AuthorizationResponseEntity instance) throws IOException { + protected ElasticInferenceServiceAuthorizationResponseEntity mutateInstance(ElasticInferenceServiceAuthorizationResponseEntity instance) + throws IOException { var newEndpoints = new ArrayList<>(instance.getAuthorizedEndpoints()); newEndpoints.add(createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))); - return new AuthorizationResponseEntity(newEndpoints); + return new ElasticInferenceServiceAuthorizationResponseEntity(newEndpoints); } } From 4d3613b980728a946769e74e8f2f03d00ed41762 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 2 Dec 2025 17:50:38 -0500 Subject: [PATCH 23/24] Addressing feedback --- .../action/StoreInferenceEndpointsAction.java | 6 +-- ...icInferenceServiceAuthorizationServer.java | 4 +- .../AuthorizationTaskExecutorIT.java | 15 +++--- ...horizationTaskExecutorMultipleNodesIT.java | 7 ++- .../inference/integration/CCMServiceIT.java | 8 ++-- .../TransportGetInferenceServicesAction.java | 1 + .../inference/registry/ModelRegistry.java | 14 +++--- .../authorization/AuthorizationPoller.java | 4 +- ...ticInferenceServiceAuthorizationModel.java | 12 ++--- ...nceServiceAuthorizationResponseEntity.java | 4 +- .../elastic/ElasticInferenceServiceTests.java | 2 +- .../AuthorizationPollerTests.java | 8 ++-- ...ferenceServiceAuthorizationModelTests.java | 34 +------------- ...rviceAuthorizationRequestHandlerTests.java | 4 +- ...rviceAuthorizationResponseEntityTests.java | 46 +++++++++++-------- 15 files changed, 71 insertions(+), 98 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java index 01f2ac02dc284..aa613cda60399 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java @@ -35,9 +35,9 @@ public StoreInferenceEndpointsAction() { } public static class Request extends AcknowledgedRequest { - private final List models; + private final List models; - public Request(List models, TimeValue timeout) { + public Request(List models, TimeValue timeout) { super(timeout, DEFAULT_ACK_TIMEOUT); this.models = Objects.requireNonNull(models); } @@ -53,7 +53,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(models); } - public List getModels() { + public List getModels() { return models; } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index a25e82c4a336f..aa37096feea76 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -25,8 +25,8 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule private final MockWebServer webServer = new MockWebServer(); public void enqueueAuthorizeAllModelsResponse() { - var authResponse = getEisAuthorizationResponseWithMultipleEndpoints("ignored"); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponse.responseJson())); + var authResponseBody = getEisAuthorizationResponseWithMultipleEndpoints("ignored").responseJson(); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponseBody)); } public String getUrl() { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 431509791255c..6916497014a0c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -66,7 +66,7 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; - private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; + private static String chatCompletionResponseBody; private ModelRegistry modelRegistry; private AuthorizationTaskExecutor authorizationTaskExecutor; @@ -76,7 +76,7 @@ public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); - chatCompletionResponse = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl); + chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl).responseJson(); } @Before @@ -122,7 +122,7 @@ protected Collection> getPlugins() { public void testCreatesEisChatCompletionEndpoint() throws Exception { assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody)); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); @@ -227,7 +227,7 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception { public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception { assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody)); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); @@ -261,7 +261,7 @@ static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesMode public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { assertNoAuthorizedEisEndpoints(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody)); restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); @@ -273,8 +273,9 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep assertChatCompletionEndpointExists(); // Simulate that a text embedding model is now authorized - var jinaEmbedResponse = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponse.responseJson())); + var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl) + .responseJson(); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponseBody)); restartPollingTaskAndWaitForAuthResponse(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index cf3b08fea3bfb..7d8c20870589b 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -55,14 +54,14 @@ public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase { private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; - private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; + private static String chatCompletionResponseBody; @BeforeClass public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE)); - chatCompletionResponse = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl); + chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl).responseJson(); } @Before @@ -113,7 +112,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun ); // queue a response that authorizes one model - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody)); assertTrue("expected the node to shutdown properly", internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java index 3a9147f5f4fc1..9153bb51bf1a1 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java @@ -49,7 +49,7 @@ public class CCMServiceIT extends CCMSingleNodeIT { private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; - private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse; + private static String chatCompletionResponseBody; private AuthorizationTaskExecutor authorizationTaskExecutor; private ModelRegistry modelRegistry; @@ -79,9 +79,9 @@ public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); - chatCompletionResponse = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse( + chatCompletionResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse( gatewayUrl - ); + ).responseJson(); } @Before @@ -144,7 +144,7 @@ public void testCreatesEisChatCompletionEndpoint() throws Exception { var eisEndpoints = getEisEndpoints(modelRegistry); assertThat(eisEndpoints, empty()); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson())); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody)); var listener = new TestPlainActionFuture(); ccmService.get().storeConfiguration(new CCMModel(new SecureString("secret".toCharArray())), listener); listener.actionGet(TimeValue.THIRTY_SECONDS); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index 6f7f47abad355..b57de343ff2e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -127,6 +127,7 @@ private void getServiceConfigurationsForServicesAndEis( threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); }).>andThen((configurationListener, authorizationModel) -> { var serviceConfigs = getServiceConfigurationsForServices(availableServices); + serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService)); if (authorizationModel.isAuthorized() == false) { configurationListener.onResponse(serviceConfigs); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 62832b3538e7e..cf731b22807e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -712,12 +712,12 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< }), timeout); } - public void storeModels(List models, ActionListener> listener, TimeValue timeout) { + public void storeModels(List models, ActionListener> listener, TimeValue timeout) { storeModels(models, true, listener, timeout); } private void storeModels( - List models, + List models, boolean updateClusterState, ActionListener> listener, TimeValue timeout @@ -745,7 +745,7 @@ private void storeModels( } private ActionListener getStoreMultipleModelsListener( - List models, + List models, boolean updateClusterState, ActionListener> listener, TimeValue timeout @@ -818,12 +818,12 @@ private ActionListener getStoreMultipleModelsListener( private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {} - private record ResponseInfo(List responses, List successfullyStoredModels) {} + private record ResponseInfo(List responses, List successfullyStoredModels) {} private static ResponseInfo getResponseInfo( BulkResponse bulkResponse, Map docIdToInferenceId, - Map inferenceIdToModel + Map inferenceIdToModel ) { var responses = new ArrayList(); var successfullyStoredModels = new ArrayList(); @@ -909,7 +909,7 @@ private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item } } - private static Model getModelFromMap(@Nullable String inferenceId, Map inferenceIdToModel) { + private static Model getModelFromMap(@Nullable String inferenceId, Map inferenceIdToModel) { if (inferenceId != null) { return inferenceIdToModel.get(inferenceId); } @@ -917,7 +917,7 @@ private static Model getModelFromMap(@Nullable String inferenceId, Map models, ActionListener listener, TimeValue timeout) { + private void updateClusterState(List models, ActionListener listener, TimeValue timeout) { var inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet()); var storeListener = listener.delegateResponse((delegate, exc) -> { logger.warn(format("Failed to add minimal service settings to cluster state for inference endpoints %s", inferenceIdsSet), exc); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index f8ef0596b4749..2532a64f27bb9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -314,7 +314,7 @@ private void sendRequest(ActionListener listener) { .addListener(listener); } - private List getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { + private List getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { logger.debug("Received authorization response, {}", authModel); var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); @@ -328,7 +328,7 @@ private List getNewInferenceEndpointsToStore(ElasticInferenceSe return scopedAuthModel.getEndpoints(newEndpointIds); } - private void storePreconfiguredModels(List newEndpoints, ActionListener listener) { + private void storePreconfiguredModels(List newEndpoints, ActionListener listener) { if (newEndpoints.isEmpty()) { listener.onResponse(null); return; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java index dec387036f0b9..395be67d20d4a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; @@ -48,6 +49,7 @@ public class ElasticInferenceServiceAuthorizationModel { private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping"; private static final String UNSUPPORTED_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unsupported task type [{}], skipping"; + // public because it's used in tests outside the package public static ElasticInferenceServiceAuthorizationModel of( ElasticInferenceServiceAuthorizationResponseEntity responseEntity, String baseEisUrl @@ -214,12 +216,6 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy } private static SimilarityMeasure getSimilarityMeasure(ElasticInferenceServiceAuthorizationResponseEntity.Configuration configuration) { - validateFieldPresent( - ElasticInferenceServiceAuthorizationResponseEntity.Configuration.SIMILARITY, - configuration.similarity(), - TaskType.TEXT_EMBEDDING - ); - return SimilarityMeasure.fromString(configuration.similarity()); } @@ -292,8 +288,8 @@ public Set getEndpointIds() { return Set.copyOf(authorizedEndpoints.keySet()); } - public List getEndpoints(Set endpointIds) { - return endpointIds.stream().map(authorizedEndpoints::get).filter(Objects::nonNull).toList(); + public List getEndpoints(Set endpointIds) { + return endpointIds.stream().map(authorizedEndpoints::get).filter(Objects::nonNull).toList(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java index be3dc28ee0517..6004fe29d30a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -301,8 +301,8 @@ public String toString() { private final List authorizedEndpoints; - public ElasticInferenceServiceAuthorizationResponseEntity(List authorizedModels) { - this.authorizedEndpoints = Objects.requireNonNull(authorizedModels); + public ElasticInferenceServiceAuthorizationResponseEntity(List authorizedEndpoints) { + this.authorizedEndpoints = Objects.requireNonNull(authorizedEndpoints); } /** diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index e6f811f4aef65..393a58e7fdc85 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -933,7 +933,7 @@ public void testChunkedInfer_noInputs() throws IOException { } } - public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() throws Exception { + public void testHideFromConfigurationApi_ThrowsUnsupported() throws Exception { try (var service = createServiceWithMockSender()) { expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index eb0419ccf9f27..ead9702eef587 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -182,7 +182,7 @@ public void testOnlyMarksCompletedOnce() { public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of()); var url = "eis-url"; var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); @@ -253,7 +253,7 @@ private ElasticInferenceServiceSparseEmbeddingsModel createSparseEndpoint(String public void testSendsAuthorizationRequest_WhenCCMIsNotConfigurable() { var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of()); var url = "eis-url"; var sparseModel = createAuthorizedEndpoint(TaskType.SPARSE_EMBEDDING); @@ -356,7 +356,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of()); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { @@ -398,7 +398,7 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { when(mockRegistry.isReady()).thenReturn(true); // Since the registry is already aware of the sparse endpoint, the authorization poller will not consider it a new // one and not attempt to store it. - when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", sparseModel.id())); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of(sparseModel.id())); var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java index 0a2b80d222ca0..aa9ec6260873f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java @@ -171,40 +171,8 @@ public void testIgnoresDuplicateId() { ); assertThat(auth.getEndpoints(Set.of(id1)), is(List.of(chatCompletionEndpoint))); - } - - public void testReturnsAuthorizedTaskTypes_UsesFirstInferenceId_IfDuplicates() { - var id = "id1"; - - var response = new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( - id, - "name1", - createTaskTypeObject(EIS_CHAT_PATH, TaskType.CHAT_COMPLETION.toString()), - "ga", - null, - "", - "", - null - ), - // This should be ignored because the id is a duplicate - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( - id, - "name2", - createTaskTypeObject(EIS_SPARSE_PATH, TaskType.SPARSE_EMBEDDING.toString()), - "ga", - null, - "", - "", - null - ) - ) - ); - - var auth = ElasticInferenceServiceAuthorizationModel.of(response, "url"); assertThat(auth.getTaskTypes(), is(Set.of(TaskType.CHAT_COMPLETION))); - assertThat(auth.getEndpointIds(), is(Set.of(id))); + assertThat(auth.getEndpointIds(), is(Set.of(id1))); assertTrue(auth.isAuthorized()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index b8f197dddbca6..3bdac6ba1c132 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -248,10 +248,10 @@ public void testGetAuthorization_ReturnsAValidResponse_WithAuthHeader() throws I createApplierFactory(secret) ); - var elserResponse = getEisElserAuthorizationResponse(eisGatewayUrl); + var elserResponseBody = getEisElserAuthorizationResponse(eisGatewayUrl).responseJson(); try (var sender = senderFactory.createSender()) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse.responseJson())); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponseBody)); PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java index d8655e3e16abb..bd3c6cf58c958 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntityTests.java @@ -324,16 +324,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn createGpLlmV2CompletionAuthorizedEndpoint(), createElserAuthorizedEndpoint(), createJinaEmbedAuthorizedEndpoint(), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( - RERANK_V1_ENDPOINT_ID, - RERANK_V1_MODEL_NAME, - createTaskTypeObject(EIS_RERANK_PATH, "rerank"), - "preview", - List.of(), - "2024-05-01", - null, - null - ) + createRerankV1AuthorizedEndpoint() ); var inferenceIds = authorizedEndpoints.stream() @@ -349,15 +340,7 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn createGpLlmV2CompletionExpectedEndpoint(url), createElserExpectedEndpoint(url), createJinaExpectedEndpoint(url), - new ElasticInferenceServiceRerankModel( - RERANK_V1_ENDPOINT_ID, - TaskType.RERANK, - ElasticInferenceService.NAME, - new ElasticInferenceServiceRerankServiceSettings(RERANK_V1_MODEL_NAME), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) - ) + createRerankV1ExpectedEndpoint(url) ), inferenceIds ); @@ -499,6 +482,31 @@ private static ElasticInferenceServiceModel createJinaExpectedEndpoint(String ur ); } + private static ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint createRerankV1AuthorizedEndpoint() { + return new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint( + RERANK_V1_ENDPOINT_ID, + RERANK_V1_MODEL_NAME, + createTaskTypeObject(EIS_RERANK_PATH, "rerank"), + "preview", + List.of(), + "2024-05-01", + null, + null + ); + } + + private static ElasticInferenceServiceRerankModel createRerankV1ExpectedEndpoint(String url) { + return new ElasticInferenceServiceRerankModel( + RERANK_V1_ENDPOINT_ID, + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings(RERANK_V1_MODEL_NAME), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + } + public static ElasticInferenceServiceAuthorizationResponseEntity createResponse() { return new ElasticInferenceServiceAuthorizationResponseEntity( randomList(1, 5, () -> createAuthorizedEndpoint(randomFrom(ElasticInferenceService.IMPLEMENTED_TASK_TYPES))) From f8c263507e08bcac3735c7b79e17e20cdde00224 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 2 Dec 2025 19:50:27 -0500 Subject: [PATCH 24/24] Refactoring into single if and removing listener --- .../TransportGetInferenceServicesAction.java | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index b57de343ff2e7..e776ba0690613 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -125,26 +125,21 @@ private void getServiceConfigurationsForServicesAndEis( SubscribableListener.newForked(authModelListener -> { // Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); - }).>andThen((configurationListener, authorizationModel) -> { + }).andThenApply((authorizationModel) -> { var serviceConfigs = getServiceConfigurationsForServices(availableServices); - serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService)); - - if (authorizationModel.isAuthorized() == false) { - configurationListener.onResponse(serviceConfigs); - return; - } - // If there was a requested task type and the authorization response from EIS doesn't support it, we'll exclude EIS as a valid - // service - if (requestedTaskType != null && authorizationModel.getTaskTypes().contains(requestedTaskType) == false) { - configurationListener.onResponse(serviceConfigs); - return; + if (authorizationModel.isAuthorized() == false + // If there was a requested task type and the authorization response from EIS doesn't support it, we'll exclude EIS as a + // valid service + || (requestedTaskType != null && authorizationModel.getTaskTypes().contains(requestedTaskType) == false)) { + serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService)); + return serviceConfigs; } var config = ElasticInferenceService.createConfiguration(authorizationModel.getTaskTypes()); serviceConfigs.add(config); serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService)); - configurationListener.onResponse(serviceConfigs); + return serviceConfigs; }) .addListener( listener.delegateFailureAndWrap(