diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 7119f44d36444..dc30b6680c17c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -224,6 +224,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_TELEMETRY_MEMORY_ADDED = def(8_748_00_0); public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0); public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0); + public static final TransportVersion ML_INFERENCE_CHUNKING_SETTINGS = def(8_751_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java new file mode 100644 index 0000000000000..2e9072626b0a8 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; +import org.elasticsearch.xcontent.ToXContentObject; + +public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable { + ChunkingStrategy getChunkingStrategy(); +} diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java b/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java new file mode 100644 index 0000000000000..bb5e0254834a3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.Strings; + +import java.util.EnumSet; + +public enum ChunkingStrategy { + WORD("word"), + SENTENCE("sentence"); + + private final String chunkingStrategy; + + ChunkingStrategy(String strategy) { + this.chunkingStrategy = strategy; + } + + @Override + public String toString() { + return chunkingStrategy; + } + + public static ChunkingStrategy fromString(String strategy) { + return EnumSet.allOf(ChunkingStrategy.class) + .stream() + .filter(cs -> cs.chunkingStrategy.equals(strategy)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException(Strings.format("Invalid chunkingStrategy %s", strategy))); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java index 6a562d58f1c8a..e5bd5a629a912 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java @@ -29,6 +29,7 @@ public class ModelConfigurations implements ToFilteredXContentObject, VersionedN public static final String SERVICE = "service"; public static final String SERVICE_SETTINGS = "service_settings"; public static final String TASK_SETTINGS = "task_settings"; + public static final String CHUNKING_SETTINGS = "chunking_settings"; private static final String NAME = "inference_model"; public static ModelConfigurations of(Model model, TaskSettings taskSettings) { @@ -40,7 +41,8 @@ public static ModelConfigurations of(Model model, TaskSettings taskSettings) { model.getConfigurations().getTaskType(), model.getConfigurations().getService(), model.getServiceSettings(), - taskSettings + taskSettings, + model.getConfigurations().getChunkingSettings() ); } @@ -53,7 +55,8 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting model.getConfigurations().getTaskType(), model.getConfigurations().getService(), serviceSettings, - model.getTaskSettings() + model.getTaskSettings(), + model.getConfigurations().getChunkingSettings() ); } @@ -62,6 +65,7 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting private final String service; private final ServiceSettings serviceSettings; private final TaskSettings taskSettings; + private final ChunkingSettings chunkingSettings; /** * Allows no task settings to be defined. This will default to the {@link EmptyTaskSettings} object. @@ -82,6 +86,23 @@ public ModelConfigurations( this.service = Objects.requireNonNull(service); this.serviceSettings = Objects.requireNonNull(serviceSettings); this.taskSettings = Objects.requireNonNull(taskSettings); + this.chunkingSettings = null; + } + + public ModelConfigurations( + String inferenceEntityId, + TaskType taskType, + String service, + ServiceSettings serviceSettings, + TaskSettings taskSettings, + ChunkingSettings chunkingSettings + ) { + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.taskType = Objects.requireNonNull(taskType); + this.service = Objects.requireNonNull(service); + this.serviceSettings = Objects.requireNonNull(serviceSettings); + this.taskSettings = Objects.requireNonNull(taskSettings); + this.chunkingSettings = chunkingSettings; } public ModelConfigurations(StreamInput in) throws IOException { @@ -90,6 +111,9 @@ public ModelConfigurations(StreamInput in) throws IOException { this.service = in.readString(); this.serviceSettings = in.readNamedWriteable(ServiceSettings.class); this.taskSettings = in.readNamedWriteable(TaskSettings.class); + this.chunkingSettings = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS) + ? in.readOptionalNamedWriteable(ChunkingSettings.class) + : null; } @Override @@ -99,6 +123,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(service); out.writeNamedWriteable(serviceSettings); out.writeNamedWriteable(taskSettings); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)) { + out.writeOptionalNamedWriteable(chunkingSettings); + } } public String getInferenceEntityId() { @@ -121,6 +148,10 @@ public TaskSettings getTaskSettings() { return taskSettings; } + public ChunkingSettings getChunkingSettings() { + return chunkingSettings; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -133,6 +164,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(SERVICE, service); builder.field(SERVICE_SETTINGS, serviceSettings); builder.field(TASK_SETTINGS, taskSettings); + if (chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS, chunkingSettings); + } builder.endObject(); return builder; } @@ -149,6 +183,9 @@ public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params builder.field(SERVICE, service); builder.field(SERVICE_SETTINGS, serviceSettings.getFilteredXContentObject()); builder.field(TASK_SETTINGS, taskSettings); + if (chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS, chunkingSettings); + } builder.endObject(); return builder; } diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index d3622575be3a2..cb98f9de31ff5 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -17,7 +17,8 @@ */ public enum FeatureFlag { TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null), - FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null); + FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null), + CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/ChunkingSettingsFeatureFlag.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/ChunkingSettingsFeatureFlag.java new file mode 100644 index 0000000000000..fae69058df565 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/ChunkingSettingsFeatureFlag.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * chunking_settings feature flag. When the feature is complete, this flag will be removed. + */ +public class ChunkingSettingsFeatureFlag { + + private ChunkingSettingsFeatureFlag() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_chunking_settings"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java index 38c0768196142..ca1dd5a71ea2f 100644 --- a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java @@ -19,7 +19,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasSize; @@ -29,6 +29,7 @@ public class OpenAIServiceMixedIT extends BaseMixedTestCase { private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0"; private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0"; + private static final String OPEN_AI_EMBEDDINGS_CHUNKING_SETTINGS_ADDED = "8.16.0"; private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0"; private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0"; @@ -50,6 +51,7 @@ public static void shutdown() { openAiChatCompletionsServer.close(); } + @AwaitsFix(bugUrl = "Backport #112074 to 8.16") @SuppressWarnings("unchecked") public void testOpenAiEmbeddings() throws IOException { var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED)); @@ -64,7 +66,23 @@ public void testOpenAiEmbeddings() throws IOException { String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig(); // queue a response as PUT will call the service openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING); + + try { + put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING); + } catch (Exception e) { + if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_CHUNKING_SETTINGS_ADDED)) { + // Chunking settings were added in 8.16.0. if the version is before that, an exception will be thrown if the index mapping + // was created based on a mapping from an old node + assertThat( + e.getMessage(), + containsString( + "One or more nodes in your cluster does not support chunking_settings. " + + "Please update all nodes in your cluster to use chunking_settings." + ) + ); + return; + } + } var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints"); assertThat(configs, hasSize(1)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java index 4f1697bc683c1..1c93494d78636 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java @@ -24,9 +24,10 @@ private InferenceIndex() {} public static final String INDEX_NAME = ".inference"; public static final String INDEX_PATTERN = INDEX_NAME + "*"; + public static final String INDEX_ALIAS = ".inference-alias"; // Increment this version number when the mappings change - private static final int INDEX_MAPPING_VERSION = 1; + private static final int INDEX_MAPPING_VERSION = 2; public static Settings settings() { return Settings.builder() @@ -84,6 +85,50 @@ public static XContentBuilder mappings() { .startObject("properties") .endObject() .endObject() + .startObject("chunking_settings") + .field("dynamic", "false") + .startObject("properties") + .startObject("strategy") + .field("type", "keyword") + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to build mappings for index " + INDEX_NAME, e); + } + } + + public static XContentBuilder mappingsV1() { + try { + return jsonBuilder().startObject() + .startObject(SINGLE_MAPPING_NAME) + .startObject("_meta") + .field(SystemIndexDescriptor.VERSION_META_KEY, 1) + .endObject() + .field("dynamic", "strict") + .startObject("properties") + .startObject("model_id") + .field("type", "keyword") + .endObject() + .startObject("task_type") + .field("type", "keyword") + .endObject() + .startObject("service") + .field("type", "keyword") + .endObject() + .startObject("service_settings") + .field("dynamic", "false") + .startObject("properties") + .endObject() + .endObject() + .startObject("task_settings") + .field("dynamic", "false") + .startObject("properties") + .endObject() + .endObject() .endObject() .endObject() .endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 9b7b9a0802640..336626cd1db20 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; @@ -26,6 +27,8 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings; @@ -108,6 +111,8 @@ public static List getNamedWriteables() { // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); + addChunkingSettingsNamedWriteables(namedWriteables); + // Empty default secret settings namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new)); @@ -444,6 +449,19 @@ private static void addChunkedInferenceResultsNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ChunkingSettings.class, + SentenceBoundaryChunkingSettings.NAME, + SentenceBoundaryChunkingSettings::new + ) + ); + } + private static void addInferenceResultsNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a6972ddc214fc..16bd0942c6c26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -257,16 +257,31 @@ public List getNamedWriteables() { @Override public Collection getSystemIndexDescriptors(Settings settings) { + + var inferenceIndexV1Descriptor = SystemIndexDescriptor.builder() + .setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED) + .setIndexPattern(InferenceIndex.INDEX_PATTERN) + .setAliasName(InferenceIndex.INDEX_ALIAS) + .setPrimaryIndex(InferenceIndex.INDEX_NAME) + .setDescription("Contains inference service and model configuration") + .setMappings(InferenceIndex.mappingsV1()) + .setSettings(InferenceIndex.settings()) + .setVersionMetaKey("version") + .setOrigin(ClientHelper.INFERENCE_ORIGIN) + .build(); + return List.of( SystemIndexDescriptor.builder() .setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED) .setIndexPattern(InferenceIndex.INDEX_PATTERN) + .setAliasName(InferenceIndex.INDEX_ALIAS) .setPrimaryIndex(InferenceIndex.INDEX_NAME) .setDescription("Contains inference service and model configuration") .setMappings(InferenceIndex.mappings()) .setSettings(InferenceIndex.settings()) .setVersionMetaKey("version") .setOrigin(ClientHelper.INFERENCE_ORIGIN) + .setPriorSystemIndexDescriptors(List.of(inferenceIndexV1Descriptor)) .build(), SystemIndexDescriptor.builder() .setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index b5d57e7afa6e7..ec54294432fe8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.StrictDynamicMappingException; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; @@ -198,7 +199,19 @@ private void parseAndStoreModel( ActionListener storeModelListener = listener.delegateFailureAndWrap( (delegate, verifiedModel) -> modelRegistry.storeModel( verifiedModel, - delegate.delegateFailureAndWrap((l, r) -> putAndStartModel(service, verifiedModel, l)) + ActionListener.wrap(r -> putAndStartModel(service, verifiedModel, delegate), e -> { + if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) { + delegate.onFailure( + new ElasticsearchStatusException( + "One or more nodes in your cluster does not support chunking_settings. " + + "Please update all nodes in your cluster to the latest version to use chunking_settings.", + RestStatus.BAD_REQUEST + ) + ); + } else { + delegate.onFailure(e); + } + }) ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java new file mode 100644 index 0000000000000..af7c706c807ec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingSettings; + +import java.util.List; + +public interface Chunker { + List chunk(String input, ChunkingSettings chunkingSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java new file mode 100644 index 0000000000000..830f1579348f6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingStrategy; + +public class ChunkerBuilder { + public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) { + if (chunkingStrategy == null) { + return new WordBoundaryChunker(); + } + + return switch (chunkingStrategy) { + case WORD -> new WordBoundaryChunker(); + case SENTENCE -> new SentenceBoundaryChunker(); + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java new file mode 100644 index 0000000000000..477c3ea6352f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; + +import java.util.Map; + +public class ChunkingSettingsBuilder { + public static final WordBoundaryChunkingSettings DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); + + public static ChunkingSettings fromMap(Map settings) { + if (settings.isEmpty()) { + return DEFAULT_SETTINGS; + } + if (settings.containsKey(ChunkingSettingsOptions.STRATEGY.toString()) == false) { + throw new IllegalArgumentException("Can't generate Chunker without ChunkingStrategy provided"); + } + + ChunkingStrategy chunkingStrategy = ChunkingStrategy.fromString( + settings.get(ChunkingSettingsOptions.STRATEGY.toString()).toString() + ); + return switch (chunkingStrategy) { + case WORD -> WordBoundaryChunkingSettings.fromMap(settings); + case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(settings); + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java new file mode 100644 index 0000000000000..a85b92dd1a055 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +public enum ChunkingSettingsOptions { + STRATEGY("strategy"), + MAX_CHUNK_SIZE("max_chunk_size"), + OVERLAP("overlap"); + + private final String chunkingSettingsOption; + + ChunkingSettingsOptions(String chunkingSettingsOption) { + this.chunkingSettingsOption = chunkingSettingsOption; + } + + @Override + public String toString() { + return chunkingSettingsOption; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 7587dbf8ca95b..81ebebdb47e4f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; @@ -23,6 +24,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -60,6 +62,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private final int wordsPerChunk; private final int chunkOverlap; private final EmbeddingType embeddingType; + private final ChunkingSettings chunkingSettings; private List> chunkedInputs; private List>> floatResults; @@ -82,11 +85,34 @@ public EmbeddingRequestChunker( this.wordsPerChunk = wordsPerChunk; this.chunkOverlap = chunkOverlap; this.embeddingType = embeddingType; + this.chunkingSettings = null; + splitIntoBatchedRequests(inputs); + } + + public EmbeddingRequestChunker( + List inputs, + int maxNumberOfInputsPerBatch, + EmbeddingType embeddingType, + ChunkingSettings chunkingSettings + ) { + this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch; + this.wordsPerChunk = DEFAULT_WORDS_PER_CHUNK; // Can be removed after ChunkingConfigurationFeatureFlag is enabled + this.chunkOverlap = DEFAULT_CHUNK_OVERLAP; // Can be removed after ChunkingConfigurationFeatureFlag is enabled + this.embeddingType = embeddingType; + this.chunkingSettings = chunkingSettings; splitIntoBatchedRequests(inputs); } private void splitIntoBatchedRequests(List inputs) { - var chunker = new WordBoundaryChunker(); + Function> chunkFunction; + if (chunkingSettings != null) { + var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + chunkFunction = input -> chunker.chunk(input, chunkingSettings); + } else { + var chunker = new WordBoundaryChunker(); + chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap); + } + chunkedInputs = new ArrayList<>(inputs.size()); switch (embeddingType) { case FLOAT -> floatResults = new ArrayList<>(inputs.size()); @@ -95,7 +121,7 @@ private void splitIntoBatchedRequests(List inputs) { errors = new AtomicArray<>(inputs.size()); for (int i = 0; i < inputs.size(); i++) { - var chunks = chunker.chunk(inputs.get(i), wordsPerChunk, chunkOverlap); + var chunks = chunkFunction.apply(inputs.get(i)); int numberOfSubBatches = addToBatches(chunks, i); // size the results array with the expected number of request/responses switch (embeddingType) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 258a127dac8ab..3a53ecc7ae958 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -9,6 +9,9 @@ import com.ibm.icu.text.BreakIterator; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.ChunkingSettings; + import java.util.ArrayList; import java.util.List; import java.util.Locale; @@ -23,7 +26,7 @@ * {@code maxNumberWordsPerChunk} it is split on word boundary with * overlap. */ -public class SentenceBoundaryChunker { +public class SentenceBoundaryChunker implements Chunker { private final BreakIterator sentenceIterator; private final BreakIterator wordIterator; @@ -33,6 +36,27 @@ public SentenceBoundaryChunker() { wordIterator = BreakIterator.getWordInstance(Locale.ROOT); } + /** + * Break the input text into small chunks on sentence boundaries. + * + * @param input Text to chunk + * @param chunkingSettings Chunking settings that define maxNumberWordsPerChunk + * @return The input text chunked + */ + @Override + public List chunk(String input, ChunkingSettings chunkingSettings) { + if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) { + return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize); + } else { + throw new IllegalArgumentException( + Strings.format( + "SentenceBoundaryChunker can't use ChunkingSettings with strategy [%s]", + chunkingSettings.getChunkingStrategy() + ) + ); + } + } + /** * Break the input text into small chunks on sentence boundaries. * diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java new file mode 100644 index 0000000000000..0d1903895f615 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class SentenceBoundaryChunkingSettings implements ChunkingSettings { + public static final String NAME = "SentenceBoundaryChunkingSettings"; + private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE; + private static final Set VALID_KEYS = Set.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString() + ); + protected final int maxChunkSize; + + public SentenceBoundaryChunkingSettings(Integer maxChunkSize) { + this.maxChunkSize = maxChunkSize; + } + + public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { + maxChunkSize = in.readInt(); + } + + public static SentenceBoundaryChunkingSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray(); + if (invalidSettings.length > 0) { + validationException.addValidationError( + Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) + ); + } + + Integer maxChunkSize = ServiceUtils.extractRequiredPositiveInteger( + map, + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + ModelConfigurations.CHUNKING_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new SentenceBoundaryChunkingSettings(maxChunkSize); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY); + builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(maxChunkSize); + } + + @Override + public ChunkingStrategy getChunkingStrategy() { + return STRATEGY; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SentenceBoundaryChunkingSettings that = (SentenceBoundaryChunkingSettings) o; + return Objects.equals(maxChunkSize, that.maxChunkSize); + } + + @Override + public int hashCode() { + return Objects.hash(maxChunkSize); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index 4233f917f8f80..c9c752b9aabbc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -9,6 +9,9 @@ import com.ibm.icu.text.BreakIterator; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.ChunkingSettings; + import java.util.ArrayList; import java.util.List; import java.util.Locale; @@ -24,7 +27,7 @@ * complexity of tracking the start positions of multiple * chunks within the chunk. */ -public class WordBoundaryChunker { +public class WordBoundaryChunker implements Chunker { private BreakIterator wordIterator; @@ -34,6 +37,24 @@ public WordBoundaryChunker() { record ChunkPosition(int start, int end, int wordCount) {} + /** + * Break the input text into small chunks as dictated + * by the chunking parameters + * @param input Text to chunk + * @param chunkingSettings The chunking settings that configure chunkSize and overlap + * @return List of chunked text + */ + @Override + public List chunk(String input, ChunkingSettings chunkingSettings) { + if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) { + return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap); + } else { + throw new IllegalArgumentException( + Strings.format("WordBoundaryChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy()) + ); + } + } + /** * Break the input text into small chunks as dictated * by the chunking parameters diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java new file mode 100644 index 0000000000000..6517e0eea14d9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class WordBoundaryChunkingSettings implements ChunkingSettings { + public static final String NAME = "WordBoundaryChunkingSettings"; + private static final ChunkingStrategy STRATEGY = ChunkingStrategy.WORD; + private static final Set VALID_KEYS = Set.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + ChunkingSettingsOptions.OVERLAP.toString() + ); + protected final int maxChunkSize; + protected final int overlap; + + public WordBoundaryChunkingSettings(Integer maxChunkSize, Integer overlap) { + this.maxChunkSize = maxChunkSize; + this.overlap = overlap; + } + + public WordBoundaryChunkingSettings(StreamInput in) throws IOException { + maxChunkSize = in.readInt(); + overlap = in.readInt(); + } + + public static WordBoundaryChunkingSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray(); + if (invalidSettings.length > 0) { + validationException.addValidationError( + Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) + ); + } + + Integer maxChunkSize = ServiceUtils.extractRequiredPositiveInteger( + map, + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + ModelConfigurations.CHUNKING_SETTINGS, + validationException + ); + + Integer overlap = null; + if (maxChunkSize != null) { + overlap = ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax( + map, + ChunkingSettingsOptions.OVERLAP.toString(), + maxChunkSize / 2, + ModelConfigurations.CHUNKING_SETTINGS, + validationException + ); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new WordBoundaryChunkingSettings(maxChunkSize, overlap); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY); + builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + builder.field(ChunkingSettingsOptions.OVERLAP.toString(), overlap); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(maxChunkSize); + out.writeInt(overlap); + } + + @Override + public ChunkingStrategy getChunkingStrategy() { + return STRATEGY; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WordBoundaryChunkingSettings that = (WordBoundaryChunkingSettings) o; + return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(overlap, that.overlap); + } + + @Override + public int hashCode() { + return Objects.hash(maxChunkSize, overlap); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index c75ded629605f..6c4904f8918a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -410,6 +410,24 @@ public static Integer extractRequiredPositiveInteger( return field; } + public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax( + Map map, + String settingName, + int maxValue, + String scope, + ValidationException validationException + ) { + Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException); + + if (field != null && field > maxValue) { + validationException.addValidationError( + ServiceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, field, maxValue) + ); + } + + return field; + } + public static Integer extractOptionalPositiveInteger( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index cee3ccf676c4c..7cea1ec7df46c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -24,6 +25,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -73,6 +76,13 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap); OpenAiModel model = createModel( @@ -80,6 +90,7 @@ public void parseRequestConfig( taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -100,6 +111,7 @@ private static OpenAiModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage ) { @@ -108,6 +120,7 @@ private static OpenAiModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -119,6 +132,7 @@ private static OpenAiModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -130,6 +144,7 @@ private static OpenAiModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -157,6 +172,11 @@ public OpenAiModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap); return createModelFromPersistent( @@ -164,6 +184,7 @@ public OpenAiModel parsePersistedConfigWithSecrets( taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -174,6 +195,11 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap); return createModelFromPersistent( @@ -181,6 +207,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -225,11 +252,22 @@ protected void doChunkedInfer( OpenAiModel openAiModel = (OpenAiModel) model; var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + openAiModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java index 18a1d8a5b658f..5659c46050ad8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.openai.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -36,6 +37,7 @@ public OpenAiEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -45,6 +47,7 @@ public OpenAiEmbeddingsModel( service, OpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), OpenAiEmbeddingsTaskSettings.fromMap(taskSettings, context), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -56,10 +59,11 @@ public OpenAiEmbeddingsModel( String service, OpenAiEmbeddingsServiceSettings serviceSettings, OpenAiEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings, secrets diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index 5afae297b3592..5a1922fd200f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -14,6 +14,8 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; @@ -26,7 +28,8 @@ public static ModelConfigurations createRandomInstance() { taskType, randomAlphaOfLength(6), randomServiceSettings(), - randomTaskSettings(taskType) + randomTaskSettings(taskType), + ChunkingSettingsFeatureFlag.isEnabled() && randomBoolean() ? ChunkingSettingsTests.createRandomChunkingSettings() : null ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index fb841bd6953cb..5abb9000f4d04 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -173,6 +173,19 @@ public static SimilarityMeasure randomSimilarityMeasure() { public record PersistedConfig(Map config, Map secrets) {} + public static PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + + var persistedConfigMap = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + persistedConfigMap.config.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return persistedConfigMap; + } + public static PersistedConfig getPersistedConfigMap( Map serviceSettings, Map taskSettings, @@ -197,6 +210,18 @@ public static PersistedConfig getPersistedConfigMap(Map serviceS ); } + public static Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + public static Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java new file mode 100644 index 0000000000000..d2aea45d4603c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.ESTestCase; + +import java.util.Map; + +import static org.hamcrest.Matchers.instanceOf; + +public class ChunkerBuilderTests extends ESTestCase { + + public void testNullChunkingStrategy() { + assertThat(ChunkerBuilder.fromChunkingStrategy(null), instanceOf(WordBoundaryChunker.class)); + } + + public void testValidChunkingStrategy() { + chunkingStrategyToExpectedChunkerClassMap().forEach((chunkingStrategy, chunkerClass) -> { + assertThat(ChunkerBuilder.fromChunkingStrategy(chunkingStrategy), instanceOf(chunkerClass)); + }); + } + + private Map> chunkingStrategyToExpectedChunkerClassMap() { + return Map.of(ChunkingStrategy.WORD, WordBoundaryChunker.class, ChunkingStrategy.SENTENCE, SentenceBoundaryChunker.class); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java new file mode 100644 index 0000000000000..061ea677e6fe1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.ESTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class ChunkingSettingsBuilderTests extends ESTestCase { + + public static final WordBoundaryChunkingSettings DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); + + public void testEmptyChunkingSettingsMap() { + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(Collections.emptyMap()); + + assertEquals(DEFAULT_SETTINGS, chunkingSettings); + } + + public void testChunkingStrategyNotProvided() { + Map settings = Map.of(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomNonNegativeInt()); + + assertThrows(IllegalArgumentException.class, () -> { ChunkingSettingsBuilder.fromMap(settings); }); + } + + public void testValidChunkingSettingsMap() { + chunkingSettingsMapToChunkingSettings().forEach((chunkingSettingsMap, chunkingSettings) -> { + assertEquals(chunkingSettings, ChunkingSettingsBuilder.fromMap(new HashMap<>(chunkingSettingsMap))); + }); + } + + private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { + var maxChunkSize = randomNonNegativeInt(); + var overlap = randomIntBetween(1, maxChunkSize / 2); + return Map.of( + Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingStrategy.WORD.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.OVERLAP.toString(), + overlap + ), + new WordBoundaryChunkingSettings(maxChunkSize, overlap), + Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingStrategy.SENTENCE.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize + ), + new SentenceBoundaryChunkingSettings(maxChunkSize) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java new file mode 100644 index 0000000000000..2482586c75595 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.ESTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class ChunkingSettingsTests extends ESTestCase { + + public static ChunkingSettings createRandomChunkingSettings() { + ChunkingStrategy randomStrategy = randomFrom(ChunkingStrategy.values()); + + switch (randomStrategy) { + case WORD -> { + var maxChunkSize = randomNonNegativeInt(); + return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2)); + } + case SENTENCE -> { + return new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + } + default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]"); + } + } + + public static Map createRandomChunkingSettingsMap() { + ChunkingStrategy randomStrategy = randomFrom(ChunkingStrategy.values()); + Map chunkingSettingsMap = new HashMap<>(); + chunkingSettingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), randomStrategy.toString()); + + switch (randomStrategy) { + case WORD -> { + var maxChunkSize = randomNonNegativeInt(); + chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + chunkingSettingsMap.put(ChunkingSettingsOptions.OVERLAP.toString(), randomIntBetween(1, maxChunkSize / 2)); + + } + case SENTENCE -> { + chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomNonNegativeInt()); + } + default -> { + } + } + return chunkingSettingsMap; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java new file mode 100644 index 0000000000000..802cea5986b30 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.ESTestCase; + +public class ChunkingStrategyTests extends ESTestCase { + + public void testValidChunkingStrategy() { + ChunkingStrategy expected = randomFrom(ChunkingStrategy.values()); + + assertEquals(expected, ChunkingStrategy.fromString(expected.toString())); + } + + public void testInvalidChunkingStrategy() { + assertThrows(IllegalArgumentException.class, () -> ChunkingStrategy.fromString(randomAlphaOfLength(10))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index cb89846b197fc..cf862ee6fb4b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -94,6 +94,50 @@ public void testManyInputsMakeManyBatches() { } } + public void testChunkingSettingsProvided() { + int maxNumInputsPerBatch = 10; + int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches + var inputs = new ArrayList(); + + for (int i = 0; i < numInputs; i++) { + inputs.add("input " + i); + } + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + + var batches = new EmbeddingRequestChunker( + inputs, + maxNumInputsPerBatch, + embeddingType, + ChunkingSettingsTests.createRandomChunkingSettings() + ).batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(4)); + assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(3).batch().inputs(), hasSize(1)); + + assertEquals("input 0", batches.get(0).batch().inputs().get(0)); + assertEquals("input 9", batches.get(0).batch().inputs().get(9)); + assertThat( + batches.get(1).batch().inputs(), + contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") + ); + assertEquals("input 20", batches.get(2).batch().inputs().get(0)); + assertEquals("input 29", batches.get(2).batch().inputs().get(9)); + assertThat(batches.get(3).batch().inputs(), contains("input 30")); + + int inputIndex = 0; + var subBatches = batches.get(0).batch().subBatches(); + for (int i = 0; i < batches.size(); i++) { + var subBatch = subBatches.get(i); + assertThat(subBatch.requests(), contains(inputs.get(i))); + assertEquals(0, subBatch.positions().chunkIndex()); + assertEquals(inputIndex, subBatch.positions().inputIndex()); + assertEquals(1, subBatch.positions().embeddingCount()); + inputIndex++; + } + } + public void testLongInputChunkedOverMultipleBatches() { int batchSize = 5; int chunkSize = 20; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index 5bf282a07067a..335752faa6b22 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -9,6 +9,7 @@ import com.ibm.icu.text.BreakIterator; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; @@ -144,6 +145,29 @@ public void testCountWords_WithSymbols() { } } + public void testChunkSplitLargeChunkSizesWithChunkingSettings() { + for (int maxWordsPerChunk : new int[] { 100, 200 }) { + var chunker = new SentenceBoundaryChunker(); + SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk); + var chunks = chunker.chunk(TEST_TEXT, chunkingSettings); + + int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); + assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); + + for (var chunk : chunks) { + assertTrue(Character.isUpperCase(chunk.charAt(0))); + var trailingWhiteSpaceRemoved = chunk.strip(); + var lastChar = trailingWhiteSpaceRemoved.charAt(trailingWhiteSpaceRemoved.length() - 1); + assertThat(lastChar, Matchers.is('.')); + } + } + } + + public void testInvalidChunkingSettingsProvided() { + ChunkingSettings chunkingSettings = new WordBoundaryChunkingSettings(randomNonNegativeInt(), randomNonNegativeInt()); + assertThrows(IllegalArgumentException.class, () -> { new SentenceBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); }); + } + private int[] sentenceSizes(String text) { var sentences = text.split("\\.\\s+"); var lengths = new int[sentences.length]; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java new file mode 100644 index 0000000000000..3f304a593144b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public class SentenceBoundaryChunkingSettingsTests extends AbstractWireSerializingTestCase { + + public void testMaxChunkSizeNotProvided() { + assertThrows( + ValidationException.class, + () -> { SentenceBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.empty())); } + ); + } + + public void testInvalidInputsProvided() { + var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomNonNegativeInt())); + chunkingSettingsMap.put(randomAlphaOfLength(10), randomNonNegativeInt()); + + assertThrows(ValidationException.class, () -> { SentenceBoundaryChunkingSettings.fromMap(chunkingSettingsMap); }); + } + + public void testValidInputsProvided() { + int maxChunkSize = randomNonNegativeInt(); + SentenceBoundaryChunkingSettings settings = SentenceBoundaryChunkingSettings.fromMap( + buildChunkingSettingsMap(Optional.of(maxChunkSize)) + ); + + assertEquals(settings.getChunkingStrategy(), ChunkingStrategy.SENTENCE); + assertEquals(settings.maxChunkSize, maxChunkSize); + } + + public Map buildChunkingSettingsMap(Optional maxChunkSize) { + Map settingsMap = new HashMap<>(); + settingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.SENTENCE.toString()); + maxChunkSize.ifPresent(maxChunkSizeValue -> settingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSizeValue)); + + return settingsMap; + } + + @Override + protected Writeable.Reader instanceReader() { + return SentenceBoundaryChunkingSettings::new; + } + + @Override + protected SentenceBoundaryChunkingSettings createTestInstance() { + return new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + } + + @Override + protected SentenceBoundaryChunkingSettings mutateInstance(SentenceBoundaryChunkingSettings instance) throws IOException { + var chunkSize = randomValueOtherThan(instance.maxChunkSize, ESTestCase::randomNonNegativeInt); + + return new SentenceBoundaryChunkingSettings(chunkSize); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index 864d01507ca35..21d8c65ad7dcd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -9,6 +9,7 @@ import com.ibm.icu.text.BreakIterator; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.test.ESTestCase; import java.util.List; @@ -53,6 +54,9 @@ public class WordBoundaryChunkerTests extends ESTestCase { + " خليفہ المومنين يا خليفہ المسلمين يا صحابی يا رضي الله عنه چئي۔ (ب) آنحضور ﷺ جي گھروارين کان علاوه ڪنھن کي ام المومنين " + "چئي۔ (ج) آنحضور ﷺ جي خاندان جي اھل بيت کان علاوہڍه ڪنھن کي اھل بيت چئي۔ (د) پنھنجي عبادت گاھ کي مسجد چئي۔" }; + private static final int DEFAULT_MAX_CHUNK_SIZE = 250; + private static final int DEFAULT_OVERLAP = 100; + public static int NUM_WORDS_IN_TEST_TEXT; static { var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); @@ -104,6 +108,41 @@ public void testNumberOfChunks() { } } + public void testNumberOfChunksWithWordBoundaryChunkingSettings() { + for (int numWords : new int[] { 10, 22, 50, 73, 100 }) { + var sb = new StringBuilder(); + for (int i = 0; i < numWords; i++) { + sb.append(i).append(' '); + } + var whiteSpacedText = sb.toString(); + assertExpectedNumberOfChunksWithWordBoundaryChunkingSettings( + whiteSpacedText, + numWords, + new WordBoundaryChunkingSettings(10, 4) + ); + assertExpectedNumberOfChunksWithWordBoundaryChunkingSettings( + whiteSpacedText, + numWords, + new WordBoundaryChunkingSettings(10, 2) + ); + assertExpectedNumberOfChunksWithWordBoundaryChunkingSettings( + whiteSpacedText, + numWords, + new WordBoundaryChunkingSettings(20, 4) + ); + assertExpectedNumberOfChunksWithWordBoundaryChunkingSettings( + whiteSpacedText, + numWords, + new WordBoundaryChunkingSettings(20, 10) + ); + } + } + + public void testInvalidChunkingSettingsProvided() { + ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + assertThrows(IllegalArgumentException.class, () -> { new WordBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); }); + } + public void testWindowSpanningWithOverlapNumWordsInOverlapSection() { int chunkSize = 10; int windowSize = 3; @@ -206,6 +245,16 @@ public void testPunctuation() { assertThat(chunks, contains("Won't you chunk")); } + private void assertExpectedNumberOfChunksWithWordBoundaryChunkingSettings( + String input, + int numWords, + WordBoundaryChunkingSettings chunkingSettings + ) { + var chunks = new WordBoundaryChunker().chunk(input, chunkingSettings); + int expected = expectedNumberOfChunks(numWords, chunkingSettings.maxChunkSize, chunkingSettings.overlap); + assertEquals(expected, chunks.size()); + } + private void assertExpectedNumberOfChunks(String input, int numWords, int windowSize, int overlap) { var chunks = new WordBoundaryChunker().chunk(input, windowSize, overlap); int expected = expectedNumberOfChunks(numWords, windowSize, overlap); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java new file mode 100644 index 0000000000000..c5515f7bf0512 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class WordBoundaryChunkingSettingsTests extends AbstractWireSerializingTestCase { + + public void testMaxChunkSizeNotProvided() { + assertThrows(ValidationException.class, () -> { + WordBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.empty(), Optional.of(randomNonNegativeInt()))); + }); + } + + public void testOverlapNotProvided() { + assertThrows(ValidationException.class, () -> { + WordBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.of(randomNonNegativeInt()), Optional.empty())); + }); + } + + public void testInvalidInputsProvided() { + var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomNonNegativeInt()), Optional.of(randomNonNegativeInt())); + chunkingSettingsMap.put(randomAlphaOfLength(10), randomNonNegativeInt()); + + assertThrows(ValidationException.class, () -> { WordBoundaryChunkingSettings.fromMap(chunkingSettingsMap); }); + } + + public void testOverlapGreaterThanHalfMaxChunkSize() { + var maxChunkSize = randomNonNegativeInt(); + var overlap = randomIntBetween((maxChunkSize / 2) + 1, maxChunkSize); + assertThrows(ValidationException.class, () -> { + WordBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.of(maxChunkSize), Optional.of(overlap))); + }); + } + + public void testValidInputsProvided() { + int maxChunkSize = randomNonNegativeInt(); + int overlap = randomIntBetween(1, maxChunkSize / 2); + WordBoundaryChunkingSettings settings = WordBoundaryChunkingSettings.fromMap( + buildChunkingSettingsMap(Optional.of(maxChunkSize), Optional.of(overlap)) + ); + + assertEquals(settings.getChunkingStrategy(), ChunkingStrategy.WORD); + assertEquals(settings.maxChunkSize, maxChunkSize); + assertEquals(settings.overlap, overlap); + } + + public Map buildChunkingSettingsMap(Optional maxChunkSize, Optional overlap) { + Map settingsMap = new HashMap<>(); + settingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.WORD.toString()); + maxChunkSize.ifPresent(maxChunkSizeValue -> settingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSizeValue)); + overlap.ifPresent(overlapValue -> settingsMap.put(ChunkingSettingsOptions.OVERLAP.toString(), overlapValue)); + + return settingsMap; + } + + @Override + protected Writeable.Reader instanceReader() { + return WordBoundaryChunkingSettings::new; + } + + @Override + protected WordBoundaryChunkingSettings createTestInstance() { + var maxChunkSize = randomNonNegativeInt(); + return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2)); + } + + @Override + protected WordBoundaryChunkingSettings mutateInstance(WordBoundaryChunkingSettings instance) throws IOException { + var valueToMutate = randomFrom(List.of(ChunkingSettingsOptions.MAX_CHUNK_SIZE, ChunkingSettingsOptions.OVERLAP)); + var maxChunkSize = instance.maxChunkSize; + var overlap = instance.overlap; + + if (valueToMutate.equals(ChunkingSettingsOptions.MAX_CHUNK_SIZE)) { + while (maxChunkSize == instance.maxChunkSize) { + maxChunkSize = randomNonNegativeInt(); + } + + if (overlap > maxChunkSize / 2) { + overlap = randomIntBetween(1, maxChunkSize / 2); + } + } else if (valueToMutate.equals(ChunkingSettingsOptions.OVERLAP)) { + while (overlap == instance.overlap) { + overlap = randomIntBetween(1, maxChunkSize / 2); + } + } + + return new WordBoundaryChunkingSettings(maxChunkSize, overlap); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 86af5e431d78d..e5f0989b43976 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalTimeValue; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.getEmbeddingSize; @@ -545,6 +546,64 @@ public void testExtractRequiredPositiveInteger_AddsErrorWhenKeyIsMissing() { assertThat(validation.validationErrors().get(1), is("[scope] does not contain the required setting [not_key]")); } + public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_ReturnsValueWhenValueIsLessThanMax() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", 1)); + var parsedInt = extractRequiredPositiveIntegerLessThanOrEqualToMax(map, "key", 5, "scope", validation); + + assertThat(validation.validationErrors(), hasSize(1)); + assertNotNull(parsedInt); + assertThat(parsedInt, is(1)); + assertTrue(map.isEmpty()); + } + + public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_ReturnsValueWhenValueIsEqualToMax() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", 5)); + var parsedInt = extractRequiredPositiveIntegerLessThanOrEqualToMax(map, "key", 5, "scope", validation); + + assertThat(validation.validationErrors(), hasSize(1)); + assertNotNull(parsedInt); + assertThat(parsedInt, is(5)); + assertTrue(map.isEmpty()); + } + + public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_AddsErrorForNegativeValue() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", -1)); + var parsedInt = extractRequiredPositiveIntegerLessThanOrEqualToMax(map, "key", 5, "scope", validation); + + assertThat(validation.validationErrors(), hasSize(2)); + assertNull(parsedInt); + assertTrue(map.isEmpty()); + assertThat(validation.validationErrors().get(1), is("[scope] Invalid value [-1]. [key] must be a positive integer")); + } + + public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_AddsErrorWhenKeyIsMissing() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", -1)); + var parsedInt = extractRequiredPositiveIntegerLessThanOrEqualToMax(map, "not_key", 5, "scope", validation); + + assertThat(validation.validationErrors(), hasSize(2)); + assertNull(parsedInt); + assertThat(validation.validationErrors().get(1), is("[scope] does not contain the required setting [not_key]")); + } + + public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_AddsErrorWhenValueIsGreaterThanMax() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", 6)); + var parsedInt = extractRequiredPositiveIntegerLessThanOrEqualToMax(map, "not_key", 5, "scope", validation); + + assertThat(validation.validationErrors(), hasSize(2)); + assertNull(parsedInt); + assertThat(validation.validationErrors().get(1), is("[scope] does not contain the required setting [not_key]")); + } + public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", "value")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index dbc365f3d6919..9ea6b61fa53db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -18,10 +18,10 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -29,8 +29,10 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -56,8 +58,10 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; @@ -340,6 +344,90 @@ public void testParseRequestConfig_MovesModel() throws IOException { } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", null, null), + getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", null, null), + getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")), + Set.of(), + modelVerificationListener + ); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( @@ -417,6 +505,95 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWi } } + public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( @@ -612,6 +789,77 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUr } } + public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutChunkingSettingsWhenChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().uri()); + assertNull(embeddingsModel.getServiceSettings().organizationId()); + assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); + assertNull(embeddingsModel.getTaskSettings().user()); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( @@ -1262,6 +1510,32 @@ public void testMoveModelFromTaskToServiceSettings_AlreadyMoved() { } public void testChunkedInfer_Batches() throws IOException { + var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = OpenAiEmbeddingsModelTests.createModel( + getUrl(webServer), + "org", + "secret", + "model", + "user", + ChunkingSettingsTests.createRandomChunkingSettings() + ); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", (ChunkingSettings) null); + + testChunkedInfer(model); + } + + private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { @@ -1297,7 +1571,6 @@ public void testChunkedInfer_Batches() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -1343,18 +1616,4 @@ public void testChunkedInfer_Batches() throws IOException { private OpenAiService createOpenAiService() { return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } - - private Map getRequestConfigMap( - Map serviceSettings, - Map taskSettings, - Map secretSettings - ) { - var builtServiceSettings = new HashMap<>(); - builtServiceSettings.putAll(serviceSettings); - builtServiceSettings.putAll(secretSettings); - - return new HashMap<>( - Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) - ); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java index 86b7f4421954d..0e9179792b92b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -61,6 +62,26 @@ public static OpenAiEmbeddingsModel createModel( "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null), new OpenAiEmbeddingsTaskSettings(user), + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static OpenAiEmbeddingsModel createModel( + String url, + @Nullable String org, + String apiKey, + String modelName, + @Nullable String user, + ChunkingSettings chunkingSettings + ) { + return new OpenAiEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null), + new OpenAiEmbeddingsTaskSettings(user), + chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -78,6 +99,7 @@ public static OpenAiEmbeddingsModel createModel( "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null), new OpenAiEmbeddingsTaskSettings(user), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -96,6 +118,7 @@ public static OpenAiEmbeddingsModel createModel( "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit, false, null), new OpenAiEmbeddingsTaskSettings(user), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -115,6 +138,7 @@ public static OpenAiEmbeddingsModel createModel( "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, false, null), new OpenAiEmbeddingsTaskSettings(user), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -136,6 +160,7 @@ public static OpenAiEmbeddingsModel createModel( "service", new OpenAiEmbeddingsServiceSettings(modelName, url, org, similarityMeasure, dimensions, tokenLimit, dimensionsSetByUser, null), new OpenAiEmbeddingsTaskSettings(user), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); }