Skip to content

Commit

Permalink
More poling, javadoc
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Apr 3, 2019
1 parent 56cd14d commit 1d06db3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
Expand Up @@ -8,7 +8,15 @@
import java.util.Random;

/**
* Created by Alex on 03/04/2019.
* A standard/default {@link BertSequenceMasker}. Implements masking as per the BERT paper:
* <a href="https://arxiv.org/abs/1810.04805">https://arxiv.org/abs/1810.04805</a>
* That is, each token is chosen to be masked independently with some probability "maskProb".
* For tokens that are masked, 3 possibilities:<br>
* 1. They are replaced with the mask token (such as "[MASK]") in the input, with probability "maskTokenProb"<br>
* 2. They are replaced with a random word from the vocabulary, with probability "randomTokenProb"<br>
* 3. They are are left unmodified with probability 1.0 - maskTokenProb - randomTokenProb<br>
*
* @author Alex Black
*/
public class BertMaskedLMMasker implements BertSequenceMasker {
public static final double DEFAULT_MASK_PROB = 0.15;
Expand All @@ -18,18 +26,28 @@ public class BertMaskedLMMasker implements BertSequenceMasker {
protected final Random r;
protected final double maskProb;
protected final double maskTokenProb;
protected final double randomWordProb;
protected final double randomTokenProb;

/**
* Create a BertMaskedLMMasker with all default probabilities
*/
public BertMaskedLMMasker(){
this(new Random(), DEFAULT_MASK_PROB, DEFAULT_MASK_TOKEN_PROB, DEFAULT_RANDOM_WORD_PROB);
}

public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomWordProb){
/**
* See: {@link BertMaskedLMMasker} for details.
* @param r Random number generator
* @param maskProb Probability of masking each token
* @param maskTokenProb Probability of replacing a selected token with the mask token
* @param randomTokenProb Probability of replacing a selected token with a random token
*/
public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomTokenProb){
Preconditions.checkArgument(maskProb > 0 && maskProb < 1, "Probability must be beteen 0 and 1, got %s", maskProb);
this.r = r;
this.maskProb = maskProb;
this.maskTokenProb = maskTokenProb;
this.randomWordProb = randomWordProb;
this.randomTokenProb = randomTokenProb;
}

@Override
Expand All @@ -42,7 +60,7 @@ public Pair<List<String>,boolean[]> maskSequence(List<String> input, String mask
double d = r.nextDouble();
if(d < maskTokenProb){
out.add(maskToken);
} else if(d < maskTokenProb + randomWordProb){
} else if(d < maskTokenProb + randomTokenProb){
//Randomly select a token...
String random = vocabWords.get(r.nextInt(vocabWords.size()));
out.add(random);
Expand Down
@@ -1,13 +1,25 @@
package org.deeplearning4j.iterator.bert;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;

import java.util.List;


/**
* Interface used to customize how masking should be performed with {@link org.deeplearning4j.iterator.BertIterator}
* when doing unsupervised training
*
* @author Alex Black
*/
public interface BertSequenceMasker {

/**
*
* @param input Input sequence of tokens
* @param maskToken Token to use for masking - usually something like "[MASK]"
* @param vocabWords Vocabulary, as a list
* @return Pair: The new input tokens (after masking out), along with a boolean[] for whether the token is
* masked or not (same length as number of tokens). boolean[i] is true if token i was masked.
*/
Pair<List<String>,boolean[]> maskSequence(List<String> input, String maskToken, List<String> vocabWords);

}

0 comments on commit 1d06db3

Please sign in to comment.