Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deal with the duplicated positions in generator #41

Closed
zheyuye opened this issue Apr 14, 2020 · 2 comments
Closed

Deal with the duplicated positions in generator #41

zheyuye opened this issue Apr 14, 2020 · 2 comments

Comments

@zheyuye
Copy link

zheyuye commented Apr 14, 2020

Here, the corrupted tokens are produced in generator as fake data. I can understand why we should deal with the duplicated positions and only appy it once. However, I am confused about the below implementation take average value of corrupted token ids of duplicated, and what's the intuition behind it?

if sequence.dtype == tf.float32:
updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
updates /= tf.maximum(1.0, updates_mask_3d)

@clarkkev
Copy link
Collaborator

clarkkev commented Apr 14, 2020

scatter_nd sums up values with the same index, but we want to just pick a single value per index. If the values for an index are the same (e.g., we are just scattering a mask tensor of 1s), then dividing the summed values by the number of occurrences at that index fixes the issue. That's what is implemented in the code you linked to.

However, this does NOT fully fix having duplicate mask positions. It does stop there from being errors due to overflowing above the vocab size, but if (1) the same position is sampled twice and (2) the generator samples different tokens for the position then the replaced token will be a "random" token obtained by averaging the two sampled token ids. I didn't bother fixing this issue when developing ELECTRA because this occurs for a pretty small fraction of masked positions. But there actually is an easy fix: replace the sampling step (masked_lm_positions = tf.random.categorical... in pretrain_helpers.mask) with

masked_lm_positions = tf.math.top_k(
      sample_logits + sample_gumbel(
          modeling.get_shape_list(sample_logits)), N)[1]

This will do sampling without replacement when picking mask positions.

@zheyuye
Copy link
Author

zheyuye commented Apr 15, 2020

That is clear. Thank you for the explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants