Skip to content

Commit

Permalink
[ML] Multiple items in a single inference request (#75759)
Browse files Browse the repository at this point in the history
Inference requests can be batched by adding more rows to the input tensor. 
These batch calls are more performant than making multiple calls to forward()
with a single input when all the inputs are of a similar length. The expected 
input is now a 2D array of tokens and 2D arrays of supporting arguments, 
the output is a 3D array.
  • Loading branch information
davidkyle committed Sep 1, 2021
1 parent 5d48fdc commit 7a28310
Show file tree
Hide file tree
Showing 22 changed files with 515 additions and 208 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public static Request parseRequest(String deploymentId, XContentParser parser) {
return builder.build();
}

private String deploymentId;
private List<Map<String, Object>> docs;
private final String deploymentId;
private final List<Map<String, Object>> docs;

public Request(String deploymentId, List<Map<String, Object>> docs) {
this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, DEPLOYMENT_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,58 @@ public static <T> List<List<T>> parseArrayOfArrays(String fieldName, CheckedFunc
}
return values;
}

/**
* Parses a 3 dimensional array of doubles.
*
* @param fieldName the field name
* @param parser the outer parser
* @return The 3D array of doubles
* @throws IOException If parsing fails
*/
public static double[][][] parse3DArrayOfDoubles(String fieldName, XContentParser parser) throws IOException {
if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}
List<List<List<Double>>> values = new ArrayList<>();
while(parser.nextToken() != XContentParser.Token.END_ARRAY) {
if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}

List<List<Double>> innerList = new ArrayList<>();

while(parser.nextToken() != XContentParser.Token.END_ARRAY) {
if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}

if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}

List<Double> innerInner = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
if (parser.currentToken() != XContentParser.Token.VALUE_NUMBER) {
throw new IllegalStateException("expected non-null numerical value but got [" + parser.currentToken() + "] " +
"for [" + fieldName + "]");
}
innerInner.add(parser.doubleValue());
}
innerList.add(innerInner);
}
values.add(innerList);
}

double [][][] val = new double[values.size()][values.get(0).size()][values.get(0).get(0).size()];

for (int i = 0; i < val.length; i++) {
for (int j = 0; j < val[0].length; j++) {
double[] doubles = values.get(i).get(j).stream().mapToDouble(d -> d).toArray();
System.arraycopy(doubles, 0, val[i][j], 0, doubles.length);
}
}

return val;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -109,7 +111,8 @@ public void setLogging() throws IOException {
"{" +
"\"transient\" : {\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\",\n" +
" \"logger.org.elasticsearch.xpack.ml.process.logging\" : \"TRACE\"\n" +
" }" +
"}");
client().performRequest(loggingSettings);
Expand All @@ -124,7 +127,8 @@ public void cleanup() throws Exception {
"{" +
"\"transient\" : {\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.allocation\" :null,\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null,\n" +
" \"logger.org.elasticsearch.xpack.ml.process.logging\" : null\n" +
" }" +
"}");
client().performRequest(loggingSettings);
Expand All @@ -133,7 +137,6 @@ public void cleanup() throws Exception {
waitForPendingTasks(adminClient());
}

@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
public void testEvaluate() throws IOException, InterruptedException {
String modelId = "test_evaluate";
createModelStoreIndex();
Expand Down Expand Up @@ -168,7 +171,6 @@ public void testEvaluate() throws IOException, InterruptedException {
}

@SuppressWarnings("unchecked")
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
public void testLiveDeploymentStats() throws IOException {
String modelA = "model_a";

Expand All @@ -193,7 +195,6 @@ public void testLiveDeploymentStats() throws IOException {
}

@SuppressWarnings("unchecked")
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
public void testGetDeploymentStats_WithWildcard() throws IOException {

{
Expand Down Expand Up @@ -262,7 +263,6 @@ public void testGetDeploymentStats_WithWildcard() throws IOException {
}

@SuppressWarnings("unchecked")
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
public void testGetDeploymentStats_WithStartedStoppedDeployments() throws IOException {
putVocabulary(List.of("once", "twice"));
String modelFoo = "foo";
Expand Down Expand Up @@ -367,7 +367,10 @@ private void createModelStoreIndex() throws IOException {
}

private void putVocabulary(List<String> vocabulary) throws IOException {
String quotedWords = vocabulary.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
List<String> vocabularyWithPad = new ArrayList<>();
vocabularyWithPad.add(BertTokenizer.PAD_TOKEN);
vocabularyWithPad.addAll(vocabulary);
String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));

Request request = new Request("PUT", "/" + VOCAB_INDEX + "/_doc/test_vocab");
request.setJsonEntity("{ " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -232,7 +234,10 @@ public void onFailure(Exception e) {
@Override
protected void doRun() {
try {
String text = NlpTask.extractInput(processContext.modelInput.get(), doc);
// The request builder expect a list of inputs which are then batched.
// TODO batching was implemented for expected use-cases such as zero-shot
// classification but is not used here.
List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
processor.validateInputs(text);
NlpTask.Request request = processor.getRequestBuilder().buildRequest(text, requestId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@

package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
Expand All @@ -37,21 +36,13 @@ public class PyTorchResult implements ToXContentObject, Writeable {
private static final ParseField TIME_MS = new ParseField("time_ms");

public static final ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>("pytorch_result",
a -> new PyTorchResult((String) a[0], (double[][]) a[1], (Long) a[2], (String) a[3]));
a -> new PyTorchResult((String) a[0], (double[][][]) a[1], (Long) a[2], (String) a[3]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID);
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
List<List<Double>> listOfListOfDoubles = MlParserUtils.parseArrayOfArrays(
INFERENCE.getPreferredName(), XContentParser::doubleValue, p);
double[][] primitiveDoubles = new double[listOfListOfDoubles.size()][];
for (int i = 0; i < listOfListOfDoubles.size(); i++) {
List<Double> row = listOfListOfDoubles.get(i);
primitiveDoubles[i] = row.stream().mapToDouble(d -> d).toArray();
}
return primitiveDoubles;
},
(p, c) ->
MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p),
INFERENCE,
ObjectParser.ValueType.VALUE_ARRAY
);
Expand All @@ -64,12 +55,12 @@ public static PyTorchResult fromXContent(XContentParser parser) throws IOExcepti
}

private final String requestId;
private final double[][] inference;
private final double[][][] inference;
private final Long timeMs;
private final String error;

public PyTorchResult(String requestId,
@Nullable double[][] inference,
@Nullable double[][][] inference,
@Nullable Long timeMs,
@Nullable String error) {
this.requestId = Objects.requireNonNull(requestId);
Expand All @@ -82,7 +73,7 @@ public PyTorchResult(StreamInput in) throws IOException {
requestId = in.readString();
boolean hasInference = in.readBoolean();
if (hasInference) {
inference = in.readArray(StreamInput::readDoubleArray, double[][]::new);
inference = in.readArray(in2 -> in2.readArray(StreamInput::readDoubleArray, double[][]::new), double[][][]::new);
} else {
inference = null;
}
Expand All @@ -102,7 +93,7 @@ public String getError() {
return error;
}

public double[][] getInferenceResult() {
public double[][][] getInferenceResult() {
return inference;
}

Expand All @@ -115,7 +106,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(REQUEST_ID.getPreferredName(), requestId);
if (inference != null) {
builder.field(INFERENCE.getPreferredName(), inference);
builder.startArray(INFERENCE.getPreferredName());
for (int i = 0; i < inference.length; i++) {
builder.startArray();
for (int j = 0; j < inference[0].length; j++)
{
builder.startArray();
for (int k = 0; k < inference[0][0].length; k++) {
builder.value(inference[i][j][k]);
}
builder.endArray();
}
builder.endArray();
}
builder.endArray();
}
if (timeMs != null) {
builder.field(TIME_MS.getPreferredName(), timeMs);
Expand All @@ -134,7 +138,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeArray(StreamOutput::writeDoubleArray, inference);
out.writeArray(
(out2, arr) -> out2.writeArray(StreamOutput::writeDoubleArray, arr),
inference);
}
out.writeOptionalLong(timeMs);
out.writeOptionalString(error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public class BertRequestBuilder implements NlpTask.RequestBuilder {

Expand All @@ -31,30 +31,33 @@ public BertRequestBuilder(BertTokenizer tokenizer) {
}

@Override
public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
TokenizationResult tokenization = tokenizer.tokenize(input);
return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
if (tokenizer.getPadToken().isEmpty()) {
throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
" token in its vocabulary");
}

TokenizationResult tokenization = tokenizer.tokenize(inputs);
return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId));
}

static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
static BytesReference jsonRequest(TokenizationResult tokenization,
int padToken,
String requestId) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.field(REQUEST_ID, requestId);
builder.array(TOKENS, tokens);

int[] inputMask = new int[tokens.length];
Arrays.fill(inputMask, 1);
int[] segmentMask = new int[tokens.length];
Arrays.fill(segmentMask, 0);
int[] positionalIds = new int[tokens.length];
Arrays.setAll(positionalIds, i -> i);

builder.array(ARG1, inputMask);
builder.array(ARG2, segmentMask);
builder.array(ARG3, positionalIds);

NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
int batchSize = tokenization.getTokenizations().size();
NlpTask.RequestBuilder.writeNonPaddedArguments(ARG2, batchSize, tokenization.getLongestSequenceLength(), i -> 0, builder);
NlpTask.RequestBuilder.writeNonPaddedArguments(ARG3, batchSize, tokenization.getLongestSequenceLength(), i -> i, builder);
builder.endObject();

// BytesReference.bytes closes the builder
return BytesReference.bytes(builder);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public class DistilBertRequestBuilder implements NlpTask.RequestBuilder {

Expand All @@ -29,21 +29,24 @@ public DistilBertRequestBuilder(BertTokenizer tokenizer) {
}

@Override
public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
TokenizationResult result = tokenizer.tokenize(input);
return new NlpTask.Request(result, jsonRequest(result.getTokenIds(), requestId));
public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
if (tokenizer.getPadToken().isEmpty()) {
throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
" token in its vocabulary");
}

TokenizationResult result = tokenizer.tokenize(inputs);
return new NlpTask.Request(result, jsonRequest(result, tokenizer.getPadToken().getAsInt(), requestId));
}

static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
static BytesReference jsonRequest(TokenizationResult tokenization,
int padToken,
String requestId) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.field(REQUEST_ID, requestId);
builder.array(TOKENS, tokens);

int[] inputMask = new int[tokens.length];
Arrays.fill(inputMask, 1);

builder.array(ARG1, inputMask);
NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
builder.endObject();

// BytesReference.bytes closes the builder
Expand Down

0 comments on commit 7a28310

Please sign in to comment.