Skip to content

Commit

Permalink
BERT iterator polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Apr 4, 2019
1 parent 4b552f3 commit dc6d529
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
Expand Up @@ -45,31 +45,61 @@
* (b) Supervised - For sequence classification (i.e., 1 label per sequence, typically used for fine tuning)<br>
* The task can be specified using {@link Task}.
* <br>
* Example for unsupervised training:<br>
* <b>Example for unsupervised training:</b><br>
* <pre>
* {@code
*
* BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab);
* BertIterator b = BertIterator.builder()
* .tokenizer(t)
* .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
* .minibatchSize(2)
* .sentenceProvider(<sentence provider here>)
* .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
* .vocabMap(t.getVocab())
* .task(BertIterator.Task.UNSUPERVISED)
* .masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
* .unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
* .maskToken("[MASK]")
* .build();
* }
* </pre>
* <br>
* <b>Example for supervised (sequence classification - one label per sequence) training:</b><br>
* <pre>
* {@code
* BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab);
* BertIterator b = BertIterator.builder()
* .tokenizer(t)
* .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
* .minibatchSize(2)
* .sentenceProvider(new TestSentenceProvider())
* .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
* .vocabMap(t.getVocab())
* .task(BertIterator.Task.SEQ_CLASSIFICATION)
* .build();
* }
* </pre>
*
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
* <br>
* <u><b>{@link LengthHandling} configuration:</b></u><br>
* Determines how to handle variable-length sequence situations.<br>
* <b>FIXED_LENGTH</b>: Always trim longer sequences to the specified length, and always pad shorter sequences to the specified length.<br>
* <b>ANY_LENGTH</b>: Output length is determined by the length of the longest sequence in the minibatch<br>
* <b>ANY_LENGTH</b>: Output length is determined by the length of the longest sequence in the minibatch. Shorter sequences within the
* minibatch are zero padded and masked.<br>
* <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in
* a minibatch is shorter than the specified maximum, no padding will occur.<br>
* a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the
* maximum (within the current minibatch) they will be zero padded and masked.<br>
*<br><br>
* <u><b>{@link FeatureArrays} configuration:</b></u><br>
* Determines what arrays should be included.
* <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array<br>
* <b>INDICES_MASK</b>: Indices array, mask array and segment ID array (which is all 0s for single segment tasks)<br>
* Determines what arrays should be included.<br>
* <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br>
* <b>INDICES_MASK_SEGMENTID</b>: Indices array, mask array and segment ID array (which is all 0s for single segment tasks). Returns
* 2 feature arrays (indices, segment ID) and 1 feature mask array (plus labels)<br>
* <br>
* <u><b>{@link UnsupervisedLabelFormat} configuration:</b></u><br>
* Only relevant when the task is set to {@link Task#UNSUPERVISED}. Determine the format of the labels:<br>
* <b>RANK2_IDX</b>: return int32 [minibatch, numTokens] array with entries being class numbers<br>
* <b>RANK3_NCL</b>: return float32 [minibatch, numClasses, numTokens] array with 1-hot entries along dimension 1<br>
* <b>RANK2_IDX</b>: return int32 [minibatch, numTokens] array with entries being class numbers. Example use case: with sparse softmax loss functions.<br>
* <b>RANK3_NCL</b>: return float32 [minibatch, numClasses, numTokens] array with 1-hot entries along dimension 1. Example use case: RnnOutputLayer, RnnLossLayer<br>
* <b>RANK3_NLC</b>: return float32 [minibatch, numTokens, numClasses] array with 1-hot entries along dimension 2<br>
* <br>
*/
Expand Down Expand Up @@ -423,7 +453,8 @@ public Builder featureArrays(FeatureArrays featureArrays){
}

/**
* Provide the vocabulary as a map. Keys are
* Provide the vocabulary as a map. Keys are the words in the vocabulary, and values are the indices of those
* words. For indices, they should be in range 0 to vocabMap.size()-1 inclusive.<br>
* If using {@link BertWordPieceTokenizerFactory},
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
*/
Expand Down
Expand Up @@ -44,6 +44,10 @@ public BertMaskedLMMasker(){
*/
public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomTokenProb){
Preconditions.checkArgument(maskProb > 0 && maskProb < 1, "Probability must be beteen 0 and 1, got %s", maskProb);
Preconditions.checkState(maskTokenProb >=0 && maskTokenProb <= 1.0, "Mask token probability must be between 0 and 1, got %s", maskTokenProb);
Preconditions.checkState(randomTokenProb >=0 && randomTokenProb <= 1.0, "Random token probability must be between 0 and 1, got %s", randomTokenProb);
Preconditions.checkState(maskTokenProb + randomTokenProb <= 1.0, "Sum of maskTokenProb (%s) and randomTokenProb (%s) must be <= 1.0, got sum is %s",
maskTokenProb, randomTokenProb, (maskTokenProb + randomTokenProb));
this.r = r;
this.maskProb = maskProb;
this.maskTokenProb = maskTokenProb;
Expand Down

0 comments on commit dc6d529

Please sign in to comment.