Skip to content

Commit

Permalink
[ML] add windowing support for text_classification (#83989)
Browse files Browse the repository at this point in the history
This commit adds initial windowing support for text_classification tasks.

Specifically, a user can now indicate a span (non-negative) indicating the tokenization windowing span when creating
sub-sequences.

Default value is span: -1 indicates that no windowing should take place.
  • Loading branch information
benwtrent committed Mar 1, 2022
1 parent beb7c9e commit 45deac4
Show file tree
Hide file tree
Showing 48 changed files with 634 additions and 179 deletions.
4 changes: 4 additions & 0 deletions docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
[%collapsible%open]
=======

`span`::::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-span]

`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
Expand Down
11 changes: 11 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,17 @@ Specifies if the tokenization lower case the text sequence when building the
tokens.
end::inference-config-nlp-tokenization-bert-do-lower-case[]

tag::inference-config-nlp-tokenization-bert-span[]
When `truncate` is `none`, you can partition longer text sequences
for inference. The value indicates how many tokens overlap between each
subsequence.
+
The default value is `-1`, indicating no windowing or spanning occurs.
+
NOTE: When your typical input is just slightly larger than `max_sequence_length`, it may be best to simply truncate;
there will be very little information in the second subsequence.
end::inference-config-nlp-tokenization-bert-span[]

tag::inference-config-nlp-tokenization-bert-truncate[]
Indicates how tokens are truncated when they exceed `max_sequence_length`.
The default value is `first`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
`span`::::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-span]
`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
Expand All @@ -469,6 +473,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
`span`::::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-span]
`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]

`span`::::
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-span]

`truncate`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public static ConstructingObjectParser<BertTokenization, Void> createParser(bool
(Boolean) a[0],
(Boolean) a[1],
(Integer) a[2],
a[3] == null ? null : Truncate.fromString((String) a[3])
a[3] == null ? null : Truncate.fromString((String) a[3]),
(Integer) a[4]
)
);
Tokenization.declareCommonFields(parser);
Expand All @@ -47,9 +48,10 @@ public BertTokenization(
@Nullable Boolean doLowerCase,
@Nullable Boolean withSpecialTokens,
@Nullable Integer maxSequenceLength,
@Nullable Truncate truncate
@Nullable Truncate truncate,
@Nullable Integer span
) {
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate);
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate, span);
}

public BertTokenization(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
Expand All @@ -18,32 +19,41 @@

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class BertTokenizationUpdate implements TokenizationUpdate {

public static final ParseField NAME = BertTokenization.NAME;

public static ConstructingObjectParser<BertTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
"bert_tokenization_update",
a -> new BertTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]))
a -> new BertTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]), (Integer) a[1])
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
}

public static BertTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate) {
public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
}

public BertTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
}

@Override
Expand All @@ -64,19 +74,25 @@ public Tokenization apply(Tokenization originalConfig) {
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
this.truncate
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null;
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}
Expand All @@ -89,6 +105,9 @@ public String getWriteableName() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
Expand All @@ -101,11 +120,11 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BertTokenizationUpdate that = (BertTokenizationUpdate) o;
return truncate == that.truncate;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate);
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ public FillMaskConfig(
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_RESULTS : numTopClasses;
this.resultsField = resultsField;
if (this.tokenization.span != -1) {
throw ExceptionsHelper.badRequestException(
"[{}] does not support windowing long text sequences; configured span [{}]",
NAME,
this.tokenization.span
);
}
}

public FillMaskConfig(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public static ConstructingObjectParser<MPNetTokenization, Void> createParser(boo
(Boolean) a[0],
(Boolean) a[1],
(Integer) a[2],
a[3] == null ? null : Truncate.fromString((String) a[3])
a[3] == null ? null : Truncate.fromString((String) a[3]),
(Integer) a[4]
)
);
Tokenization.declareCommonFields(parser);
Expand All @@ -47,9 +48,10 @@ public MPNetTokenization(
@Nullable Boolean doLowerCase,
@Nullable Boolean withSpecialTokens,
@Nullable Integer maxSequenceLength,
@Nullable Truncate truncate
@Nullable Truncate truncate,
@Nullable Integer span
) {
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate);
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate, span);
}

public MPNetTokenization(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
Expand All @@ -18,32 +19,41 @@

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class MPNetTokenizationUpdate implements TokenizationUpdate {

public static final ParseField NAME = MPNetTokenization.NAME;

public static ConstructingObjectParser<MPNetTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
"mpnet_tokenization_update",
a -> new MPNetTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]))
a -> new MPNetTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]), (Integer) a[1])
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
}

public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate) {
public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
}

public MPNetTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
}

@Override
Expand All @@ -64,19 +74,25 @@ public Tokenization apply(Tokenization originalConfig) {
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
this.truncate
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null;
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}
Expand All @@ -89,6 +105,9 @@ public String getWriteableName() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
Expand All @@ -101,11 +120,11 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
return truncate == that.truncate;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate);
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ public NerConfig(
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
this.resultsField = resultsField;
if (this.tokenization.span != -1) {
throw ExceptionsHelper.badRequestException(
"[{}] does not support windowing long text sequences; configured span [{}]",
NAME,
this.tokenization.span
);
}
}

public NerConfig(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import java.io.IOException;
import java.util.Map;
import java.util.function.Function;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

public abstract class NlpConfigUpdate implements InferenceConfigUpdate, NamedXContentObject {
Expand All @@ -31,15 +32,15 @@ public static TokenizationUpdate tokenizationFromMap(Map<String, Object> map) {
return null;
}

Map<String, Function<Tokenization.Truncate, TokenizationUpdate>> knownTokenizers = Map.of(
Map<String, BiFunction<Tokenization.Truncate, Integer, TokenizationUpdate>> knownTokenizers = Map.of(
BertTokenization.NAME.getPreferredName(),
BertTokenizationUpdate::new,
MPNetTokenization.NAME.getPreferredName(),
MPNetTokenizationUpdate::new
);

Map<String, Object> tokenizationConfig = null;
Function<Tokenization.Truncate, TokenizationUpdate> updater = null;
BiFunction<Tokenization.Truncate, Integer, TokenizationUpdate> updater = null;
for (var tokenizerType : knownTokenizers.keySet()) {
tokenizationConfig = (Map<String, Object>) tokenization.remove(tokenizerType);
if (tokenizationConfig != null) {
Expand All @@ -55,11 +56,17 @@ public static TokenizationUpdate tokenizationFromMap(Map<String, Object> map) {
tokenization.keySet()
);
}
Object truncate = tokenizationConfig.remove("truncate");
if (truncate == null) {
if (tokenizationConfig == null) {
return null;
}
return updater.apply(Tokenization.Truncate.fromString(truncate.toString()));
Tokenization.Truncate truncate = Optional.ofNullable(tokenizationConfig.remove("truncate"))
.map(t -> Tokenization.Truncate.fromString(t.toString()))
.orElse(null);
Integer span = (Integer) Optional.ofNullable(tokenizationConfig.remove("span")).orElse(null);
if (truncate == null && span == null) {
return null;
}
return updater.apply(truncate, span);
}

protected final TokenizationUpdate tokenizationUpdate;
Expand Down

0 comments on commit 45deac4

Please sign in to comment.