Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

awd-lstm pretrained=False crashes in distributed setup #2148

Closed
stas00 opened this issue Jun 6, 2019 · 3 comments
Closed

awd-lstm pretrained=False crashes in distributed setup #2148

stas00 opened this issue Jun 6, 2019 · 3 comments

Comments

@stas00
Copy link

stas00 commented Jun 6, 2019

Let's start with the workaround that fixes the issue:

--- a/fastai/distributed.py
+++ b/fastai/distributed.py
@@ -29,7 +29,7 @@ class DistributedTrainer(LearnerCallback):
         return old_dl,new_dl,sampler

     def on_train_begin(self, **kwargs):
-        self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)
+        self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id, find_unused_parameters=True)
         shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True
         self.old_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle)
         if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None:

added find_unused_parameters=True to DistributedDataParallel

but I have no idea what it does and whether it camouflages some other problem.

So language_model_learner(data_lm, AWD_LSTM, pretrained=False) runs fine in a single gpu mode, but crashes on distributed with:

Traceback (most recent call last):                                                                                                                 
  File "./mimic_lm_distr.py", line 69, in <module>
    seed:       Param("Random seed", int)=42,
  File "/mnt/nvme1/fast.ai-1/br/fastai/master/fastai/script.py", line 40, in call_parse
    func(**args.__dict__)
  File "./mimic_lm_distr.py", line 105, in main
    learn.fit_one_cycle(10, slice(1e-2), moms=moms)
  File "/mnt/nvme1/fast.ai-1/br/fastai/master/fastai/train.py", line 22, in fit_one_cycle
    learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
  File "/mnt/nvme1/fast.ai-1/br/fastai/master/fastai/basic_train.py", line 200, in fit
    fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
  File "/mnt/nvme1/fast.ai-1/br/fastai/master/fastai/basic_train.py", line 101, in fit
    loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
  File "/mnt/nvme1/fast.ai-1/br/fastai/master/fastai/basic_train.py", line 26, in loss_batch
    out = model(*xb)
  File "/home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 494, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 401, in forward
    self.reducer.prepare_for_backward([])
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable). (prepare_for_backward at /opt/conda/conda-bld/pytorch-nightly_1559452046329/work/torch/csrc/distributed/c10d/reducer.cpp:410)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fe928977265 in /home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: c10d::Reducer::prepare_for_backward(std::vector<torch::autograd::Variable, std::allocator<torch::autograd::Variable> > const&) + 0x61b (0x7fe9577cca1b in /home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #2: <unknown function> + 0x7116d8 (0x7fe9577c26d8 in /home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #3: <unknown function> + 0x216716 (0x7fe9572c7716 in /home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: _PyMethodDef_RawFastCallKeywords + 0x264 (0x55da2c35a6e4 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #5: _PyCFunction_FastCallKeywords + 0x21 (0x55da2c35a801 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #6: _PyEval_EvalFrameDefault + 0x537e (0x55da2c3b67ae in /home/stas/anaconda3/envs/fastai/bin/python)
frame #7: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #8: _PyFunction_FastCallDict + 0x1d5 (0x55da2c2f85d5 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #9: _PyObject_Call_Prepend + 0x63 (0x55da2c30fc43 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #10: PyObject_Call + 0x6e (0x55da2c30495e in /home/stas/anaconda3/envs/fastai/bin/python)
frame #11: _PyEval_EvalFrameDefault + 0x1e20 (0x55da2c3b3250 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #12: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #13: _PyFunction_FastCallDict + 0x1d5 (0x55da2c2f85d5 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #14: _PyObject_Call_Prepend + 0x63 (0x55da2c30fc43 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #15: <unknown function> + 0x17116a (0x55da2c35216a in /home/stas/anaconda3/envs/fastai/bin/python)
frame #16: PyObject_Call + 0x6e (0x55da2c30495e in /home/stas/anaconda3/envs/fastai/bin/python)
frame #17: _PyEval_EvalFrameDefault + 0x1e20 (0x55da2c3b3250 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #18: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #19: _PyFunction_FastCallKeywords + 0x325 (0x55da2c3599c5 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x416 (0x55da2c3b1846 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #21: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #22: _PyFunction_FastCallKeywords + 0x387 (0x55da2c359a27 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #23: _PyEval_EvalFrameDefault + 0x14ce (0x55da2c3b28fe in /home/stas/anaconda3/envs/fastai/bin/python)
frame #24: _PyEval_EvalCodeWithName + 0xbb9 (0x55da2c2f7db9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #25: _PyFunction_FastCallKeywords + 0x387 (0x55da2c359a27 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x14ce (0x55da2c3b28fe in /home/stas/anaconda3/envs/fastai/bin/python)
frame #27: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #28: _PyFunction_FastCallKeywords + 0x387 (0x55da2c359a27 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x14ce (0x55da2c3b28fe in /home/stas/anaconda3/envs/fastai/bin/python)
frame #30: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #31: _PyFunction_FastCallDict + 0x400 (0x55da2c2f8800 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #32: _PyEval_EvalFrameDefault + 0x1e20 (0x55da2c3b3250 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #33: _PyFunction_FastCallKeywords + 0xfb (0x55da2c35979b in /home/stas/anaconda3/envs/fastai/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x416 (0x55da2c3b1846 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #35: _PyEval_EvalCodeWithName + 0x2f9 (0x55da2c2f74f9 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #36: PyEval_EvalCodeEx + 0x44 (0x55da2c2f83c4 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #37: PyEval_EvalCode + 0x1c (0x55da2c2f83ec in /home/stas/anaconda3/envs/fastai/bin/python)
frame #38: <unknown function> + 0x22f874 (0x55da2c410874 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #39: PyRun_FileExFlags + 0xa1 (0x55da2c41ab81 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #40: PyRun_SimpleFileExFlags + 0x1c3 (0x55da2c41ad73 in /home/stas/anaconda3/envs/fastai/bin/python)
frame #41: <unknown function> + 0x23ae5f (0x55da2c41be5f in /home/stas/anaconda3/envs/fastai/bin/python)
frame #42: _Py_UnixMain + 0x3c (0x55da2c41bf7c in /home/stas/anaconda3/envs/fastai/bin/python)
frame #43: __libc_start_main + 0xe7 (0x7fe9714d9b97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #44: <unknown function> + 0x1e0122 (0x55da2c3c1122 in /home/stas/anaconda3/envs/fastai/bin/python)

I followed the RuntimeError error instructions and added find_unused_parameters=True as it suggested. But instead of reporting unused params, it just worked. The distributed training worked.

I don't know yet anything about this argument, I hope perhaps you do. If not I will investigate tomorrow.

note that this problem doesn't exist with pretrained=True.

this is all with git master.

I can also make a script to reproduce the problem if it helps, really just the staple LM like the lesson, but pretrained=False and run with python -m torch.distributed.launch --nproc_per_node=2 ./script.py.

Thank you!

p.s. looks like find_unused_parameters=True was added some time in pytorch 1.2.0.dev2 (i.e. not in 1.0.1.post2)


env:

=== Software ===
python        : 3.7.3
fastai        : 1.0.53.dev0
fastprogress  : 0.1.21
torch         : 1.2.0.dev20190602
nvidia driver : 418.56
torch cuda    : 10.0.130 / is available
torch cudnn   : 7501 / is enabled

=== Hardware ===
nvidia gpus   : 2
torch devices : 2
  - gpu0      : 12212MB | GeForce GTX TITAN X
  - gpu1      : 12212MB | GeForce GTX TITAN X

=== Environment ===
platform      : Linux-4.15.0-51-generic-x86_64-with-debian-buster-sid
distro        : Ubuntu 18.04 bionic
conda env     : fastai
python        : /home/stas/anaconda3/envs/fastai/bin/python
sys.path      : /mnt/nvme1/fast.ai-1/br/fastai/master
/home/stas/anaconda3/envs/fastai/lib/python37.zip
/home/stas/anaconda3/envs/fastai/lib/python3.7
/home/stas/anaconda3/envs/fastai/lib/python3.7/lib-dynload
/home/stas/.local/lib/python3.7/site-packages
/home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages
/mnt/nvme1/fast.ai-1/br/stas00/ipyexperiments
/mnt/nvme1/fast.ai-1/br/fastai/master
/mnt/nvme1/fast.ai-1/br/fastai/fastprogress
@sgugger
Copy link
Contributor

sgugger commented Jun 6, 2019

It may be linked to the hack around weight dropout but I can't be sure.

@sgugger sgugger closed this as completed Jun 6, 2019
@sgugger sgugger reopened this Jun 6, 2019
@stas00
Copy link
Author

stas00 commented Jun 8, 2019

I had a chance to look at this new argument find_unused_parameters. It looks like a bad naming, as it actually workarounds against the situation where some params don't participate in the loss calculation and "brings them into the fold" https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L195

So I guess, instead of the developer hunting down some "lost" params, it finds a way to involve all params.

Now I'm going to look specifically at WeightDropout as you suggested might be the culprit.

But, regardless of the outcome, it looks like my "fix" is a legit workaround for a time being.

I will post more as I gain more understanding.

@stas00
Copy link
Author

stas00 commented Jun 12, 2019

I know you don't like hanging outstanding issues, so I will close this one for now and update it if I discover something new to add.

@stas00 stas00 closed this as completed Jun 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants