Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Objects;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION;
import static org.elasticsearch.inference.TaskType.COMPLETION;
import static org.elasticsearch.inference.TaskType.RERANK;
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
Expand Down Expand Up @@ -97,6 +98,10 @@ public static MinimalServiceSettings completion() {
return new MinimalServiceSettings(COMPLETION, null, null, null);
}

public static MinimalServiceSettings chatCompletion() {
return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null);
}

public MinimalServiceSettings(Model model) {
this(
model.getTaskType(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file has been contributed to by a Generative AI
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.FeatureFlag;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.rules.RuleChain;
import org.junit.rules.TestRule;

public class BaseMockEISAuthServerTest extends ESRestTestCase {

// The reason we're retrying is there's a race condition between the node retrieving the
// authorization response and running the test. Retrieving the authorization should be very fast since
// we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure
// we'll automatically retry after waiting a second.
@Rule
public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1));

private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer
.enabledWithRainbowSprinklesAndElser();

private static final ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.security.enabled", "true")
// Adding both settings unless one feature flag is disabled in a particular environment
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)
// TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is gone now. @vidok ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's being removed in this PR: #120842

.setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl)
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
.build();

// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
// it to the cluster as a setting.
@ClassRule
public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster);

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}

@Override
protected Settings restClientSettings() {
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,20 @@ static String mockDenseServiceModelConfig() {
""";
}

protected void deleteModel(String modelId) throws IOException {
static void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
}

protected Response deleteModel(String modelId, String queryParams) throws IOException {
static Response deleteModel(String modelId, String queryParams) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId + "?" + queryParams);
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
return response;
}

protected void deleteModel(String modelId, TaskType taskType) throws IOException {
static void deleteModel(String modelId, TaskType taskType) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId));
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
Expand Down Expand Up @@ -229,12 +229,12 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin
assertStatusOkOrCreated(response);
}

protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
static Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
return putRequest(endpoint, modelConfig);
}

protected Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID);
return putRequest(endpoint, modelConfig);
}
Expand Down Expand Up @@ -265,12 +265,12 @@ protected void deletePipeline(String pipelineId) throws IOException {
/**
* Task type should be in modelConfig
*/
protected Map<String, Object> putModel(String modelId, String modelConfig) throws IOException {
static Map<String, Object> putModel(String modelId, String modelConfig) throws IOException {
String endpoint = Strings.format("_inference/%s", modelId);
return putRequest(endpoint, modelConfig);
}

Map<String, Object> putRequest(String endpoint, String body) throws IOException {
static Map<String, Object> putRequest(String endpoint, String body) throws IOException {
var request = new Request("PUT", endpoint);
request.setJsonEntity(body);
var response = client().performRequest(request);
Expand Down Expand Up @@ -318,18 +318,17 @@ protected Map<String, Object> getModel(String modelId) throws IOException {
}

@SuppressWarnings("unchecked")
protected List<Map<String, Object>> getModels(String modelId, TaskType taskType) throws IOException {
static List<Map<String, Object>> getModels(String modelId, TaskType taskType) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
return (List<Map<String, Object>>) getInternalAsMap(endpoint).get("endpoints");
}

@SuppressWarnings("unchecked")
protected List<Map<String, Object>> getAllModels() throws IOException {
var endpoint = Strings.format("_inference/_all");
static List<Map<String, Object>> getAllModels() throws IOException {
return (List<Map<String, Object>>) getInternalAsMap("_inference/_all").get("endpoints");
}

private Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
private static Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
var request = new Request("GET", endpoint);
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file has been contributed to by a Generative AI
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;

import java.io.IOException;

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
import static org.hamcrest.Matchers.hasSize;

public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {

public void testGetDefaultEndpoints() throws IOException {
var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);

if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
assertThat(allModels, hasSize(4));
assertThat(chatCompletionModels, hasSize(1));

for (var model : chatCompletionModels) {
assertEquals("chat_completion", model.get("task_type"));
}
} else {
assertThat(allModels, hasSize(3));
assertThat(chatCompletionModels, hasSize(0));
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,8 @@

import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.FeatureFlag;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.rules.RuleChain;
import org.junit.rules.TestRule;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -35,47 +23,7 @@
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.equalTo;

public class InferenceGetServicesIT extends ESRestTestCase {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to BaseMockEISAuthServerTest


// The reason we're retrying is there's a race condition between the node retrieving the
// authorization response and running the test. Retrieving the authorization should be very fast since
// we're hosting a local mock server but it's possible it could respond slower. So in the even of a test failure
// we'll automatically retry after waiting a second.
@Rule
public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1));

private static final MockElasticInferenceServiceAuthorizationServer mockEISServer = MockElasticInferenceServiceAuthorizationServer
.enabledWithSparseEmbeddingsAndChatCompletion();

private static final ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.security.enabled", "true")
// Adding both settings unless one feature flag is disabled in a particular environment
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)
// TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL
.setting("xpack.inference.eis.gateway.url", mockEISServer::getUrl)
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
.build();

// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
// it to the cluster as a setting.
@ClassRule
public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster);

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}

@Override
protected Settings restClientSettings() {
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule
private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class);
private final MockWebServer webServer = new MockWebServer();

public static MockElasticInferenceServiceAuthorizationServer enabledWithSparseEmbeddingsAndChatCompletion() {
public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() {
var server = new MockElasticInferenceServiceAuthorizationServer();

String responseJson = """
{
"models": [
{
"model_name": "model-a",
"task_types": ["embed/text/sparse", "chat"]
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ protected void masterOperation(
return;
}

if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) {
listener.onFailure(
new ElasticsearchStatusException(
"[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.",
RestStatus.BAD_REQUEST,
request.getInferenceEntityId()
)
);
return;
}

var requestAsMap = requestToMap(request);
var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ public ModelRegistry(Client client) {
defaultConfigIds = new HashMap<>();
}

/**
* Returns true if the provided inference entity id is the same as one of the default
* endpoints ids.
* @param inferenceEntityId the id to search for
* @return true if we find a match and false if not
*/
public boolean containsDefaultConfigId(String inferenceEntityId) {
return defaultConfigIds.containsKey(inferenceEntityId);
}

/**
* Set the default inference ids provided by the services
* @param defaultConfigId The default
Expand Down
Loading