Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
db5a7b3
Starting new response class
jonathan-buttner Nov 14, 2025
5060aab
Writing tests
jonathan-buttner Nov 17, 2025
6a31be2
Fixing tests
jonathan-buttner Nov 18, 2025
70ef1a5
[CI] Auto commit changes from spotless
Nov 18, 2025
20d7881
Successful tests
jonathan-buttner Nov 18, 2025
4a56082
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 18, 2025
057f15e
Merge branch 'ml-eis-auth-v2' of github.com:jonathan-buttner/elastics…
jonathan-buttner Nov 18, 2025
2546292
Removing unused code
jonathan-buttner Nov 18, 2025
839ebe0
Renaming
jonathan-buttner Nov 18, 2025
e7b2371
[CI] Auto commit changes from spotless
Nov 18, 2025
f6b2d74
Working integration tests
jonathan-buttner Nov 19, 2025
f2a3d1a
Fixing forbidden calls
jonathan-buttner Nov 19, 2025
3b370db
Merge branch 'ml-eis-auth-v2' of github.com:jonathan-buttner/elastics…
jonathan-buttner Nov 19, 2025
58361ea
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 19, 2025
b2aecd2
Fixing tests
jonathan-buttner Nov 19, 2025
3c0e1b3
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 19, 2025
5715af6
Fixing integration tests
jonathan-buttner Nov 19, 2025
7834662
Refactoring tests
jonathan-buttner Nov 19, 2025
cd4388a
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 19, 2025
a540427
Adding some comments
jonathan-buttner Nov 20, 2025
92335ef
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 20, 2025
6656baf
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 24, 2025
e8e3577
Fixing gp llm v2 name
jonathan-buttner Nov 24, 2025
05b0564
Updating test name for rerank
jonathan-buttner Nov 24, 2025
a4543db
Removing named writeable
jonathan-buttner Nov 25, 2025
43a8f31
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 25, 2025
7dd71f0
Removing import
jonathan-buttner Nov 25, 2025
b920069
comments
jonathan-buttner Nov 25, 2025
e4e9095
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 25, 2025
6b6d21b
Adding support for completion
jonathan-buttner Nov 25, 2025
578546f
Fixing tests
jonathan-buttner Nov 25, 2025
6aa0ea5
Addressing feedback
jonathan-buttner Nov 25, 2025
2379066
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Nov 25, 2025
eb6def2
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Dec 2, 2025
4d3613b
Addressing feedback
jonathan-buttner Dec 2, 2025
f8c2635
Refactoring into single if and removing listener
jonathan-buttner Dec 3, 2025
06f1f66
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Dec 3, 2025
7ed5393
Merge branch 'main' into ml-eis-auth-v2
jonathan-buttner Dec 3, 2025
9869cdd
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-a…
jonathan-buttner Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.core.XPackClientPlugin;
Expand Down Expand Up @@ -43,7 +44,7 @@ protected StoreInferenceEndpointsAction.Request createTestInstance() {

@Override
protected StoreInferenceEndpointsAction.Request mutateInstance(StoreInferenceEndpointsAction.Request instance) throws IOException {
var newModels = new ArrayList<>(instance.getModels());
var newModels = new ArrayList<Model>(instance.getModels());
newModels.add(ModelTests.randomModel());
return new StoreInferenceEndpointsAction.Request(newModels, instance.masterNodeTimeout());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ dependencies {
javaRestTestImplementation project(path: xpackModule('core'))
javaRestTestImplementation project(path: xpackModule('inference'))
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')

// Allow javaRestTest to see unit-test classes from x-pack:plugin:inference so we can import some variables
javaRestTestImplementation(testArtifact(project(xpackModule('inference'))))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so the qa tests can imported from the unit tests in the inference plugin.


// Added this to have access to MockWebServer within the tests
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
Expand All @@ -26,6 +25,7 @@

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel;
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;

public class BaseMockEISAuthServerTest extends ESRestTestCase {

Expand Down Expand Up @@ -93,6 +93,6 @@ public void ensureEisPreconfiguredEndpointsExist() throws Exception {
// available
// Technically this only needs to be done before the suite runs but the underlying client is created in @Before and not statically
// for the suite
assertBusy(() -> getModel(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2));
assertBusy(() -> getModel(ELSER_V2_ENDPOINT_ID));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

Expand Down Expand Up @@ -55,12 +61,12 @@ public void testGetDefaultEndpoints() throws IOException {
assertEquals("completion", model.get("task_type"));
}

assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".gp-llm-v2-chat_completion", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".gp-llm-v2-completion", TaskType.COMPLETION);
assertInferenceIdTaskType(allModels, ".elser-2-elastic", TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, ".jina-embeddings-v3", TaskType.TEXT_EMBEDDING);
assertInferenceIdTaskType(allModels, ".jina-reranker-v2", TaskType.RERANK);
assertInferenceIdTaskType(allModels, RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, GP_LLM_V2_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION);
assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING);
assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK);
}

private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,11 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"streaming_completion_test_service",
"completion_test_service",
"hugging_face",
"elastic",
"amazon_sagemaker",
"mistral",
"nvidia",
"watsonxai"
"watsonxai",
"nvidia"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,16 @@
import org.junit.runners.model.Statement;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints;

public class MockElasticInferenceServiceAuthorizationServer implements TestRule {

private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class);
private final MockWebServer webServer = new MockWebServer();

public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() {
var server = new MockElasticInferenceServiceAuthorizationServer();

server.enqueueAuthorizeAllModelsResponse();
return server;
}

public void enqueueAuthorizeAllModelsResponse() {
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "gp-llm-v2",
"task_types": ["chat"]
},
{
"model_name": "elser_model_2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "jina-embeddings-v3",
"task_types": ["embed/text/dense"]
},
{
"model_name": "jina-reranker-v2",
"task_types": ["rerank/text/text-similarity"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var authResponseBody = getEisAuthorizationResponseWithMultipleEndpoints("ignored").responseJson();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponseBody));
}

public String getUrl() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller;
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings;
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
Expand All @@ -37,38 +37,36 @@
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;

public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]";

public static final String EMPTY_AUTH_RESPONSE = """
{
"models": [
]
}
""";

public static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
}
]
}
""";
public static final Set<String> EIS_PRECONFIGURED_ENDPOINT_IDS = Set.of(
RAINBOW_SPRINKLES_ENDPOINT_ID,
ELSER_V2_ENDPOINT_ID,
JINA_EMBED_V3_ENDPOINT_ID,
RERANK_V1_ENDPOINT_ID
);

public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]";

private static final MockWebServer webServer = new MockWebServer();
private static String gatewayUrl;
private static String chatCompletionResponseBody;

private ModelRegistry modelRegistry;
private AuthorizationTaskExecutor authorizationTaskExecutor;
Expand All @@ -77,7 +75,8 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
public static void initClass() throws IOException {
webServer.start();
gatewayUrl = getUrl(webServer);
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl).responseJson();
}

@Before
Expand All @@ -94,7 +93,7 @@ public void shutdown() {
static void removeEisPreconfiguredEndpoints(ModelRegistry modelRegistry) {
// Delete all the eis preconfigured endpoints
var listener = new PlainActionFuture<Boolean>();
modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener);
modelRegistry.deleteModels(EIS_PRECONFIGURED_ENDPOINT_IDS, listener);
listener.actionGet(TimeValue.THIRTY_SECONDS);
}

Expand Down Expand Up @@ -123,7 +122,7 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
public void testCreatesEisChatCompletionEndpoint() throws Exception {
assertNoAuthorizedEisEndpoints();

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
restartPollingTaskAndWaitForAuthResponse();

assertChatCompletionEndpointExists();
Expand All @@ -149,7 +148,7 @@ static void assertNoAuthorizedEisEndpoints(
var eisEndpoints = getEisEndpoints(modelRegistry);
assertThat(eisEndpoints, empty());

for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) {
for (String eisPreconfiguredEndpoints : EIS_PRECONFIGURED_ENDPOINT_IDS) {
assertFalse(modelRegistry.containsPreconfiguredInferenceEndpointId(eisPreconfiguredEndpoints));
}
}
Expand Down Expand Up @@ -228,13 +227,13 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception {
public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception {
assertNoAuthorizedEisEndpoints();

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
restartPollingTaskAndWaitForAuthResponse();

assertChatCompletionEndpointExists();

// Simulate that the model is no longer authorized
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
restartPollingTaskAndWaitForAuthResponse();

assertChatCompletionEndpointExists();
Expand All @@ -250,55 +249,45 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) {

var rainbowSprinklesModel = eisEndpoints.get(0);
assertChatCompletionUnparsedModel(rainbowSprinklesModel);
assertTrue(
modelRegistry.containsPreconfiguredInferenceEndpointId(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)
);
assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID));
}

static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) {
assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION));
assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME));
assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
assertThat(rainbowSprinklesModel.inferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID));
}

public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception {
assertNoAuthorizedEisEndpoints();

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
restartPollingTaskAndWaitForAuthResponse();

assertChatCompletionEndpointExists();

// Simulate that the model is no longer authorized
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
restartPollingTaskAndWaitForAuthResponse();

assertChatCompletionEndpointExists();

// Simulate that a text embedding model is now authorized
var authorizedTextEmbeddingResponse = """
{
"models": [
{
"model_name": "jina-embeddings-v3",
"task_types": ["embed/text/dense"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse));
var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl)
.responseJson();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponseBody));

restartPollingTaskAndWaitForAuthResponse();

var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity()));
assertThat(eisEndpoints.size(), is(2));

assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
assertTrue(eisEndpoints.containsKey(RAINBOW_SPRINKLES_ENDPOINT_ID));
assertChatCompletionUnparsedModel(eisEndpoints.get(RAINBOW_SPRINKLES_ENDPOINT_ID));

assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID));
assertTrue(eisEndpoints.containsKey(JINA_EMBED_V3_ENDPOINT_ID));

var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID);
var textEmbeddingEndpoint = eisEndpoints.get(JINA_EMBED_V3_ENDPOINT_ID);
assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING));
assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME));
}
Expand All @@ -307,7 +296,7 @@ public void testRestartsTaskAfterAbort() throws Exception {
// Ensure the task is created and we get an initial authorization response
assertNoAuthorizedEisEndpoints();

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
// Abort the task and ensure it is restarted
restartPollingTaskAndWaitForAuthResponse();
}
Expand Down
Loading