-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] add support for distilbert pytorch models (#76679)
This commit adds support for distilbert pytorch models. While the tokenization itself is exactly the same as bert, the parameters sent to the model are different. DistilBERT does not require the segment mask or positional IDs to be sent. Only the input mask and token ids. But, since the effective output of the tokenization sent to the model is different, I opted to consider it as a unique tokenizer, inheriting from our bert implementation. The API now looks like: for BERT models ```js "inference_config": { "ner": { "vocabulary": {/*...*/}, "tokenization": { "bert": {/*...*/} } } } ``` For DistilBERT models ```js "inference_config": { "ner": { "vocabulary": {/*...*/}, "tokenization": { "distil_bert": {/*...*/} } } } ```
- Loading branch information
Showing
36 changed files
with
820 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
...rc/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenization.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; | ||
|
||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.common.xcontent.ParseField; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
import org.elasticsearch.core.Nullable; | ||
|
||
import java.io.IOException; | ||
|
||
public class BertTokenization extends Tokenization { | ||
|
||
public static final ParseField NAME = new ParseField("bert"); | ||
|
||
public static ConstructingObjectParser<BertTokenization, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<BertTokenization, Void> parser = new ConstructingObjectParser<>( | ||
"bert_tokenization", | ||
ignoreUnknownFields, | ||
a -> new BertTokenization((Boolean) a[0], (Boolean) a[1], (Integer) a[2]) | ||
); | ||
Tokenization.declareCommonFields(parser); | ||
return parser; | ||
} | ||
|
||
private static final ConstructingObjectParser<BertTokenization, Void> LENIENT_PARSER = createParser(true); | ||
private static final ConstructingObjectParser<BertTokenization, Void> STRICT_PARSER = createParser(false); | ||
|
||
public static BertTokenization fromXContent(XContentParser parser, boolean lenient) { | ||
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); | ||
} | ||
|
||
public BertTokenization(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength) { | ||
super(doLowerCase, withSpecialTokens, maxSequenceLength); | ||
} | ||
|
||
public BertTokenization(StreamInput in) throws IOException { | ||
super(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
} | ||
|
||
XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { | ||
return builder; | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return NAME.getPreferredName(); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME.getPreferredName(); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (o == null || getClass() != o.getClass()) return false; | ||
return super.equals(o); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return super.hashCode(); | ||
} | ||
} |
82 changes: 82 additions & 0 deletions
82
...n/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenization.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; | ||
|
||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.common.xcontent.ParseField; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
import org.elasticsearch.core.Nullable; | ||
|
||
import java.io.IOException; | ||
|
||
public class DistilBertTokenization extends Tokenization { | ||
|
||
public static final ParseField NAME = new ParseField("distil_bert"); | ||
|
||
public static ConstructingObjectParser<DistilBertTokenization, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<DistilBertTokenization, Void> parser = new ConstructingObjectParser<>( | ||
"distil_bert_tokenization", | ||
ignoreUnknownFields, | ||
a -> new DistilBertTokenization((Boolean) a[0], (Boolean) a[1], (Integer) a[2]) | ||
); | ||
Tokenization.declareCommonFields(parser); | ||
return parser; | ||
} | ||
|
||
private static final ConstructingObjectParser<DistilBertTokenization, Void> LENIENT_PARSER = createParser(true); | ||
private static final ConstructingObjectParser<DistilBertTokenization, Void> STRICT_PARSER = createParser(false); | ||
|
||
public static DistilBertTokenization fromXContent(XContentParser parser, boolean lenient) { | ||
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); | ||
} | ||
|
||
public DistilBertTokenization( | ||
@Nullable Boolean doLowerCase, | ||
@Nullable Boolean withSpecialTokens, | ||
@Nullable Integer maxSequenceLength | ||
) { | ||
super(doLowerCase, withSpecialTokens, maxSequenceLength); | ||
} | ||
|
||
public DistilBertTokenization(StreamInput in) throws IOException { | ||
super(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
} | ||
|
||
XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { | ||
return builder; | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return NAME.getPreferredName(); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME.getPreferredName(); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (o == null || getClass() != o.getClass()) return false; | ||
return super.equals(o); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return super.hashCode(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.