Skip to content

Commit

Permalink
[ML] Create default word based chunker (#107303)
Browse files Browse the repository at this point in the history
WordBoundaryChunker uses ICU4J to split text at word boundaries
creating chunks from long inputs. The chunksize and overlap 
parameters are measured in words. The chunk text is then processed 
in batches depending on the inference services supported batch size.
  • Loading branch information
davidkyle committed Apr 15, 2024
1 parent 17d6bc2 commit ecc406e
Show file tree
Hide file tree
Showing 18 changed files with 1,190 additions and 157 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/107303.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 107303
summary: Create default word based chunker
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public record ChunkedTextEmbeddingFloatResults(List<EmbeddingChunk> chunks) implements ChunkedInferenceServiceResults {

public static final String NAME = "chunked_text_embedding_service_float_results";
public static final String FIELD_NAME = "text_embedding_float_chunk";

public ChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(EmbeddingChunk::new));
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
// TODO add isTruncated flag
builder.startArray(FIELD_NAME);
for (var embedding : chunks) {
embedding.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(chunks);
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Chunked results are not returned in the legacy format");
}

@Override
public Map<String, Object> asMap() {
return Map.of(FIELD_NAME, chunks.stream().map(EmbeddingChunk::asMap).collect(Collectors.toList()));
}

@Override
public String getWriteableName() {
return NAME;
}

public List<EmbeddingChunk> getChunks() {
return chunks;
}

public record EmbeddingChunk(String matchedText, List<Float> embedding) implements Writeable, ToXContentObject {

public EmbeddingChunk(StreamInput in) throws IOException {
this(in.readString(), in.readCollectionAsImmutableList(StreamInput::readFloat));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(matchedText);
out.writeCollection(embedding, StreamOutput::writeFloat);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ChunkedNlpInferenceResults.TEXT, matchedText);

builder.startArray(ChunkedNlpInferenceResults.INFERENCE);
for (Float value : embedding) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

public Map<String, Object> asMap() {
var map = new HashMap<String, Object>();
map.put(ChunkedNlpInferenceResults.TEXT, matchedText);
map.put(ChunkedNlpInferenceResults.INFERENCE, embedding);
return map;
}

@Override
public String toString() {
return Strings.toString(this);
}
}

}
2 changes: 2 additions & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ dependencies {
compileOnly project(path: xpackModule('core'))
testImplementation(testArtifact(project(xpackModule('core'))))
testImplementation project(':modules:reindex')

api "com.ibm.icu:icu4j:${versions.icu4j}"
}
33 changes: 33 additions & 0 deletions x-pack/plugin/inference/licenses/icu4j-LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
ICU License - ICU 1.8.1 and later

COPYRIGHT AND PERMISSION NOTICE

Copyright (c) 1995-2012 International Business Machines Corporation and others

All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, and/or sell copies of the
Software, and to permit persons to whom the Software is furnished to do so,
provided that the above copyright notice(s) and this permission notice appear
in all copies of the Software and that both the above copyright notice(s) and
this permission notice appear in supporting documentation.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS.
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE BE
LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR
ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

Except as contained in this notice, the name of a copyright holder shall not
be used in advertising or otherwise to promote the sale, use or other
dealings in this Software without prior written authorization of the
copyright holder.

All trademarks and registered trademarks mentioned herein are the property of
their respective owners.
3 changes: 3 additions & 0 deletions x-pack/plugin/inference/licenses/icu4j-NOTICE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ICU4J, (under lucene/analysis/icu) is licensed under an MIT style license
(modules/analysis/icu/lib/icu4j-LICENSE-BSD_LIKE.txt) and Copyright (c) 1995-2012
International Business Machines Corporation and others
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
requires org.apache.httpcomponents.httpasyncclient;
requires org.apache.httpcomponents.httpcore.nio;
requires org.apache.lucene.core;
requires com.ibm.icu;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
Expand Down Expand Up @@ -105,6 +106,13 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
ChunkedTextEmbeddingResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingFloatResults.NAME,
ChunkedTextEmbeddingFloatResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
Expand Down

0 comments on commit ecc406e

Please sign in to comment.