Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ default void init(Client client) {}
* Whether this service should be hidden from the API. Should be used for services
* that are not ready to be used.
*/
default Boolean hideFromConfigurationApi() {
return Boolean.FALSE;
default boolean hideFromConfigurationApi() {
return false;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

Expand Down Expand Up @@ -80,14 +78,11 @@ public InferenceServiceConfiguration(StreamInput in) throws IOException {
private static final ConstructingObjectParser<InferenceServiceConfiguration, Void> PARSER = new ConstructingObjectParser<>(
"inference_service_configuration",
true,
args -> {
List<String> taskTypes = (ArrayList<String>) args[2];
return new InferenceServiceConfiguration.Builder().setService((String) args[0])
.setName((String) args[1])
.setTaskTypes(EnumSet.copyOf(taskTypes.stream().map(TaskType::fromString).collect(Collectors.toList())))
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
.build();
}
args -> new InferenceServiceConfiguration.Builder().setService((String) args[0])
.setName((String) args[1])
.setTaskTypes((List<String>) args[2])
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
.build()
);

static {
Expand Down Expand Up @@ -195,6 +190,16 @@ public Builder setTaskTypes(EnumSet<TaskType> taskTypes) {
return this;
}

public Builder setTaskTypes(List<String> taskTypes) {
var enumTaskTypes = EnumSet.noneOf(TaskType.class);

for (var supportedTaskTypeString : taskTypes) {
enumTaskTypes.add(TaskType.fromStringOrStatusException(supportedTaskTypeString));
}
this.taskTypes = enumTaskTypes;
return this;
}

public Builder setConfigurations(Map<String, SettingsConfiguration> configurations) {
this.configurations = configurations;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,47 @@ public void testToXContent() throws IOException {
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
}

public void testToXContent_EmptyTaskTypes() throws IOException {
String content = XContentHelper.stripWhitespace("""
{
"service": "some_provider",
"name": "Some Provider",
"task_types": [],
"configurations": {
"text_field_configuration": {
"description": "Wow, this tooltip is useful.",
"label": "Very important field",
"required": true,
"sensitive": true,
"updatable": false,
"type": "str"
},
"numeric_field_configuration": {
"default_value": 3,
"description": "Wow, this tooltip is useful.",
"label": "Very important numeric field",
"required": true,
"sensitive": false,
"updatable": true,
"type": "int"
}
}
}
""");

InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes(
new BytesArray(content),
XContentType.JSON
);
boolean humanReadable = true;
BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable);
InferenceServiceConfiguration parsed;
try (XContentParser parser = createParser(XContentType.JSON.xContent(), originalBytes)) {
parsed = InferenceServiceConfiguration.fromXContent(parser);
}
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
}

public void testToMap() {
InferenceServiceConfiguration configField = InferenceServiceConfigurationTestUtils.getRandomServiceConfigurationField();
Map<String, Object> configFieldAsMap = configField.toMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ dependencies {
javaRestTestImplementation project(path: xpackModule('core'))
javaRestTestImplementation project(path: xpackModule('inference'))
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
// Added this to have access to MockWebServer within the tests
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
}

tasks.named("javaRestTest").configure {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void testAttachToDeployment() throws IOException {

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);

var inferenceId = "inference_on_existing_deployment";
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
Expand Down Expand Up @@ -58,7 +58,7 @@ public void testAttachWithModelId() throws IOException {

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);

var inferenceId = "inference_on_existing_deployment";
var putModel = putModel(inferenceId, endpointConfig(modelId, deploymentId), TaskType.SPARSE_EMBEDDING);
Expand Down Expand Up @@ -93,7 +93,7 @@ public void testModelIdDoesNotMatch() throws IOException {

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);

var inferenceId = "inference_on_existing_deployment";
var e = expectThrows(
Expand Down Expand Up @@ -123,7 +123,7 @@ public void testNumAllocationsIsUpdated() throws IOException {

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);

var inferenceId = "test_num_allocations_updated";
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
Expand All @@ -145,7 +145,7 @@ public void testNumAllocationsIsUpdated() throws IOException {
)
);

assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
assertStatusOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public class InferenceBaseRestTest extends ESRestTestCase {
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
.build();

@ClassRule
public static MlModelServer mlModelServer = new MlModelServer();

Expand Down Expand Up @@ -175,20 +174,20 @@ static String mockDenseServiceModelConfig() {
protected void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
}

protected Response deleteModel(String modelId, String queryParams) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId + "?" + queryParams);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return response;
}

protected void deleteModel(String modelId, TaskType taskType) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId));
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
}

protected void putSemanticText(String endpointId, String indexName) throws IOException {
Expand All @@ -207,7 +206,7 @@ protected void putSemanticText(String endpointId, String indexName) throws IOExc
""", endpointId);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
}

protected void putSemanticText(String endpointId, String searchEndpointId, String indexName) throws IOException {
Expand All @@ -227,7 +226,7 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin
""", endpointId, searchEndpointId);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
}

protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
Expand Down Expand Up @@ -260,7 +259,7 @@ protected Map<String, Object> putPipeline(String pipelineId, String modelId) thr
protected void deletePipeline(String pipelineId) throws IOException {
var request = new Request("DELETE", Strings.format("_ingest/pipeline/%s", pipelineId));
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
}

/**
Expand All @@ -275,15 +274,15 @@ Map<String, Object> putRequest(String endpoint, String body) throws IOException
var request = new Request("PUT", endpoint);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

Map<String, Object> postRequest(String endpoint, String body) throws IOException {
var request = new Request("POST", endpoint);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

Expand All @@ -300,15 +299,15 @@ protected Map<String, Object> putE5TrainedModels() throws IOException {

request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

protected Map<String, Object> deployE5TrainedModels() throws IOException {
var request = new Request("POST", "_ml/trained_models/.multilingual-e5-small/deployment/_start?wait_for=fully_allocated");

var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

Expand All @@ -330,31 +329,13 @@ protected List<Map<String, Object>> getAllModels() throws IOException {
return (List<Map<String, Object>>) getInternalAsMap("_inference/_all").get("endpoints");
}

protected List<Object> getAllServices() throws IOException {
var endpoint = Strings.format("_inference/_services");
return getInternalAsList(endpoint);
}

@SuppressWarnings("unchecked")
protected List<Object> getServices(TaskType taskType) throws IOException {
var endpoint = Strings.format("_inference/_services/%s", taskType);
return getInternalAsList(endpoint);
}

private Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
var request = new Request("GET", endpoint);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

private List<Object> getInternalAsList(String endpoint) throws IOException {
var request = new Request("GET", endpoint);
var response = client().performRequest(request);
assertOkOrCreated(response);
return entityAsList(response);
}

protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
var endpoint = Strings.format("_inference/%s", modelId);
return inferInternal(endpoint, input, null, Map.of());
Expand Down Expand Up @@ -475,7 +456,7 @@ private Map<String, Object> inferInternal(
) throws IOException {
var request = createInferenceRequest(endpoint, input, query, queryParameters);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

Expand Down Expand Up @@ -511,7 +492,7 @@ protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int
}
}

protected static void assertOkOrCreated(Response response) throws IOException {
static void assertStatusOkOrCreated(Response response) throws IOException {
int statusCode = response.getStatusLine().getStatusCode();
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
// Avoid that call with check here.
Expand All @@ -527,7 +508,7 @@ protected Map<String, Object> getTrainedModel(String inferenceEntityId) throws I
var endpoint = Strings.format("_ml/trained_models/%s/_stats", inferenceEntityId);
var request = new Request("GET", endpoint);
var response = client().performRequest(request);
assertOkOrCreated(response);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}
}
Loading