Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2s trainer] fix DP mode #8823

Merged
merged 8 commits into from
Nov 30, 2020
Merged

[s2s trainer] fix DP mode #8823

merged 8 commits into from
Nov 30, 2020

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Nov 27, 2020

This PR:

@patrickvonplaten, @sgugger

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean! Thanks

examples/seq2seq/seq2seq_trainer.py Outdated Show resolved Hide resolved
examples/seq2seq/test_finetune_trainer.py Outdated Show resolved Hide resolved
@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

moving the discussion out of the review commentary as it disappears as soon as it's resolved, so it's best to discuss it in the normal comments as this is what this PR is trying to solve.


Oh, I see - thank you for catching that. So I didn't solve the actual problem, but had a luck of hiding it under the carpet.

The problem is that the distributed=... is wrong here - it is currently coded to expect ddp when distributed==True and not dp. dp doesn't have get_world_size()/etc and so it fails, so should that arg be called dpp instead of distributed? But in any case the correct solution is then:

                self.train_dataset.make_sortish_sampler(
                    self.args.per_device_train_batch_size, distributed=self.args.local_rank != -1)

or re-coded to handle dp too? I don't know the initial intention - should it support sortish_sampler under dp or not?

we need to know whether to:

  1. recode make_sortish_sampler to support dp (can't use get_world_size()/etc)
  2. recode make_sortish_sampler to change its distributed arg to dpp, so that it only does the special case for dpp.

And somewhat unrelated to the actual bug, I'd like to repeat the request at #8822 - let's have a simple flag so that the downstream code knows which mode it is under and not via checking ranks and n_gpus which is very confusing and error-prone.

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

Here is where the problem happens with dp:

class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()

So dist.is_available() returns True under dp, but dist.get_world_size() fails, since it only works under dpp and requires torch.distributed.init_process_group() which doesn't get called under dp.

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2020

In DataParallel mode, you don't need to do anything to your datalaoder (only in DistributedDataParallel where you need to split the batches across the various processes somehow) so you should make a regular datalaoder in that case.
In general, the only proper way to detect if you are in distributed data parallel is to look at the test local_rank != -1 as torch.distributed can give you false information there. I agree it would all be much easier if the training arguments contained something that directly gives the distributed environment.

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

In DataParallel mode, you don't need to do anything to your datalaoder (only in DistributedDataParallel where you need to split the batches across the various processes somehow) so you should make a regular datalaoder in that case.

Great, so then should we change the signature to make it clear ddp is wanted and not any distributed:

- def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
+ def make_sortish_sampler(self, batch_size, ddp=False, shuffle=True, **kwargs):

and adjust the invocations accordingly?

In general, the only proper way to detect if you are in distributed data parallel is to look at the test local_rank != -1 as torch.distributed can give you false information there. I agree it would all be much easier if the training arguments contained something that directly gives the distributed environment.

Great. Should we create a feature request for that?

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2020

I think there is a misunderstanding on the terminology: DataParallel is not distributed: distributed means launching several processes with the same script. The package torch.distributed does not return anything useful for DataParallel and ddp stands for distributed data parallel, so leaving that argument as distributed seems better to me.

Great. Should we create a feature request for that?

We can do that, yes.

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

If you stick to the specific implementation, yes, dpp is the only distributed mode. But logically it doesn't make sense. DP is just as distributed as DPP, just isn't using the torch.distributed, so it's not a very clear distinction and will lead to such confusions all over.

As an example if you look at this function usage pattern it's mostly dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) which clearly implies for any multi gpu mode (and erroneously so).

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2020

I disagree, in the sense that code use PyTorch should stick with the PyTorch naming conventions. They chose to have a not distributed DataParallel, so we should honor that in our naming as well. In Distributed data parallel, you have to use a DistributedSampler (but not in DataParallel) etc. Those are all parallel modes (as you're training with multiple GPUs) but only one is distributed.

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

That is a reasonable choice to follow. I'm only flagging how this leads to coding errors when a developer assumes that n_gpu> 1 == ddp. So perhaps some extra support is needed there.

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2020

Let's see how it goes once we add the "distributed_env" to TrainingArguments!

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

@sgugger, please kindly review at your convenience - I addressed all the issues you have raised - all should be good - CI failures are unrelated. Thank you!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks a lot for humoring me and my annoying comments :-)

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

Perfect, thanks a lot for humoring me and my annoying comments :-)

On the contrary, your comments were excellent and to the point.

I was just slow on getting your point of view since in my mind if we solve a problem on multiple gpus it's distributed across multiple-gpus, regardless of the way it's implemented. But here distributed means distributed across multiple processes. Different semantics.

@stas00 stas00 merged commit 7f34d75 into huggingface:master Nov 30, 2020
@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

So this is probably wrong too:

# examples/seq2seq/finetune.py:  
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)

But that's code base on PL.

@patil-suraj, may be you could have a look when you start working at this one? I suspect that it should do a different check for distributed and not check the number of gpus. Let me know if you prefer that I open a separate issue.

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2020

Dunno how PL works.

@stas00
Copy link
Contributor Author

stas00 commented Nov 30, 2020

Let's see how it goes once we add the "distributed_env" to TrainingArguments!

Added a feature request: #8858

@rabeehk
Copy link

rabeehk commented Dec 1, 2020

Thank you HuggingFace Team and @stas00 , I cannot express how much I appreciate your efforts.

stas00 added a commit to stas00/transformers that referenced this pull request Dec 5, 2020
* fix DP case on multi-gpu

* make executable

* test all 3 modes

* use the correct check for distributed

* dp doesn't need a special case

* restore original name

* cleanup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants