Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run_mlm.py: CUDA error: device-side assert triggered, THCTensorIndex #10832

Closed
2 of 4 tasks
matteomedioli opened this issue Mar 21, 2021 · 2 comments
Closed
2 of 4 tasks

Comments

@matteomedioli
Copy link

matteomedioli commented Mar 21, 2021

Environment info

  • transformers version: 4.4.2
  • Platform: Linux
  • Python version: Python 3.4.9
  • PyTorch version (GPU?): 1.6.0+cu101
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes
  • GPU details: 4 GPUs V100 16GB

Information

I am using Bert and Roberta. I'm try to train from scratch on Wikipedia dataset using your examples run_mlm and your dataset wikipedia (20200501.en)
Before using distributed set up, I was stacked on the first optimization step. Without distributed setup I was stack on first optimization steps or received the reported error. With distributed setup I always receive the reported error.

The problem arises when using:

The tasks I am working on is:

  • an official GLUE/SQUaD task: MLM train from scratch Bert and Roberta
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

export CUDA_LAUNCH_BLOCKING=1
export TOKENIZERS_PARALLELISM=true
export OMP_NUM_THREADS=32

source /data/medioli/env/bin/activate

python3 -m torch.distributed.launch \
--nproc_per_node 4 run_mlm.py \
--dataset_name wikipedia \
--tokenizer_name roberta-base \
--model_type roberta \
--dataset_config_name 20200501.en \
--do_train \
--do_eval \
--learning_rate 1e-5 \
--num_train_epochs 5 \
--save_steps 5000 \
--output_dir /data/medioli/models/mlm/wikipedia_roberta_5ep_1e5_lbl \
--line_by_line \
--use_fast_tokenizer \
--logging_dir /data/medioli/models/mlm/wikipedia_roberta_5ep_1e5_lbl/runs \
--cache_dir /data/medioli/datasets/wikipedia/ \
--overwrite_output_dir \

Errors and Output

Many errors like this:

/pytorch/aten/src/THC/THCTensorIndex.cu:272: indexSelectLargeIndex: block: [372,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

Then:

Traceback (most recent call last):
  File "/data/medioli/transformers/examples/language-modeling/run_mlm.py", line 491, in <module>
    main()
  File "/data/medioli/transformers/examples/language-modeling/run_mlm.py", line 457, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/data/medioli/env/lib/python3.6/site-packages/transformers/trainer.py", line 1053, in train
    tr_loss += self.training_step(model, inputs)
  File "/data/medioli/env/lib/python3.6/site-packages/transformers/trainer.py", line 1443, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/medioli/env/lib/python3.6/site-packages/transformers/trainer.py", line 1475, in compute_loss
    outputs = model(**inputs)
  File "/data/medioli/env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/medioli/env/lib64/python3.6/site-packages/torch/nn/parallel/distributed.py", line 511, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/data/medioli/env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/medioli/env/lib/python3.6/site-packages/transformers/models/roberta/modeling_roberta.py", line 1057, in forward
    return_dict=return_dict,
  File "/data/medioli/env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/medioli/env/lib/python3.6/site-packages/transformers/models/roberta/modeling_roberta.py", line 810, in forward
    past_key_values_length=past_key_values_length,
  File "/data/medioli/env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/medioli/env/lib/python3.6/site-packages/transformers/models/roberta/modeling_roberta.py", line 123, in forward
    embeddings += position_embeddings
RuntimeError: CUDA error: device-side assert triggered
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: device-side assert triggered
Exception raised from create_event_internal at /pytorch/c10/cuda/CUDACachingAllocator.cpp:687 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fa4517ed1e2 in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libc10.so)
frame #1: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0xad2 (0x7fa451a3bf92 in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::TensorImpl::release_resources() + 0x4d (0x7fa4517db9cd in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libc10.so)
frame #3: std::vector<c10d::Reducer::Bucket, std::allocator<c10d::Reducer::Bucket> >::~vector() + 0x25a (0x7fa427f8489a in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #4: c10d::Reducer::~Reducer() + 0x28a (0x7fa427f79b1a in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #5: std::_Sp_counted_ptr<c10d::Reducer*, (__gnu_cxx::_Lock_policy)2>::_M_dispose() + 0x12 (0x7fa427f593c2 in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #6: std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() + 0x46 (0x7fa4277577a6 in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0xa6b08b (0x7fa427f5a08b in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x273c00 (0x7fa427762c00 in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x274e4e (0x7fa427763e4e in /data/medioli/env/lib64/python3.6/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #22: main + 0x16e (0x400a3e in /data/medioli/env/bin/python3)
frame #23: __libc_start_main + 0xf5 (0x7fa48f4903d5 in /lib64/libc.so.6)
frame #24: /data/medioli/env/bin/python3() [0x400b02]

Discussion in pytorch: https://discuss.pytorch.org/t/solved-assertion-srcindex-srcselectdimsize-failed-on-gpu-for-torch-cat/1804/22

Who can help me
Models:
@LysandreJik

Library:

Examples:

@sgugger
Copy link
Collaborator

sgugger commented Mar 22, 2021

Hi there! So the problem is a bit complex and linked to the way RoBERTa is implemented in Transformers with a small hack: its toknizer has 512 + 2 position embeddings, not 512. When you run your command, the model is randomly initialized with 512 position embeddings (the default in the config) but you still use it with that robert-base tokenizer which returns up to 514. This results in an index error that throws the "device-side assert triggered".

To fix this, you need to either use another tokenizer, or prepare your random model like this:

from transformers import RobertaForMaskedLM, RobertaConfig

model = RobertaForMaskedLM(RobertaConfig(max_position_embeddings=514))
model.save_pretrained("model_dir")

then use model_dir for --model_name_or_path when launching your script.

You can also tweak the script directly to add max_position_embeddings=514 in this line.

@matteomedioli
Copy link
Author

Thank you! Now it works! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants