diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java index 35e6639bfe08..87f91a2904a7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertMaskedLMMasker.java @@ -8,7 +8,15 @@ import java.util.Random; /** - * Created by Alex on 03/04/2019. + * A standard/default {@link BertSequenceMasker}. Implements masking as per the BERT paper: + * https://arxiv.org/abs/1810.04805 + * That is, each token is chosen to be masked independently with some probability "maskProb". + * For tokens that are masked, 3 possibilities:
+ * 1. They are replaced with the mask token (such as "[MASK]") in the input, with probability "maskTokenProb"
+ * 2. They are replaced with a random word from the vocabulary, with probability "randomTokenProb"
+ * 3. They are are left unmodified with probability 1.0 - maskTokenProb - randomTokenProb
+ * + * @author Alex Black */ public class BertMaskedLMMasker implements BertSequenceMasker { public static final double DEFAULT_MASK_PROB = 0.15; @@ -18,18 +26,28 @@ public class BertMaskedLMMasker implements BertSequenceMasker { protected final Random r; protected final double maskProb; protected final double maskTokenProb; - protected final double randomWordProb; + protected final double randomTokenProb; + /** + * Create a BertMaskedLMMasker with all default probabilities + */ public BertMaskedLMMasker(){ this(new Random(), DEFAULT_MASK_PROB, DEFAULT_MASK_TOKEN_PROB, DEFAULT_RANDOM_WORD_PROB); } - public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomWordProb){ + /** + * See: {@link BertMaskedLMMasker} for details. + * @param r Random number generator + * @param maskProb Probability of masking each token + * @param maskTokenProb Probability of replacing a selected token with the mask token + * @param randomTokenProb Probability of replacing a selected token with a random token + */ + public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomTokenProb){ Preconditions.checkArgument(maskProb > 0 && maskProb < 1, "Probability must be beteen 0 and 1, got %s", maskProb); this.r = r; this.maskProb = maskProb; this.maskTokenProb = maskTokenProb; - this.randomWordProb = randomWordProb; + this.randomTokenProb = randomTokenProb; } @Override @@ -42,7 +60,7 @@ public Pair,boolean[]> maskSequence(List input, String mask double d = r.nextDouble(); if(d < maskTokenProb){ out.add(maskToken); - } else if(d < maskTokenProb + randomWordProb){ + } else if(d < maskTokenProb + randomTokenProb){ //Randomly select a token... String random = vocabWords.get(r.nextInt(vocabWords.size())); out.add(random); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java index e4b9e81e51ee..81b006d1341d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/bert/BertSequenceMasker.java @@ -1,13 +1,25 @@ package org.deeplearning4j.iterator.bert; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import java.util.List; - +/** + * Interface used to customize how masking should be performed with {@link org.deeplearning4j.iterator.BertIterator} + * when doing unsupervised training + * + * @author Alex Black + */ public interface BertSequenceMasker { + /** + * + * @param input Input sequence of tokens + * @param maskToken Token to use for masking - usually something like "[MASK]" + * @param vocabWords Vocabulary, as a list + * @return Pair: The new input tokens (after masking out), along with a boolean[] for whether the token is + * masked or not (same length as number of tokens). boolean[i] is true if token i was masked. + */ Pair,boolean[]> maskSequence(List input, String maskToken, List vocabWords); }