Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_TELEMETRY_MEMORY_ADDED = def(8_748_00_0);
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0);
public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0);
public static final TransportVersion ML_INFERENCE_CHUNKING_SETTINGS = def(8_751_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {
ChunkingStrategy getChunkingStrategy();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import org.elasticsearch.common.Strings;

import java.util.EnumSet;

public enum ChunkingStrategy {
WORD("word"),
SENTENCE("sentence");

private final String chunkingStrategy;

ChunkingStrategy(String strategy) {
this.chunkingStrategy = strategy;
}

@Override
public String toString() {
return chunkingStrategy;
}

public static ChunkingStrategy fromString(String strategy) {
return EnumSet.allOf(ChunkingStrategy.class)
.stream()
.filter(cs -> cs.chunkingStrategy.equals(strategy))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException(Strings.format("Invalid chunkingStrategy %s", strategy)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class ModelConfigurations implements ToFilteredXContentObject, VersionedN
public static final String SERVICE = "service";
public static final String SERVICE_SETTINGS = "service_settings";
public static final String TASK_SETTINGS = "task_settings";
public static final String CHUNKING_SETTINGS = "chunking_settings";
private static final String NAME = "inference_model";

public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
Expand All @@ -40,7 +41,8 @@ public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
model.getConfigurations().getTaskType(),
model.getConfigurations().getService(),
model.getServiceSettings(),
taskSettings
taskSettings,
model.getConfigurations().getChunkingSettings()
);
}

Expand All @@ -53,7 +55,8 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting
model.getConfigurations().getTaskType(),
model.getConfigurations().getService(),
serviceSettings,
model.getTaskSettings()
model.getTaskSettings(),
model.getConfigurations().getChunkingSettings()
);
}

Expand All @@ -62,6 +65,7 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting
private final String service;
private final ServiceSettings serviceSettings;
private final TaskSettings taskSettings;
private final ChunkingSettings chunkingSettings;

/**
* Allows no task settings to be defined. This will default to the {@link EmptyTaskSettings} object.
Expand All @@ -82,6 +86,23 @@ public ModelConfigurations(
this.service = Objects.requireNonNull(service);
this.serviceSettings = Objects.requireNonNull(serviceSettings);
this.taskSettings = Objects.requireNonNull(taskSettings);
this.chunkingSettings = null;
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
String service,
ServiceSettings serviceSettings,
TaskSettings taskSettings,
ChunkingSettings chunkingSettings
) {
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.service = Objects.requireNonNull(service);
this.serviceSettings = Objects.requireNonNull(serviceSettings);
this.taskSettings = Objects.requireNonNull(taskSettings);
this.chunkingSettings = chunkingSettings;
}

public ModelConfigurations(StreamInput in) throws IOException {
Expand All @@ -90,6 +111,9 @@ public ModelConfigurations(StreamInput in) throws IOException {
this.service = in.readString();
this.serviceSettings = in.readNamedWriteable(ServiceSettings.class);
this.taskSettings = in.readNamedWriteable(TaskSettings.class);
this.chunkingSettings = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)
? in.readOptionalNamedWriteable(ChunkingSettings.class)
: null;
}

@Override
Expand All @@ -99,6 +123,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(service);
out.writeNamedWriteable(serviceSettings);
out.writeNamedWriteable(taskSettings);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)) {
out.writeOptionalNamedWriteable(chunkingSettings);
}
}

public String getInferenceEntityId() {
Expand All @@ -121,6 +148,10 @@ public TaskSettings getTaskSettings() {
return taskSettings;
}

public ChunkingSettings getChunkingSettings() {
return chunkingSettings;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -133,6 +164,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(SERVICE, service);
builder.field(SERVICE_SETTINGS, serviceSettings);
builder.field(TASK_SETTINGS, taskSettings);
if (chunkingSettings != null) {
builder.field(CHUNKING_SETTINGS, chunkingSettings);
}
builder.endObject();
return builder;
}
Expand All @@ -149,6 +183,9 @@ public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params
builder.field(SERVICE, service);
builder.field(SERVICE_SETTINGS, serviceSettings.getFilteredXContentObject());
builder.field(TASK_SETTINGS, taskSettings);
if (chunkingSettings != null) {
builder.field(CHUNKING_SETTINGS, chunkingSettings);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
*/
public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null);
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null);

public final String systemProperty;
public final Version from;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.util.FeatureFlag;

/**
* chunking_settings feature flag. When the feature is complete, this flag will be removed.
*/
public class ChunkingSettingsFeatureFlag {

private ChunkingSettingsFeatureFlag() {}

private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_chunking_settings");

public static boolean isEnabled() {
return FEATURE_FLAG.isEnabled();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
Expand All @@ -29,6 +29,7 @@ public class OpenAIServiceMixedIT extends BaseMixedTestCase {

private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0";
private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0";
private static final String OPEN_AI_EMBEDDINGS_CHUNKING_SETTINGS_ADDED = "8.16.0";
private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0";
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";

Expand All @@ -50,6 +51,7 @@ public static void shutdown() {
openAiChatCompletionsServer.close();
}

@AwaitsFix(bugUrl = "Backport #112074 to 8.16")
@SuppressWarnings("unchecked")
public void testOpenAiEmbeddings() throws IOException {
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
Expand All @@ -64,7 +66,23 @@ public void testOpenAiEmbeddings() throws IOException {
String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig();
// queue a response as PUT will call the service
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);

try {
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
} catch (Exception e) {
if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_CHUNKING_SETTINGS_ADDED)) {
// Chunking settings were added in 8.16.0. if the version is before that, an exception will be thrown if the index mapping
// was created based on a mapping from an old node
assertThat(
e.getMessage(),
containsString(
"One or more nodes in your cluster does not support chunking_settings. "
+ "Please update all nodes in your cluster to use chunking_settings."
)
);
return;
}
}

var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ private InferenceIndex() {}

public static final String INDEX_NAME = ".inference";
public static final String INDEX_PATTERN = INDEX_NAME + "*";
public static final String INDEX_ALIAS = ".inference-alias";

// Increment this version number when the mappings change
private static final int INDEX_MAPPING_VERSION = 1;
private static final int INDEX_MAPPING_VERSION = 2;

public static Settings settings() {
return Settings.builder()
Expand Down Expand Up @@ -84,6 +85,50 @@ public static XContentBuilder mappings() {
.startObject("properties")
.endObject()
.endObject()
.startObject("chunking_settings")
.field("dynamic", "false")
.startObject("properties")
.startObject("strategy")
.field("type", "keyword")
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();
} catch (IOException e) {
throw new UncheckedIOException("Failed to build mappings for index " + INDEX_NAME, e);
}
}

public static XContentBuilder mappingsV1() {
try {
return jsonBuilder().startObject()
.startObject(SINGLE_MAPPING_NAME)
.startObject("_meta")
.field(SystemIndexDescriptor.VERSION_META_KEY, 1)
.endObject()
.field("dynamic", "strict")
.startObject("properties")
.startObject("model_id")
.field("type", "keyword")
.endObject()
.startObject("task_type")
.field("type", "keyword")
.endObject()
.startObject("service")
.field("type", "keyword")
.endObject()
.startObject("service_settings")
.field("dynamic", "false")
.startObject("properties")
.endObject()
.endObject()
.startObject("task_settings")
.field("dynamic", "false")
.startObject("properties")
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InferenceResults;
Expand All @@ -26,6 +27,8 @@
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
Expand Down Expand Up @@ -108,6 +111,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

addChunkingSettingsNamedWriteables(namedWriteables);

// Empty default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new));

Expand Down Expand Up @@ -444,6 +449,19 @@ private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteabl
);
}

private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ChunkingSettings.class,
SentenceBoundaryChunkingSettings.NAME,
SentenceBoundaryChunkingSettings::new
)
);
}

private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new)
Expand Down
Loading