diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java
index 35e6639bfe08..87f91a2904a7 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java
@@ -8,7 +8,15 @@
import java.util.Random;
/**
- * Created by Alex on 03/04/2019.
+ * A standard/default {@link BertSequenceMasker}. Implements masking as per the BERT paper:
+ * https://arxiv.org/abs/1810.04805
+ * That is, each token is chosen to be masked independently with some probability "maskProb".
+ * For tokens that are masked, 3 possibilities:
+ * 1. They are replaced with the mask token (such as "[MASK]") in the input, with probability "maskTokenProb"
+ * 2. They are replaced with a random word from the vocabulary, with probability "randomTokenProb"
+ * 3. They are are left unmodified with probability 1.0 - maskTokenProb - randomTokenProb
+ *
+ * @author Alex Black
*/
public class BertMaskedLMMasker implements BertSequenceMasker {
public static final double DEFAULT_MASK_PROB = 0.15;
@@ -18,18 +26,28 @@ public class BertMaskedLMMasker implements BertSequenceMasker {
protected final Random r;
protected final double maskProb;
protected final double maskTokenProb;
- protected final double randomWordProb;
+ protected final double randomTokenProb;
+ /**
+ * Create a BertMaskedLMMasker with all default probabilities
+ */
public BertMaskedLMMasker(){
this(new Random(), DEFAULT_MASK_PROB, DEFAULT_MASK_TOKEN_PROB, DEFAULT_RANDOM_WORD_PROB);
}
- public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomWordProb){
+ /**
+ * See: {@link BertMaskedLMMasker} for details.
+ * @param r Random number generator
+ * @param maskProb Probability of masking each token
+ * @param maskTokenProb Probability of replacing a selected token with the mask token
+ * @param randomTokenProb Probability of replacing a selected token with a random token
+ */
+ public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomTokenProb){
Preconditions.checkArgument(maskProb > 0 && maskProb < 1, "Probability must be beteen 0 and 1, got %s", maskProb);
this.r = r;
this.maskProb = maskProb;
this.maskTokenProb = maskTokenProb;
- this.randomWordProb = randomWordProb;
+ this.randomTokenProb = randomTokenProb;
}
@Override
@@ -42,7 +60,7 @@ public Pair,boolean[]> maskSequence(List input, String mask
double d = r.nextDouble();
if(d < maskTokenProb){
out.add(maskToken);
- } else if(d < maskTokenProb + randomWordProb){
+ } else if(d < maskTokenProb + randomTokenProb){
//Randomly select a token...
String random = vocabWords.get(r.nextInt(vocabWords.size()));
out.add(random);
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java
index e4b9e81e51ee..81b006d1341d 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java
@@ -1,13 +1,25 @@
package org.deeplearning4j.iterator.bert;
-import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.List;
-
+/**
+ * Interface used to customize how masking should be performed with {@link org.deeplearning4j.iterator.BertIterator}
+ * when doing unsupervised training
+ *
+ * @author Alex Black
+ */
public interface BertSequenceMasker {
+ /**
+ *
+ * @param input Input sequence of tokens
+ * @param maskToken Token to use for masking - usually something like "[MASK]"
+ * @param vocabWords Vocabulary, as a list
+ * @return Pair: The new input tokens (after masking out), along with a boolean[] for whether the token is
+ * masked or not (same length as number of tokens). boolean[i] is true if token i was masked.
+ */
Pair,boolean[]> maskSequence(List input, String maskToken, List vocabWords);
}