Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax Masked Language Modeling training example #8728

Merged
merged 36 commits into from
Dec 9, 2020
Merged

Conversation

mfuntowicz
Copy link
Member

@mfuntowicz mfuntowicz commented Nov 23, 2020

Include a training example running with Flax/JAX framework. (cc @avital @marcvanzee)

TODOs:

  • Make the collator working with Numpy/JAX array
  • Make sure the training actually works on larger scale
  • Make it possible to train from scratch
  • Support TPU (bfloat16)
  • Support GPU amp (float16)
  • Improve overall UX

# Model forward
# TODO: Remove this conversion by replacing the collator
model_inputs = {var_name: tensor.numpy() for var_name, tensor in model_inputs.items()}
loss, optimizer = training_step(optimizer, model_inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool!

Can you also easily have a learning rate schedule?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yap sure, didn't put it in at first, focusing on making things clear and almost a no-brainer 😄.

Will look at it soon 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now included 👍

for epoch in track(range(int(training_args.num_train_epochs)), description="Training..."):
samples_idx = np.random.choice(len(tokenized_datasets["train"]), (training_args.train_batch_size, ))
samples = [tokenized_datasets["train"][idx.item()] for idx in samples_idx]
model_inputs = data_collator(samples)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your first impression feeling on having a FLAX Trainer with a similar API to the PT Trainer at some point @sgugger?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like it's going to be too hard to build.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, you actually took the one example that is a bit flaky (we merged DataCollatorForWholeWordMasking a bit too fast and the data preprocessing part of this script needs to be completely rewritten as it works for BERT only for now).
Could you do the same with the run_mlm script instead? This one won't change :-)

for epoch in track(range(int(training_args.num_train_epochs)), description="Training..."):
samples_idx = np.random.choice(len(tokenized_datasets["train"]), (training_args.train_batch_size, ))
samples = [tokenized_datasets["train"][idx.item()] for idx in samples_idx]
model_inputs = data_collator(samples)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like it's going to be too hard to build.

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
# You can also adapt this script to your own masked language modeling task. Pointers for this are left as comments.
Suggested change
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.

# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it quite useful to have a comment at the top of a binary with an example command-line command allowing users to run this code directly. What do you think of this?

@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we will never fine-tune with this code, right? At least it looks like the model always is FlaxBertForMaskedLM, which has a pre-training objective.

Comment on lines +141 to +168
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe create some inner function check_file_extension to avoid code duplication?

src/transformers/models/bert/modeling_flax_bert.py Outdated Show resolved Hide resolved
# return -jnp.mean(jnp.sum(one_hot(labels, config.vocab_size) * logits, axis=-1), axis=-1)
#

def cross_entropy(logits, targets, label_smoothing=0.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe factor this function out of training_step to make it easier to read?

examples/language-modeling/run_mlm_flax.py Outdated Show resolved Hide resolved
mfuntowicz and others added 21 commits December 8, 2020 23:37
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
@mfuntowicz mfuntowicz marked this pull request as ready for review December 8, 2020 22:50
mfuntowicz and others added 7 commits December 8, 2020 23:58
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
@mfuntowicz mfuntowicz merged commit 7562714 into master Dec 9, 2020
@mfuntowicz mfuntowicz deleted the flax-lm-example branch December 9, 2020 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants