diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java new file mode 100644 index 0000000000000..fbb5e25079f62 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java @@ -0,0 +1,277 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsAction; +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.IndexOptions; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.GetLicenseAction; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.license.PostStartBasicAction; +import org.elasticsearch.license.PostStartBasicRequest; +import org.elasticsearch.license.PutLicenseAction; +import org.elasticsearch.license.PutLicenseRequest; +import org.elasticsearch.license.TestUtils; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.protocol.xpack.license.GetLicenseRequest; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.InferenceIndex; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.CoreMatchers.equalTo; + +public class SemanticTextIndexOptionsIT extends ESIntegTestCase { + private static final String INDEX_NAME = "test-index"; + private static final Map BBQ_COMPATIBLE_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + private final Map inferenceIds = new HashMap<>(); + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class); + } + + @Before + public void resetLicense() throws Exception { + setLicense(License.LicenseType.TRIAL); + } + + @After + public void cleanUp() { + assertAcked( + safeGet( + client().admin() + .indices() + .prepareDelete(INDEX_NAME) + .setIndicesOptions( + IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() + ) + .execute() + ) + ); + + for (var entry : inferenceIds.entrySet()) { + assertAcked( + safeGet( + client().execute( + DeleteInferenceEndpointAction.INSTANCE, + new DeleteInferenceEndpointAction.Request(entry.getKey(), entry.getValue(), true, false) + ) + ) + ); + } + } + + public void testValidateIndexOptionsWithBasicLicense() throws Exception { + final String inferenceId = "test-inference-id-1"; + final String inferenceFieldName = "inference_field"; + createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS); + downgradeLicenseAndRestartCluster(); + + IndexOptions indexOptions = new DenseVectorFieldMapper.Int8HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + null, + null + ); + assertAcked( + safeGet(prepareCreate(INDEX_NAME).setMapping(generateMapping(inferenceFieldName, inferenceId, indexOptions)).execute()) + ); + + final Map expectedFieldMapping = generateExpectedFieldMapping(inferenceFieldName, inferenceId, indexOptions); + assertThat(getFieldMappings(inferenceFieldName, false), equalTo(expectedFieldMapping)); + } + + private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + + inferenceIds.put(inferenceId, taskType); + } + + private static XContentBuilder generateMapping(String inferenceFieldName, String inferenceId, @Nullable IndexOptions indexOptions) + throws IOException { + XContentBuilder mapping = XContentFactory.jsonBuilder(); + mapping.startObject(); + mapping.field("properties"); + generateFieldMapping(mapping, inferenceFieldName, inferenceId, indexOptions); + mapping.endObject(); + + return mapping; + } + + private static void generateFieldMapping( + XContentBuilder builder, + String inferenceFieldName, + String inferenceId, + @Nullable IndexOptions indexOptions + ) throws IOException { + builder.startObject(); + builder.startObject(inferenceFieldName); + builder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + builder.field("inference_id", inferenceId); + if (indexOptions != null) { + builder.startObject("index_options"); + if (indexOptions instanceof DenseVectorFieldMapper.DenseVectorIndexOptions) { + builder.field("dense_vector"); + indexOptions.toXContent(builder, ToXContent.EMPTY_PARAMS); + } + builder.endObject(); + } + builder.endObject(); + builder.endObject(); + } + + private static Map generateExpectedFieldMapping( + String inferenceFieldName, + String inferenceId, + @Nullable IndexOptions indexOptions + ) throws IOException { + Map expectedFieldMapping; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + generateFieldMapping(builder, inferenceFieldName, inferenceId, indexOptions); + expectedFieldMapping = XContentHelper.convertToMap(BytesReference.bytes(builder), false, XContentType.JSON).v2(); + } + + return expectedFieldMapping; + } + + @SuppressWarnings("unchecked") + private static Map filterNullOrEmptyValues(Map map) { + Map filteredMap = new HashMap<>(); + for (var entry : map.entrySet()) { + Object value = entry.getValue(); + if (entry.getValue() instanceof Map mapValue) { + if (mapValue.isEmpty()) { + continue; + } + + value = filterNullOrEmptyValues((Map) mapValue); + } + + if (value != null) { + filteredMap.put(entry.getKey(), value); + } + } + + return filteredMap; + } + + private static Map getFieldMappings(String fieldName, boolean includeDefaults) { + var request = new GetFieldMappingsRequest().indices(INDEX_NAME).fields(fieldName).includeDefaults(includeDefaults); + return safeGet(client().execute(GetFieldMappingsAction.INSTANCE, request)).fieldMappings(INDEX_NAME, fieldName).sourceAsMap(); + } + + private static void setLicense(License.LicenseType type) throws Exception { + if (type == License.LicenseType.BASIC) { + assertAcked( + safeGet( + client().execute( + PostStartBasicAction.INSTANCE, + new PostStartBasicRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT).acknowledge(true) + ) + ) + ); + } else { + License license = TestUtils.generateSignedLicense( + type.getTypeName(), + License.VERSION_CURRENT, + -1, + TimeValue.timeValueHours(24) + ); + assertAcked( + safeGet( + client().execute( + PutLicenseAction.INSTANCE, + new PutLicenseRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT).license(license) + ) + ) + ); + } + } + + private static void assertLicense(License.LicenseType type) { + var getLicenseResponse = safeGet(client().execute(GetLicenseAction.INSTANCE, new GetLicenseRequest(TEST_REQUEST_TIMEOUT))); + assertThat(getLicenseResponse.license().type(), equalTo(type.getTypeName())); + } + + private void downgradeLicenseAndRestartCluster() throws Exception { + // Downgrade the license and restart the cluster to force the model registry to rebuild + setLicense(License.LicenseType.BASIC); + internalCluster().fullRestart(new InternalTestCluster.RestartCallback()); + ensureGreen(InferenceIndex.INDEX_NAME); + assertLicense(License.LicenseType.BASIC); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 5400bf6acc673..8f111bcb2c785 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1223,7 +1223,7 @@ static boolean indexVersionDefaultsToBbqHnsw(IndexVersion indexVersion) { || indexVersion.between(SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0); } - static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDenseVectorIndexOptions() { + public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDenseVectorIndexOptions() { int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE);