Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
topk
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 1, 2020
1 parent 34ee884 commit a5a6475
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
21 changes: 14 additions & 7 deletions scripts/pretraining/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,15 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
num_masked_position = F.np.maximum(
1, F.np.minimum(N, round(valid_lengths * self._mask_prob)))

# The categorical distribution takes normalized probabilities as input
# softmax is used here instead of log_softmax
# Get the masking probability of each position
sample_probs = F.npx.softmax(
self._proposal_distribution * valid_candidates, axis=-1) # (B, L)
masked_positions = F.npx.random.categorical(
sample_probs, shape=N, dtype=np.int32)
sample_probs = F.npx.stop_gradient(sample_probs)
gumbels = F.np.random.gumbel(F.np.zeros_like(sample_probs))
# Following the instruction of official repo to avoid deduplicate postions
# with Top_k Sampling as https://github.com/google-research/electra/issues/41
masked_positions = F.npx.topk(
sample_probs + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32)

masked_weights = F.npx.sequence_mask(
F.np.ones_like(masked_positions),
Expand All @@ -513,14 +516,18 @@ def dynamic_masking(self, F, input_ids, valid_lengths):
unmasked_tokens = select_vectors_by_position(
F, input_ids, masked_positions) * masked_weights
masked_weights = masked_weights.astype(np.float32)

replaced_positions = (
F.np.random.uniform(
F.np.zeros_like(masked_positions),
F.np.ones_like(masked_positions)) > self._mask_prob) * masked_positions
# deal with multiple zeros
filled = F.np.where(replaced_positions, self.vocab.mask_id, masked_positions)
# dealling with multiple zero values in replaced_positions which causes the [CLS] being replaced
filled = F.np.where(replaced_positions, self.vocab.mask_id, self.vocab.cls_id).astype(np.int32)
# Masking token by replacing with [MASK]
masked_input_ids = updated_vectors_by_position(F, input_ids, filled, replaced_positions)

# Note: It is likely have multiple zero values in masked_positions if number of masked of
# positions not reached the maximum. However, this example hardly exists since valid_length
# is almost always equal to max_seq_length
masked_input = self.MaskedInput(input_ids=masked_input_ids,
masks=length_masks,
unmasked_tokens=unmasked_tokens,
Expand Down
5 changes: 4 additions & 1 deletion src/gluonnlp/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log
Returns
-------
corrupted_tokens
Shape (batch_size, )
fake_data
Shape (batch_size, seq_length)
labels
Expand All @@ -839,7 +841,8 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log
inputs, corrupted_tokens, masked_positions)
updates_mask = add_vectors_by_position(F, F.np.zeros_like(inputs),
F.np.ones_like(masked_positions), masked_positions)
# Dealing with duplicate positions
# Dealing with multiple zeros in masked_positions which
# results in a non-zero value in the first index [CLS]
updates_mask = F.np.minimum(updates_mask, 1)
labels = updates_mask * F.np.not_equal(fake_data, original_data)
return corrupted_tokens, fake_data, labels
Expand Down

0 comments on commit a5a6475

Please sign in to comment.