Skip to content

Commit

Permalink
[ML] Replace objects with primitives in Text Embedding Results classes (
Browse files Browse the repository at this point in the history
#108161)

* tests pass

* Update docs/changelog/108161.yaml

* precommit

* merge

* remove uncessary comments

* fix syntax error in test-service-plugin

* create Embedding.of to handle conversion from List of objects

* Update docs/changelog/108161.yaml

* Update 108161.yaml

* fix merge conflicts

* Update docs/changelog/108161.yaml

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
maxhniebergall and elasticmachine committed May 17, 2024
1 parent b3a902e commit ca2ce0e
Show file tree
Hide file tree
Showing 37 changed files with 385 additions and 269 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/108161.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 108161
summary: Refactor TextEmbeddingResults to use primitives rather than objects
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.Objects;

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

Expand All @@ -47,7 +47,7 @@ public static List<ChunkedInferenceServiceResults> of(List<String> inputs, TextE
return results;
}

public static ChunkedTextEmbeddingByteResults of(String input, List<Byte> byteEmbeddings) {
public static ChunkedTextEmbeddingByteResults of(String input, byte[] byteEmbeddings) {
return new ChunkedTextEmbeddingByteResults(List.of(new EmbeddingChunk(input, byteEmbeddings)), false);
}

Expand Down Expand Up @@ -84,7 +84,7 @@ public List<? extends InferenceResults> transformToLegacyFormat() {

@Override
public Map<String, Object> asMap() {
return Map.of(FIELD_NAME, chunks.stream().map(EmbeddingChunk::asMap).collect(Collectors.toList()));
return Map.of(FIELD_NAME, chunks);
}

@Override
Expand All @@ -96,16 +96,29 @@ public List<EmbeddingChunk> getChunks() {
return chunks;
}

public record EmbeddingChunk(String matchedText, List<Byte> embedding) implements Writeable, ToXContentObject {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChunkedTextEmbeddingByteResults that = (ChunkedTextEmbeddingByteResults) o;
return isTruncated == that.isTruncated && Objects.equals(chunks, that.chunks);
}

@Override
public int hashCode() {
return Objects.hash(chunks, isTruncated);
}

public record EmbeddingChunk(String matchedText, byte[] embedding) implements Writeable, ToXContentObject {

public EmbeddingChunk(StreamInput in) throws IOException {
this(in.readString(), in.readCollectionAsImmutableList(StreamInput::readByte));
this(in.readString(), in.readByteArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(matchedText);
out.writeCollection(embedding, StreamOutput::writeByte);
out.writeByteArray(embedding);
}

@Override
Expand All @@ -114,7 +127,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(ChunkedNlpInferenceResults.TEXT, matchedText);

builder.startArray(ChunkedNlpInferenceResults.INFERENCE);
for (Byte value : embedding) {
for (byte value : embedding) {
builder.value(value);
}
builder.endArray();
Expand All @@ -123,16 +136,24 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public Map<String, Object> asMap() {
var map = new HashMap<String, Object>();
map.put(ChunkedNlpInferenceResults.TEXT, matchedText);
map.put(ChunkedNlpInferenceResults.INFERENCE, embedding);
return map;
}

@Override
public String toString() {
return Strings.toString(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EmbeddingChunk that = (EmbeddingChunk) o;
return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding);
}

@Override
public int hashCode() {
int result = Objects.hash(matchedText);
result = 31 * result + Arrays.hashCode(embedding);
return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;

import java.io.IOException;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.Objects;

public record ChunkedTextEmbeddingFloatResults(List<EmbeddingChunk> chunks) implements ChunkedInferenceServiceResults {

Expand Down Expand Up @@ -61,7 +61,7 @@ public List<? extends InferenceResults> transformToLegacyFormat() {

@Override
public Map<String, Object> asMap() {
return Map.of(FIELD_NAME, chunks.stream().map(EmbeddingChunk::asMap).collect(Collectors.toList()));
return Map.of(FIELD_NAME, chunks);
}

@Override
Expand All @@ -73,16 +73,29 @@ public List<EmbeddingChunk> getChunks() {
return chunks;
}

public record EmbeddingChunk(String matchedText, List<Float> embedding) implements Writeable, ToXContentObject {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ChunkedTextEmbeddingFloatResults that = (ChunkedTextEmbeddingFloatResults) o;
return Objects.equals(chunks, that.chunks);
}

@Override
public int hashCode() {
return Objects.hash(chunks);
}

public record EmbeddingChunk(String matchedText, float[] embedding) implements Writeable, ToXContentObject {

public EmbeddingChunk(StreamInput in) throws IOException {
this(in.readString(), in.readCollectionAsImmutableList(StreamInput::readFloat));
this(in.readString(), in.readFloatArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(matchedText);
out.writeCollection(embedding, StreamOutput::writeFloat);
out.writeFloatArray(embedding);
}

@Override
Expand All @@ -91,7 +104,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(ChunkedNlpInferenceResults.TEXT, matchedText);

builder.startArray(ChunkedNlpInferenceResults.INFERENCE);
for (Float value : embedding) {
for (float value : embedding) {
builder.value(value);
}
builder.endArray();
Expand All @@ -100,17 +113,25 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public Map<String, Object> asMap() {
var map = new HashMap<String, Object>();
map.put(ChunkedNlpInferenceResults.TEXT, matchedText);
map.put(ChunkedNlpInferenceResults.INFERENCE, embedding);
return map;
}

@Override
public String toString() {
return Strings.toString(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EmbeddingChunk that = (EmbeddingChunk) o;
return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding);
}

@Override
public int hashCode() {
int result = Objects.hash(matchedText);
result = 31 * result + Arrays.hashCode(embedding);
return result;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

Expand Down Expand Up @@ -50,8 +50,8 @@ public static List<ChunkedInferenceServiceResults> of(List<String> inputs, TextE
return results;
}

public static ChunkedTextEmbeddingResults of(String input, List<Float> floatEmbeddings) {
double[] doubleEmbeddings = floatEmbeddings.stream().mapToDouble(ChunkedTextEmbeddingResults::floatToDouble).toArray();
public static ChunkedTextEmbeddingResults of(String input, float[] floatEmbeddings) {
double[] doubleEmbeddings = IntStream.range(0, floatEmbeddings.length).mapToDouble(i -> floatEmbeddings[i]).toArray();

return new ChunkedTextEmbeddingResults(
List.of(
Expand Down Expand Up @@ -115,12 +115,7 @@ public List<? extends InferenceResults> transformToLegacyFormat() {

@Override
public Map<String, Object> asMap() {
return Map.of(
FIELD_NAME,
chunks.stream()
.map(org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk::asMap)
.collect(Collectors.toList())
);
return Map.of(FIELD_NAME, chunks);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/

package org.elasticsearch.xpack.core.inference.results;
Expand All @@ -17,10 +19,11 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.Objects;

/**
* Writes a text embedding result in the following json format
Expand Down Expand Up @@ -80,15 +83,15 @@ public String getResultsField() {
@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(getResultsField(), embeddings.stream().map(Embedding::asMap).collect(Collectors.toList()));
map.put(getResultsField(), embeddings);

return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
map.put(outputField, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList()));
map.put(outputField, embeddings);

return map;
}
Expand All @@ -98,28 +101,41 @@ public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LegacyTextEmbeddingResults that = (LegacyTextEmbeddingResults) o;
return Objects.equals(embeddings, that.embeddings);
}

@Override
public int hashCode() {
return Objects.hash(embeddings);
}

public TextEmbeddingResults transformToTextEmbeddingResults() {
return new TextEmbeddingResults(this);
}

public record Embedding(List<Float> values) implements Writeable, ToXContentObject {
public record Embedding(float[] values) implements Writeable, ToXContentObject {
public static final String EMBEDDING = "embedding";

public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(StreamInput::readFloat));
this(in.readFloatArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(values, StreamOutput::writeFloat);
out.writeFloatArray(values);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(EMBEDDING);
for (Float value : values) {
for (float value : values) {
builder.value(value);
}
builder.endArray();
Expand All @@ -133,8 +149,17 @@ public String toString() {
return Strings.toString(this);
}

public Map<String, Object> asMap() {
return Map.of(EMBEDDING, values);
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Embedding embedding = (Embedding) o;
return Arrays.equals(values, embedding.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}
}
Loading

0 comments on commit ca2ce0e

Please sign in to comment.