Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task #6644

Merged
merged 7 commits into from
Aug 31, 2020
Merged

Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task #6644

merged 7 commits into from
Aug 31, 2020

Conversation

mojave-pku
Copy link
Contributor

Add DataCollatorForNextSencencePrediction and TextDatasetForNextSencencePrediction to support mlm and next sentence prediction objectives together.

@codecov
Copy link

codecov bot commented Aug 21, 2020

Codecov Report

Merging #6644 into master will decrease coverage by 0.27%.
The diff coverage is 12.40%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6644      +/-   ##
==========================================
- Coverage   79.64%   79.36%   -0.28%     
==========================================
  Files         157      156       -1     
  Lines       28564    28384     -180     
==========================================
- Hits        22750    22528     -222     
- Misses       5814     5856      +42     
Impacted Files Coverage Δ
src/transformers/__init__.py 99.28% <ø> (ø)
...rc/transformers/data/datasets/language_modeling.py 56.97% <10.81%> (-34.86%) ⬇️
src/transformers/data/data_collator.py 57.14% <12.12%> (-32.57%) ⬇️
src/transformers/data/datasets/__init__.py 100.00% <100.00%> (ø)
src/transformers/tokenization_marian.py 66.66% <0.00%> (-32.50%) ⬇️
src/transformers/tokenization_reformer.py 81.66% <0.00%> (-13.34%) ⬇️
src/transformers/tokenization_xlm_roberta.py 84.52% <0.00%> (-10.72%) ⬇️
src/transformers/benchmark/benchmark_tf.py 65.03% <0.00%> (-0.49%) ⬇️
src/transformers/training_args.py 91.26% <0.00%> (-0.41%) ⬇️
src/transformers/benchmark/benchmark.py 81.88% <0.00%> (-0.29%) ⬇️
... and 133 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 41aa2b4...ec89daf. Read the comment docs.

@choidongyeon
Copy link
Contributor

Hey so I have a PR out for the same task: #6376

I'm mostly just writing this comment so that I can keep track of what the reviewers have to say and what happens with the NSP task.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, just one nit before we can merge.

nsp_probability: float = 0.5
mlm_probability: float = 0.15

def __call__(self, examples: List[List[List[int]]]) -> Dict[str, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we support dict as well since the nlp library will return that?

@mojave-pku
Copy link
Contributor Author

Hi, @sgugger ! I add dict inputs support like DataCollatorForLanguageModeling according to your suggestion, but now there is a conflict in src/transformers/__init__.py. Do I need to resolve it or leave it to you?


def __call__(self, examples: List[Union[List[List[int]], Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd need to grab the token_type_ids and the labels too I think.

Copy link
Contributor Author

@mojave-pku mojave-pku Aug 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm a little confused.
The labels you mentioned are nsp/mlm labels, or labels for a specific task?
Since none of data collators in this file grab the token_type_ids and labels, they just take the examples out of the dict, and do nothing else.
And segment_ids are generated in self.create_examples_from_document.
Thank you~

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I was reading this wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, ok~ :-)

@sgugger
Copy link
Collaborator

sgugger commented Aug 24, 2020

I can take care of the final merge once this is all good and @LysandreJik approved, it's due to a new version of isort.

@LysandreJik
Copy link
Member

Could we add a test for this? I just merged master in to make sure it has the latest changes.

@mojave-pku mojave-pku closed this Aug 28, 2020
@mojave-pku mojave-pku reopened this Aug 28, 2020
@mojave-pku
Copy link
Contributor Author

After @LysandreJik merge the master branch, many files need to be reformatted.
To clearly show the codes I modified, I did not include the changes caused by make style of other files in those commits, so check_code_quality will not pass.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks for adding the test!

@LysandreJik LysandreJik merged commit 2de7ee0 into huggingface:master Aug 31, 2020
sgugger added a commit that referenced this pull request Aug 31, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (#6803)

* t5 model should make decoder_attention_mask (#6800)

* [s2s] Test hub configs in self-scheduled CI (#6809)

* [s2s] round runtime in run_eval (#6798)

* Pegasus finetune script: add --adafactor (#6811)

* [bart] rename self-attention -> attention (#6708)

* [tests] fix typos in inputs (#6818)

* Fixed open in colab link (#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (#6827)

* BR_BERTo model card (#6793)

* clearly indicate shuffle=False (#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (#6845)

* Fix resuming training for Windows (#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
stas00 pushed a commit to stas00/transformers that referenced this pull request Sep 1, 2020
…huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
stas00 added a commit to stas00/transformers that referenced this pull request Sep 1, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
…huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
…huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants