/
MxBertQATranslator.java
195 lines (172 loc) · 6.35 KB
/
MxBertQATranslator.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.mxnet.zoo.nlp.qa;
import ai.djl.Model;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.modality.nlp.translator.QATranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;
import com.google.gson.annotations.SerializedName;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/** The translator for MXNet BERT QA model. */
public class MxBertQATranslator extends QATranslator {
private List<String> tokens;
private Vocabulary vocabulary;
private BertTokenizer tokenizer;
private int seqLength;
MxBertQATranslator(Builder builder) {
super(builder);
seqLength = builder.seqLength;
}
/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws IOException {
Model model = ctx.getModel();
vocabulary =
DefaultVocabulary.builder()
.addFromCustomizedFile(
model.getArtifact("vocab.json"), VocabParser::parseToken)
.optUnknownToken("[UNK]")
.build();
tokenizer = new BertTokenizer();
}
/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
// MXNet BertQA model doesn't support batch. See NoBatchifyTranslator.
return null;
}
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, QAInput input) {
BertToken token =
tokenizer.encode(
input.getQuestion().toLowerCase(),
input.getParagraph().toLowerCase(),
seqLength);
tokens = token.getTokens();
List<Long> indices =
token.getTokens().stream().map(vocabulary::getIndex).collect(Collectors.toList());
float[] indexesFloat = Utils.toFloatArray(indices);
float[] types = Utils.toFloatArray(token.getTokenTypes());
int validLength = token.getValidLength();
NDManager manager = ctx.getNDManager();
NDArray data0 = manager.create(indexesFloat);
data0.setName("data0");
NDArray data1 = manager.create(types);
data1.setName("data1");
// avoid to use scalar as MXNet Bert model was trained with 1.5.0
// which is not compatible with MXNet NumPy
NDArray data2 = manager.create(new float[] {validLength});
data2.setName("data2");
return new NDList(data0, data1, data2);
}
/** {@inheritDoc} */
@Override
public String processOutput(TranslatorContext ctx, NDList list) {
NDArray array = list.singletonOrThrow();
NDList output = array.split(2, 2);
// Get the formatted logits result
NDArray startLogits = output.get(0).reshape(new Shape(1, -1));
NDArray endLogits = output.get(1).reshape(new Shape(1, -1));
int startIdx = (int) startLogits.argMax(1).getLong();
int endIdx = (int) endLogits.argMax(1).getLong();
return tokenizer.buildSentence(tokens.subList(startIdx, endIdx + 1));
}
/**
* Creates a builder to build a {@code MxBertQATranslator}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}
/**
* Creates a builder to build a {@code MxBertQATranslator}.
*
* @param arguments the models' arguments
* @return a new builder
*/
public static Builder builder(Map<String, ?> arguments) {
Builder builder = new Builder();
builder.configure(arguments);
builder.setSeqLength(ArgumentsUtil.intValue(arguments, "seqLength", 384));
return builder;
}
/** The builder for Bert QA translator. */
public static class Builder extends BaseBuilder<Builder> {
private int seqLength;
/**
* Set the max length of the sequence to do the padding.
*
* @param seqLength the length of the sequence
* @return builder
*/
public Builder setSeqLength(int seqLength) {
this.seqLength = seqLength;
return self();
}
/**
* Returns the builder.
*
* @return the builder
*/
@Override
protected Builder self() {
return this;
}
/**
* Builds the translator.
*
* @return the new translator
*/
protected MxBertQATranslator build() {
if (seqLength == 0) {
throw new IllegalArgumentException("You must specify a seqLength with value > 0");
}
return new MxBertQATranslator(this);
}
}
private static final class VocabParser {
@SerializedName("idx_to_token")
List<String> idx2token;
public static List<String> parseToken(URL url) {
try (InputStream is = new BufferedInputStream(url.openStream());
Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
return JsonUtils.GSON.fromJson(reader, VocabParser.class).idx2token;
} catch (IOException e) {
throw new IllegalArgumentException("Invalid url: " + url, e);
}
}
}
}