Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nn.DataParallel compatibility in PyTorch 1.5 #4300

Merged
merged 8 commits into from
May 19, 2020
Merged

Conversation

julien-c
Copy link
Member

@julien-c julien-c commented May 12, 2020

As reported in #3936

PyTorch 1.5 changes the handling of nn.Parameters in DataParallel replicas. (pytorch/pytorch#33907). The replicas now don't have parameters() anymore.

This PR updates our self.device and self.dtype helpers to mimic nn.Module's parameters() helper but for attributes, i.e. recursively look for an attribute of type Tensor.

Reformer and TransfoXL seem to be doing fancy things based on the module's Parameters so I didn't attempt to fix them.

Finally I'm introducing a multigpu CI flag. CI does not currently run on multiple GPUs so remember to run it locally.


Also pinging @ngimel the author of the change in PyTorch, to check if I'm doing something stupid here.

@julien-c julien-c changed the title Test case for #3936 Fix nn.DataParallel compatibility in PyTorch 1.5 May 13, 2020
@julien-c julien-c marked this pull request as ready for review May 13, 2020 23:36
@codecov-io
Copy link

codecov-io commented May 13, 2020

Codecov Report

Merging #4300 into master will decrease coverage by 0.00%.
The diff coverage is 91.66%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4300      +/-   ##
==========================================
- Coverage   78.21%   78.21%   -0.01%     
==========================================
  Files         120      120              
  Lines       20038    20040       +2     
==========================================
+ Hits        15673    15674       +1     
- Misses       4365     4366       +1     
Impacted Files Coverage Δ
src/transformers/modeling_xlnet.py 75.73% <75.00%> (ø)
src/transformers/modeling_albert.py 77.20% <100.00%> (ø)
src/transformers/modeling_bert.py 88.82% <100.00%> (ø)
src/transformers/modeling_t5.py 83.66% <100.00%> (ø)
src/transformers/modeling_utils.py 90.98% <100.00%> (+0.02%) ⬆️
src/transformers/file_utils.py 73.44% <0.00%> (-0.42%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7cb203f...7eef4f5. Read the comment docs.

@julien-c
Copy link
Member Author

Update: Also asked about this on shared PyTorch Slack channel

@ngimel
Copy link

ngimel commented May 14, 2020

I don't see anything obviously wrong, but I'm not very familiar with your codebase. You are looking for tensor attributes - are you sure that you don't have any other tensor attributes that don't correspond to former parameters?
Finally, we also have _former_parameters in pytorch, introduced here pytorch/pytorch#36523. You may find it useful, but we can't guarantee that it will remain as a stable API.

@@ -110,11 +110,31 @@ def reset_memory_hooks_state(self):

@property
def device(self) -> device:
return next(self.parameters()).device
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a great idea to have this as a property here! :-)

@@ -41,7 +41,7 @@ class CTRLModelTester(object):
def __init__(
self,
parent,
batch_size=13,
batch_size=14,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an even number of batch_size ? Or is it because it's an unlucky number? :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for some (but not all for some reason) of the models, the batch size seems to need to be a multiple of the number of DataParallel replicas.

I didn’t investigate too much as to why.

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

tests/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

tests/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@julien-c julien-c merged commit 4c06893 into master May 19, 2020
@julien-c julien-c deleted the tests_multigpu branch May 19, 2020 00:34
@blacksph3re
Copy link

I think you forgot GPT-2:
transformers/modeling_gpt2.py", line 464, in forward attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility StopIteration

@julien-c
Copy link
Member Author

Yes, if I remember correctly I didn't try to remove all the calls to next(self.parameters()) in this PR – do you want to open a PR to fix this?

guhur added a commit to guhur/transformers that referenced this pull request Oct 9, 2020
LysandreJik pushed a commit that referenced this pull request Oct 9, 2020
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
RuntimeRacer added a commit to RuntimeRacer/Real-Time-Voice-Cloning that referenced this pull request Dec 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants