Skip to content

Commit

Permalink
[ML] refactoring internal tokenization logic for NLP (#83835)
Browse files Browse the repository at this point in the history
This simplifies the internal logic used to pass tokenization results around while streamlining building the request sent to the model.

This helps lay some of the ground work for windowing as collapsing request building && token results will be required (as a single sequence could result in a batch request).

Additionally, many of the intellij warnings are addressed and code is modernized (i.e. taking advantage of records)
  • Loading branch information
benwtrent committed Feb 15, 2022
1 parent ad44b88 commit ac3d0be
Show file tree
Hide file tree
Showing 27 changed files with 542 additions and 662 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
Expand Down Expand Up @@ -201,7 +202,11 @@ Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
try (
InputStream stream = hit.getSourceRef().streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)
.createParser(
XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry)
.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE),
stream
)
) {
return Vocabulary.createParser(true).apply(parser, null);
} catch (IOException e) {
Expand Down Expand Up @@ -374,8 +379,8 @@ protected void doRun() throws Exception {
NlpConfig nlpConfig = (NlpConfig) config;
NlpTask.Request request = processor.getRequestBuilder(nlpConfig)
.buildRequest(text, requestIdStr, nlpConfig.getTokenization().getTruncate());
logger.debug(() -> "Inference Request " + request.processInput.utf8ToString());
if (request.tokenization.anyTruncated()) {
logger.debug(() -> "Inference Request " + request.processInput().utf8ToString());
if (request.tokenization().anyTruncated()) {
logger.debug("[{}] [{}] input truncated", modelId, requestId);
}
processContext.getResultProcessor()
Expand All @@ -385,14 +390,14 @@ protected void doRun() throws Exception {
inferenceResult -> processResult(
inferenceResult,
processContext,
request.tokenization,
request.tokenization(),
processor.getResultProcessor((NlpConfig) config),
this
),
this::onFailure
)
);
processContext.process.get().writeInferenceRequest(request.processInput);
processContext.process.get().writeInferenceRequest(request.processInput());
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] error writing to inference process", processContext.task.getModelId()), e);
onFailure(ExceptionsHelper.serverError("Error writing to inference process", e));
Expand Down Expand Up @@ -448,8 +453,8 @@ class ProcessContext {
private volatile Instant startTime;
private volatile Integer inferenceThreads;
private volatile Integer modelThreads;
private AtomicInteger rejectedExecutionCount = new AtomicInteger();
private AtomicInteger timeoutCount = new AtomicInteger();
private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
private final AtomicInteger timeoutCount = new AtomicInteger();

ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
this.task = Objects.requireNonNull(task);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,14 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;

public class FillMaskProcessor implements NlpTask.Processor {

private final NlpTokenizer tokenizer;
public class FillMaskProcessor extends NlpTask.Processor {

FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
this.tokenizer = tokenizer;
}

@Override
public void close() {
tokenizer.close();
super(tokenizer);
}

@Override
Expand Down Expand Up @@ -97,7 +91,7 @@ static InferenceResults processResult(
int numResults,
String resultsField
) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
if (tokenization.isEmpty()) {
throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR);
}

Expand All @@ -108,25 +102,20 @@ static InferenceResults processResult(
);
}

int maskTokenIndex = -1;
int maskTokenId = tokenizer.getMaskTokenId().getAsInt();
for (int i = 0; i < tokenization.getTokenizations().get(0).getTokenIds().length; i++) {
if (tokenization.getTokenizations().get(0).getTokenIds()[i] == maskTokenId) {
maskTokenIndex = i;
break;
}
}
if (maskTokenIndex == -1) {
OptionalInt maskTokenIndex = tokenization.getTokenization(0).getTokenIndex(maskTokenId);
if (maskTokenIndex.isEmpty()) {
throw new ElasticsearchStatusException(
"mask token id [{}] not found in the tokenization {}",
"mask token id [{}] not found in the tokenization",
RestStatus.INTERNAL_SERVER_ERROR,
maskTokenId,
List.of(tokenization.getTokenizations().get(0).getTokenIds())
maskTokenId
);
}

// TODO - process all results in the batch
double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(
pyTorchResult.getInferenceResult()[0][maskTokenIndex.getAsInt()]
);

NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(
// We need at least one to record the result
Expand All @@ -142,10 +131,7 @@ static InferenceResults processResult(
}
return new FillMaskResults(
tokenization.getFromVocab(scoreAndIndices[0].index),
tokenization.getTokenizations()
.get(0)
.getInput()
.replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)),
tokenization.getTokenization(0).input().replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)),
results,
Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
scoreAndIndices[0].score,
Expand Down

This file was deleted.

0 comments on commit ac3d0be

Please sign in to comment.