Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only access loss tensor every logging_steps #6802

Merged
merged 23 commits into from
Aug 31, 2020

Conversation

jysohn23
Copy link
Collaborator

  • tensor.item() was being called every step. This must not be done for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU communication at each step. On RoBERTa MLM for example, it reduces step time by 30%, should be larger for smaller step time models/tasks.
  • Train batch size was not correct in case a user uses the per_gpu_train_batch_size flag
  • Avg reduce loss accross eval shards
  • Log TPU debug metrics before last epoch break

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards
@jysohn23 jysohn23 requested a review from julien-c August 28, 2020 20:35
@julien-c julien-c requested a review from sgugger August 28, 2020 20:39
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Instantaneous batch size per device = %d", train_dataloader.batch_size)
Copy link
Contributor

@JetRunner JetRunner Aug 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change necessary? The prompt is Instantaneous batch size per device. Also, this change seems to create inconsistency for logging IMO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If using --per_gpu_train_batch_size (legacy flag for per_device_train_batch_size), the argument is straight up incorrect.
  2. That batch size is per device.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the batch size is not per device. It is on TPUs but not if you a user has a DataParallel model.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @jysohn23

@LysandreJik
Copy link
Member

Style issue will be solved with merge @sgugger

LysandreJik and others added 4 commits August 31, 2020 09:37
…huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! I'd remove the first change in the logs though.

logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Instantaneous batch size per device = %d", train_dataloader.batch_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the batch size is not per device. It is on TPUs but not if you a user has a DataParallel model.

sgugger and others added 4 commits August 31, 2020 10:52
* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards
@jysohn23
Copy link
Collaborator Author

Thanks for fixing this! I'd remove the first change in the logs though.

Thanks for the review! Done.

@julien-c julien-c added the model card Related to pretrained model cards label Aug 31, 2020
@codecov
Copy link

codecov bot commented Aug 31, 2020

Codecov Report

Merging #6802 into master will decrease coverage by 0.42%.
The diff coverage is 89.47%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6802      +/-   ##
==========================================
- Coverage   80.32%   79.89%   -0.43%     
==========================================
  Files         157      157              
  Lines       28589    28739     +150     
==========================================
- Hits        22963    22960       -3     
- Misses       5626     5779     +153     
Impacted Files Coverage Δ
src/transformers/__init__.py 99.28% <ø> (ø)
src/transformers/file_utils.py 82.66% <ø> (+0.25%) ⬆️
src/transformers/modeling_tf_flaubert.py 88.34% <ø> (+63.80%) ⬆️
src/transformers/optimization.py 82.28% <ø> (ø)
src/transformers/tokenization_t5.py 95.28% <ø> (-0.05%) ⬇️
src/transformers/trainer.py 53.23% <46.66%> (-0.43%) ⬇️
...rc/transformers/data/datasets/language_modeling.py 90.69% <89.18%> (-1.14%) ⬇️
src/transformers/data/data_collator.py 91.90% <94.59%> (+2.19%) ⬆️
src/transformers/configuration_pegasus.py 100.00% <100.00%> (ø)
src/transformers/data/datasets/__init__.py 100.00% <100.00%> (ø)
... and 25 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9336086...2b981cd. Read the comment docs.

@sgugger sgugger merged commit 02d09c8 into huggingface:master Aug 31, 2020
stas00 added a commit to stas00/transformers that referenced this pull request Sep 1, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* Fix style (huggingface#6803)

* t5 model should make decoder_attention_mask (huggingface#6800)

* [s2s] Test hub configs in self-scheduled CI (huggingface#6809)

* [s2s] round runtime in run_eval (huggingface#6798)

* Pegasus finetune script: add --adafactor (huggingface#6811)

* [bart] rename self-attention -> attention (huggingface#6708)

* [tests] fix typos in inputs (huggingface#6818)

* Fixed open in colab link (huggingface#6825)

* Add model card for singbert lite. Update widget for singbert and singbert-large. (huggingface#6827)

* BR_BERTo model card (huggingface#6793)

* clearly indicate shuffle=False (huggingface#6312)

* Clarify shuffle

* clarify shuffle

Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>

* [s2s README] Add more dataset download instructions (huggingface#6737)

* Style

* Patch logging issue

* Set default logging level to `WARNING` instead of `INFO`

* TF Flaubert w/ pre-norm (huggingface#6841)

* Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (huggingface#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>

* Fix in Adafactor docstrings (huggingface#6845)

* Fix resuming training for Windows (huggingface#6847)

* Only access loss tensor every logging_steps

* tensor.item() was being called every step. This must not be done
for XLA:TPU tensors as it's terrible for performance causing TPU<>CPU
communication at each step. On RoBERTa MLM for example, it reduces step
time by 30%, should be larger for smaller step time models/tasks.
* Train batch size was not correct in case a user uses the
`per_gpu_train_batch_size` flag
* Avg reduce loss accross eval shards

* comments

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Thomas Ashish Cherian <6967017+PandaWhoCodes@users.noreply.github.com>
Co-authored-by: Zane Lim <zyuanlim@gmail.com>
Co-authored-by: Rodolfo De Nadai <rdenadai@gmail.com>
Co-authored-by: xujiaze13 <37360975+xujiaze13@users.noreply.github.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Huang Lianzhe <hlz@pku.edu.cn>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model card Related to pretrained model cards
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet