Skip to content

Commit

Permalink
[ML] Track token positions and use source string to tag NER entities (#…
Browse files Browse the repository at this point in the history
…81275)

By recording the position of the tokens in the original source string the entity 
labels is correctly constructed on the original text preserving case and accent
characters that were otherwise stripped during normalisation.
  • Loading branch information
davidkyle committed Dec 7, 2021
1 parent 20a131c commit d7117f2
Show file tree
Hide file tree
Showing 15 changed files with 632 additions and 343 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -163,6 +164,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

@Override
public String toString() {
return Strings.toString(this);
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put("entity", entity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public BertRequestBuilder(BertTokenizer tokenizer) {

@Override
public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
if (tokenizer.getPadToken().isEmpty()) {
if (tokenizer.getPadTokenId().isEmpty()) {
throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
}

Expand All @@ -46,10 +46,10 @@ public NlpTask.Request buildRequest(List<String> inputs, String requestId, Token

@Override
public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
if (tokenizer.getPadToken().isEmpty()) {
if (tokenizer.getPadTokenId().isEmpty()) {
throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
}
return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId));
return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
}

static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

Expand All @@ -27,10 +26,10 @@

public class FillMaskProcessor implements NlpTask.Processor {

private final NlpTask.RequestBuilder requestBuilder;
private final NlpTokenizer tokenizer;

FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
this.requestBuilder = tokenizer.requestBuilder();
this.tokenizer = tokenizer;
}

@Override
Expand All @@ -39,22 +38,23 @@ public void validateInputs(List<String> inputs) {
throw new IllegalArgumentException("input request is empty");
}

final String mask = tokenizer.getMaskToken();
for (String input : inputs) {
int maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN);
int maskIndex = input.indexOf(mask);
if (maskIndex < 0) {
throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found");
throw new IllegalArgumentException("no " + mask + " token could be found");
}

maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length());
maskIndex = input.indexOf(mask, maskIndex + mask.length());
if (maskIndex > 0) {
throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input");
throw new IllegalArgumentException("only one " + mask + " token should exist in the input");
}
}
}

@Override
public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
return requestBuilder;
return tokenizer.requestBuilder();
}

@Override
Expand All @@ -64,25 +64,55 @@ public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
return (tokenization, result) -> processResult(
tokenization,
result,
tokenizer,
fillMaskConfig.getNumTopClasses(),
fillMaskConfig.getResultsField()
);
} else {
return (tokenization, result) -> processResult(tokenization, result, FillMaskConfig.DEFAULT_NUM_RESULTS, DEFAULT_RESULTS_FIELD);
return (tokenization, result) -> processResult(
tokenization,
result,
tokenizer,
FillMaskConfig.DEFAULT_NUM_RESULTS,
DEFAULT_RESULTS_FIELD
);
}
}

static InferenceResults processResult(
TokenizationResult tokenization,
PyTorchResult pyTorchResult,
NlpTokenizer tokenizer,
int numResults,
String resultsField
) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
return new WarningInferenceResults("No valid tokens for inference");
}

int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf(BertTokenizer.MASK_TOKEN);
if (tokenizer.getMaskTokenId().isEmpty()) {
return new WarningInferenceResults(
"The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token",
tokenizer.getMaskToken()
);
}

int maskTokenIndex = -1;
int maskTokenId = tokenizer.getMaskTokenId().getAsInt();
for (int i = 0; i < tokenization.getTokenizations().get(0).getTokenIds().length; i++) {
if (tokenization.getTokenizations().get(0).getTokenIds()[i] == maskTokenId) {
maskTokenIndex = i;
break;
}
}
if (maskTokenIndex == -1) {
return new WarningInferenceResults(
"mask token id [{}] not found in the tokenization {}",
maskTokenId,
Arrays.asList(tokenization.getTokenizations().get(0).getTokenIds())
);
}

// TODO - process all results in the batch
double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);

Expand All @@ -103,7 +133,7 @@ static InferenceResults processResult(
tokenization.getTokenizations()
.get(0)
.getInput()
.replace(BertTokenizer.MASK_TOKEN, tokenization.getFromVocab(scoreAndIndices[0].index)),
.replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)),
results,
Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
scoreAndIndices[0].score,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

Expand Down Expand Up @@ -193,7 +194,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor {

@Override
public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
return new WarningInferenceResults("no valid tokens to build result");
}
// TODO - process all results in the batch
Expand All @@ -213,6 +214,7 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe
? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT)
: tokenization.getTokenizations().get(0).getInput()
);

return new NerResults(
resultsField,
buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities),
Expand All @@ -230,23 +232,17 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe
static List<TaggedToken> tagTokens(TokenizationResult.Tokenization tokenization, double[][] scores, IobTag[] iobMap) {
List<TaggedToken> taggedTokens = new ArrayList<>();
int startTokenIndex = 0;
while (startTokenIndex < tokenization.getTokens().length) {
while (startTokenIndex < tokenization.getTokenIds().length) {
int inputMapping = tokenization.getTokenMap()[startTokenIndex];
if (inputMapping < 0) {
// This token does not map to a token in the input (special tokens)
startTokenIndex++;
continue;
}
int endTokenIndex = startTokenIndex;
StringBuilder word = new StringBuilder(tokenization.getTokens()[startTokenIndex]);
while (endTokenIndex < tokenization.getTokens().length - 1
while (endTokenIndex < tokenization.getTokenMap().length - 1
&& tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) {
endTokenIndex++;
// TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens.
// It is probably more correct to implement detokenization on the tokenizer
// that does reverse lookup based on token IDs.
String endTokenWord = tokenization.getTokens()[endTokenIndex].substring(2);
word.append(endTokenWord);
}
double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);
for (int i = startTokenIndex + 1; i <= endTokenIndex; i++) {
Expand All @@ -262,7 +258,7 @@ static List<TaggedToken> tagTokens(TokenizationResult.Tokenization tokenization,
}
int maxScoreIndex = NlpHelpers.argmax(avgScores);
double score = avgScores[maxScoreIndex];
taggedTokens.add(new TaggedToken(word.toString(), iobMap[maxScoreIndex], score));
taggedTokens.add(new TaggedToken(tokenization.getTokens().get(inputMapping), iobMap[maxScoreIndex], score));
startTokenIndex = endTokenIndex + 1;
}
return taggedTokens;
Expand All @@ -283,58 +279,63 @@ static List<NerResults.EntityGroup> groupTaggedTokens(List<TaggedToken> tokens,
}
List<NerResults.EntityGroup> entities = new ArrayList<>();
int startTokenIndex = 0;
int startFindInSeq = 0;
while (startTokenIndex < tokens.size()) {
TaggedToken token = tokens.get(startTokenIndex);
if (token.tag.getEntity() == Entity.NONE) {
startTokenIndex++;
continue;
}
StringBuilder entityWord = new StringBuilder(token.word);
int endTokenIndex = startTokenIndex + 1;
double scoreSum = token.score;
while (endTokenIndex < tokens.size()) {
TaggedToken endToken = tokens.get(endTokenIndex);
if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) {
break;
}
// TODO Here we add a space between tokens.
// It is probably more correct to implement detokenization on the tokenizer
// that does reverse lookup based on token IDs.
entityWord.append(" ").append(endToken.word);
scoreSum += endToken.score;
endTokenIndex++;
}
String entity = entityWord.toString();
int i = inputSeq.indexOf(entity, startFindInSeq);

int startPos = token.token.getStartPos();
int endPos = tokens.get(endTokenIndex - 1).token.getEndPos();
String entity = inputSeq.substring(startPos, endPos);
entities.add(
new NerResults.EntityGroup(
entity,
token.tag.getEntity().toString(),
scoreSum / (endTokenIndex - startTokenIndex),
i,
i == -1 ? -1 : i + entity.length()
startPos,
endPos
)
);
startTokenIndex = endTokenIndex;
if (i != -1) {
startFindInSeq = i + entity.length();
}
}

return entities;
}

static class TaggedToken {
private final String word;
private final DelimitedToken token;
private final IobTag tag;
private final double score;

TaggedToken(String word, IobTag tag, double score) {
this.word = word;
TaggedToken(DelimitedToken token, IobTag tag, double score) {
this.token = token;
this.tag = tag;
this.score = score;
}

@Override
public String toString() {
return new StringBuilder("{").append("token:")
.append(token)
.append(", ")
.append(tag)
.append(", ")
.append(score)
.append("}")
.toString();
}
}
}
}

0 comments on commit d7117f2

Please sign in to comment.