Skip to content

Commit

Permalink
[ML] Allow NLP truncate option to be updated when span is set (#91224)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 2, 2022
1 parent ecae222 commit defa765
Show file tree
Hide file tree
Showing 9 changed files with 407 additions and 189 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/91224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 91224
summary: Allow NLP truncate option to be updated when span is set
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public abstract class AbstractTokenizationUpdate implements TokenizationUpdate {

private final Tokenization.Truncate truncate;
private final Integer span;

protected static void declareCommonParserFields(ConstructingObjectParser<? extends AbstractTokenizationUpdate, Void> parser) {
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
}

public AbstractTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
}

public AbstractTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

public Integer getSpan() {
return span;
}

public Tokenization.Truncate getTruncate() {
return truncate;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o instanceof AbstractTokenizationUpdate == false) {
return false;
}
AbstractTokenizationUpdate that = (AbstractTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class BertTokenizationUpdate implements TokenizationUpdate {
public class BertTokenizationUpdate extends AbstractTokenizationUpdate {

public static final ParseField NAME = BertTokenization.NAME;

Expand All @@ -31,29 +27,19 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
declareCommonParserFields(PARSER);
}

public static BertTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
super(truncate, span);
}

public BertTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
super(in);
}

@Override
Expand All @@ -66,65 +52,41 @@ public Tokenization apply(Tokenization originalConfig) {
);
}

Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());

if (isNoop()) {
return originalConfig;
}

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new BertTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
getTruncate(),
null
);
}

return new BertTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
Optional.ofNullable(getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(getSpan()).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return BertTokenization.NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
public String getName() {
return BertTokenization.NAME.getPreferredName();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BertTokenizationUpdate that = (BertTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class MPNetTokenizationUpdate implements TokenizationUpdate {
public class MPNetTokenizationUpdate extends AbstractTokenizationUpdate {

public static final ParseField NAME = MPNetTokenization.NAME;

Expand All @@ -31,29 +27,19 @@ public class MPNetTokenizationUpdate implements TokenizationUpdate {
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
declareCommonParserFields(PARSER);
}

public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
super(truncate, span);
}

public MPNetTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
super(in);
}

@Override
Expand All @@ -70,61 +56,35 @@ public Tokenization apply(Tokenization originalConfig) {
return originalConfig;
}

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new MPNetTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
getTruncate(),
null
);
}

return new MPNetTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return MPNetTokenization.NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
public String getName() {
return MPNetTokenization.NAME.getPreferredName();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}

0 comments on commit defa765

Please sign in to comment.