Skip to content

Commit

Permalink
[ML] add new truncate parameter tokenization (#79515)
Browse files Browse the repository at this point in the history
This commit adds a new `truncate` parameter to tokenization.

Valid values are:
first : truncate only the first sequence (if two are provided)
second: truncate only the second sequence (if two are provided)
none: do no truncation, which means we throw an error when sequences are too long
  • Loading branch information
benwtrent committed Oct 21, 2021
1 parent c369288 commit 24c659e
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ public static ConstructingObjectParser<BertTokenization, Void> createParser(bool
ConstructingObjectParser<BertTokenization, Void> parser = new ConstructingObjectParser<>(
"bert_tokenization",
ignoreUnknownFields,
a -> new BertTokenization((Boolean) a[0], (Boolean) a[1], (Integer) a[2])
a -> new BertTokenization(
(Boolean) a[0],
(Boolean) a[1],
(Integer) a[2],
a[3] == null ? null : Truncate.fromString((String)a[3])
)
);
Tokenization.declareCommonFields(parser);
return parser;
Expand All @@ -38,8 +43,13 @@ public static BertTokenization fromXContent(XContentParser parser, boolean lenie
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

public BertTokenization(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength) {
super(doLowerCase, withSpecialTokens, maxSequenceLength);
public BertTokenization(
@Nullable Boolean doLowerCase,
@Nullable Boolean withSpecialTokens,
@Nullable Integer maxSequenceLength,
@Nullable Truncate truncate
) {
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate);
}

public BertTokenization(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,82 @@
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

public abstract class Tokenization implements NamedXContentObject, NamedWriteable {

public enum Truncate {
FIRST,
SECOND,
NONE;

public static Truncate fromString(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}

//TODO add global params like never_split, bos_token, eos_token, mask_token, tokenize_chinese_chars, strip_accents, etc.
public static final ParseField DO_LOWER_CASE = new ParseField("do_lower_case");
public static final ParseField WITH_SPECIAL_TOKENS = new ParseField("with_special_tokens");
public static final ParseField MAX_SEQUENCE_LENGTH = new ParseField("max_sequence_length");
public static final ParseField TRUNCATE = new ParseField("truncate");

private static final int DEFAULT_MAX_SEQUENCE_LENGTH = 512;
private static final boolean DEFAULT_DO_LOWER_CASE = false;
private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true;
private static final Truncate DEFAULT_TRUNCATION = Truncate.FIRST;

static <T extends Tokenization> void declareCommonFields(ConstructingObjectParser<T, ?> parser) {
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE);
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), WITH_SPECIAL_TOKENS);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_SEQUENCE_LENGTH);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), TRUNCATE);
}

public static BertTokenization createDefault() {
return new BertTokenization(null, null, null);
return new BertTokenization(null, null, null, Truncate.FIRST);
}

protected final boolean doLowerCase;
protected final boolean withSpecialTokens;
protected final int maxSequenceLength;

Tokenization(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength) {
protected final Truncate truncate;

Tokenization(
@Nullable Boolean doLowerCase,
@Nullable Boolean withSpecialTokens,
@Nullable Integer maxSequenceLength,
@Nullable Truncate truncate
) {
if (maxSequenceLength != null && maxSequenceLength <= 0) {
throw new IllegalArgumentException("[" + MAX_SEQUENCE_LENGTH.getPreferredName() + "] must be positive");
}
this.doLowerCase = Optional.ofNullable(doLowerCase).orElse(DEFAULT_DO_LOWER_CASE);
this.withSpecialTokens = Optional.ofNullable(withSpecialTokens).orElse(DEFAULT_WITH_SPECIAL_TOKENS);
this.maxSequenceLength = Optional.ofNullable(maxSequenceLength).orElse(DEFAULT_MAX_SEQUENCE_LENGTH);
this.truncate = Optional.ofNullable(truncate).orElse(DEFAULT_TRUNCATION);
}

public Tokenization(StreamInput in) throws IOException {
this.doLowerCase = in.readBoolean();
this.withSpecialTokens = in.readBoolean();
this.maxSequenceLength = in.readVInt();
this.truncate = in.readEnum(Truncate.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(doLowerCase);
out.writeBoolean(withSpecialTokens);
out.writeVInt(maxSequenceLength);
out.writeEnum(truncate);
}

abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException;
Expand All @@ -75,6 +103,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(DO_LOWER_CASE.getPreferredName(), doLowerCase);
builder.field(WITH_SPECIAL_TOKENS.getPreferredName(), withSpecialTokens);
builder.field(MAX_SEQUENCE_LENGTH.getPreferredName(), maxSequenceLength);
builder.field(TRUNCATE.getPreferredName(), truncate.toString());
builder = doXContentBody(builder, params);
builder.endObject();
return builder;
Expand All @@ -87,12 +116,13 @@ public boolean equals(Object o) {
Tokenization that = (Tokenization) o;
return doLowerCase == that.doLowerCase
&& withSpecialTokens == that.withSpecialTokens
&& truncate == that.truncate
&& maxSequenceLength == that.maxSequenceLength;
}

@Override
public int hashCode() {
return Objects.hash(doLowerCase, withSpecialTokens, maxSequenceLength);
return Objects.hash(doLowerCase, truncate, withSpecialTokens, maxSequenceLength);
}

public boolean doLowerCase() {
Expand All @@ -106,4 +136,8 @@ public boolean withSpecialTokens() {
public int maxSequenceLength() {
return maxSequenceLength;
}

public Truncate getTruncate() {
return truncate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public static BertTokenization createRandom() {
return new BertTokenization(
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomIntBetween(1, 1024)
randomBoolean() ? null : randomIntBetween(1, 1024),
randomBoolean() ? null : randomFrom(Tokenization.Truncate.values())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
Expand Down Expand Up @@ -276,7 +277,9 @@ private void putAndStartModelDeployment(String modelId, long memoryUse, Allocati
new PutTrainedModelAction.Request(
TrainedModelConfig.builder()
.setModelType(TrainedModelType.PYTORCH)
.setInferenceConfig(new PassThroughConfig(null, new BertTokenization(null, false, null), null))
.setInferenceConfig(
new PassThroughConfig(null, new BertTokenization(null, false, null, Tokenization.Truncate.NONE), null)
)
.setModelId(modelId)
.build(),
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.action.ingest.DeletePipelineRequest;
import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.xcontent.XContentType;
Expand All @@ -31,10 +30,9 @@
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
Expand Down Expand Up @@ -200,7 +198,7 @@ void createModelDeployment() {
.setInferenceConfig(
new PassThroughConfig(
null,
new BertTokenization(null, false, null),
new BertTokenization(null, false, null, Tokenization.Truncate.NONE),
null
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.junit.Before;
Expand Down Expand Up @@ -72,7 +73,7 @@ public void testPutTrainedModelAndDefinition() {
new VocabularyConfig(
InferenceIndexConstants.nativeDefinitionStore()
),
new BertTokenization(null, false, null),
new BertTokenization(null, false, null, Tokenization.Truncate.NONE),
null
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class BertTokenizer implements NlpTokenizer {
private final boolean doTokenizeCjKChars;
private final boolean doStripAccents;
private final boolean withSpecialTokens;
private final Tokenization.Truncate truncate;
private final Set<String> neverSplit;
private final int maxSequenceLength;
private final NlpTask.RequestBuilder requestBuilder;
Expand All @@ -62,6 +63,7 @@ protected BertTokenizer(List<String> originalVocab,
boolean doTokenizeCjKChars,
boolean doStripAccents,
boolean withSpecialTokens,
Tokenization.Truncate truncate,
int maxSequenceLength,
Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
Set<String> neverSplit) {
Expand All @@ -72,6 +74,7 @@ protected BertTokenizer(List<String> originalVocab,
this.doTokenizeCjKChars = doTokenizeCjKChars;
this.doStripAccents = doStripAccents;
this.withSpecialTokens = withSpecialTokens;
this.truncate = truncate;
this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
this.maxSequenceLength = maxSequenceLength;
this.requestBuilder = requestBuilderFactory.apply(this);
Expand Down Expand Up @@ -113,6 +116,21 @@ public TokenizationResult.Tokenization tokenize(String seq) {
List<WordPieceTokenizer.TokenAndId> wordPieceTokens = innerResult.v1();
List<Integer> tokenPositionMap = innerResult.v2();
int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
if (numTokens > maxSequenceLength) {
switch (truncate) {
case FIRST:
case SECOND:
wordPieceTokens = wordPieceTokens.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
break;
case NONE:
throw ExceptionsHelper.badRequestException(
"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
numTokens,
maxSequenceLength
);
}
numTokens = maxSequenceLength;
}
String[] tokens = new String[numTokens];
int[] tokenIds = new int[numTokens];
int[] tokenMap = new int[numTokens];
Expand All @@ -128,7 +146,7 @@ public TokenizationResult.Tokenization tokenize(String seq) {
for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) {
tokens[i] = tokenAndId.getToken();
tokenIds[i] = tokenAndId.getId();
tokenMap[i] = tokenPositionMap.get(i-decrementHandler);
tokenMap[i] = tokenPositionMap.get(i - decrementHandler);
i++;
}

Expand All @@ -138,13 +156,6 @@ public TokenizationResult.Tokenization tokenize(String seq) {
tokenMap[i] = SPECIAL_TOKEN_POSITION;
}

if (tokenIds.length > maxSequenceLength) {
throw ExceptionsHelper.badRequestException(
"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
tokenIds.length,
maxSequenceLength
);
}
return new TokenizationResult.Tokenization(seq, tokens, tokenIds, tokenMap);
}

Expand All @@ -161,6 +172,44 @@ public TokenizationResult.Tokenization tokenize(String seq1, String seq2) {
}
// [CLS] seq1 [SEP] seq2 [SEP]
int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3;

if (numTokens > maxSequenceLength) {
switch (truncate) {
case FIRST:
if (wordPieceTokenSeq2s.size() > maxSequenceLength - 3) {
throw ExceptionsHelper.badRequestException(
"Attempting truncation [{}] but input is too large for the second sequence. " +
"The tokenized input length [{}] exceeds the maximum sequence length [{}], " +
"when taking special tokens into account",
truncate.toString(),
wordPieceTokenSeq2s.size(),
maxSequenceLength - 3
);
}
wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq2s.size());
break;
case SECOND:
if (wordPieceTokenSeq1s.size() > maxSequenceLength - 3) {
throw ExceptionsHelper.badRequestException(
"Attempting truncation [{}] but input is too large for the first sequence. " +
"The tokenized input length [{}] exceeds the maximum sequence length [{}], " +
"when taking special tokens into account",
truncate.toString(),
wordPieceTokenSeq2s.size(),
maxSequenceLength - 3
);
}
wordPieceTokenSeq2s = wordPieceTokenSeq2s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq1s.size());
break;
case NONE:
throw ExceptionsHelper.badRequestException(
"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
numTokens,
maxSequenceLength
);
}
numTokens = maxSequenceLength;
}
String[] tokens = new String[numTokens];
int[] tokenIds = new int[numTokens];
int[] tokenMap = new int[numTokens];
Expand Down Expand Up @@ -247,6 +296,7 @@ public static class Builder {
protected boolean doLowerCase = false;
protected boolean doTokenizeCjKChars = true;
protected boolean withSpecialTokens = true;
protected Tokenization.Truncate truncate = Tokenization.Truncate.FIRST;
protected int maxSequenceLength;
protected Boolean doStripAccents = null;
protected Set<String> neverSplit;
Expand All @@ -258,6 +308,7 @@ protected Builder(List<String> vocab, Tokenization tokenization) {
this.doLowerCase = tokenization.doLowerCase();
this.withSpecialTokens = tokenization.withSpecialTokens();
this.maxSequenceLength = tokenization.maxSequenceLength();
this.truncate = tokenization.getTruncate();
}

private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
Expand Down Expand Up @@ -308,6 +359,11 @@ public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestB
return this;
}

public Builder setTruncate(Tokenization.Truncate truncate) {
this.truncate = truncate;
return this;
}

public BertTokenizer build() {
// if not set strip accents defaults to the value of doLowerCase
if (doStripAccents == null) {
Expand All @@ -325,6 +381,7 @@ public BertTokenizer build() {
doTokenizeCjKChars,
doStripAccents,
withSpecialTokens,
truncate,
maxSequenceLength,
requestBuilderFactory,
neverSplit
Expand Down

0 comments on commit 24c659e

Please sign in to comment.