Skip to content
This repository has been archived by the owner on Mar 14, 2024. It is now read-only.

Fixing bug introduced by PyTorch 1.10 #246

Closed
wants to merge 2 commits into from

Conversation

tmarkovich
Copy link
Contributor

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Motivation and Context / Related issue

This fixes issue #245.

How Has This Been Tested (if it applies)

I've confirmed that I've been able to train on PyTorch 1.10 using this fix.

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 19, 2022
Copy link
Contributor

@adamlerer adamlerer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Would it be possible to make this small change? I want to confirm that the error is really because of the expand rather than because of new_zeros.

@@ -147,15 +147,15 @@ def forward(
if weight is not None:
loss_per_sample = F.cross_entropy(
scores,
pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device),
pos_scores.new_zeros((num_pos, ), dtype=torch.long),

reduction="none",
)
match_shape(weight, num_pos)
loss_per_sample = loss_per_sample * weight
else:
loss_per_sample = F.cross_entropy(
scores,
pos_scores.new_zeros((), dtype=torch.long).expand(num_pos),
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device),
pos_scores.new_zeros((num_pos, ), dtype=torch.long),

@lw
Copy link
Contributor

lw commented Jan 20, 2022

This change is undoing a perf optimization: the code before was carefully written to allocate only one element in memory, and then adapt the tensor's indexing so that it looks like that element is replicated many times. The new code however allocates multiple elements in memory. I don't know if the impact is significant, but it would certainly be better to avoid it.

If there has been a regression in PyTorch between 1.9 and 1.10 that should be reported upstream and ideally fixed there, instead of working around it in our code.

@tmarkovich
Copy link
Contributor Author

@lw I agree in principal, but the timeframe on which the upstream bug will be fixed. I've also observed almost no performance degradation in practice with the above change -- with 16 A100 GPUs I'm able to train at 46m edges/second with pytorch 1.9 and new_zeros or pytorch 1.10 and the above change.

@adamlerer I did find that the issue is down to the expand and not the difference between zeros and new_zeros.

@lw with that the above in mind, what is the perf difference between new_zeros((), ...).expand(npos) and new_zeros((npos, ))?

Screen Shot 2022-01-20 at 9 34 47 AM

Anyway, I'll report this bug upstream to pytorch as well.

@lw
Copy link
Contributor

lw commented Jan 20, 2022

what is the perf difference between new_zeros((), ...).expand(npos) and new_zeros((npos, ))?

It's more of a memory gain (and, secondarily, the time gain to allocate/fill in that memory). But admittedly it's probably barely significant, since if I remember correctly we're talking about allocating 4 extra bytes for each input edge of a batch (rather than just 4 bytes once for the entire batch).

If you've confirmed that the expand is indeed the root cause of the bug and that there's no observable difference after this fix, then I'm fine with it. Thanks for reporting the bug upstream anyways!

@tmarkovich
Copy link
Contributor Author

I'll go ahead and change to new_zeros((npos, )) and then add an issue to revert the fix once the upstream bug is fixed (with linked issue).

In the mean time, here's the upstream issue:
pytorch/pytorch#71550

@adamlerer
Copy link
Contributor

@lw you're right that there's a (minor) performance implication to this, but I think it's overshadowed by the benefit of PyTorch 1.10 compatibility (even if this is fixed in a future version, we ideally want compatibility across a range of pytorch versions).

@facebook-github-bot
Copy link
Contributor

@adamlerer has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@ngimel
Copy link

ngimel commented Jan 20, 2022

It is indeed fixed in point release (which everyone should be using instead of 1.10) https://github.com/pytorch/pytorch/pull/69617/files#diff-13bb6989036251a34e2a9f7bd28349761c20950940cd22300a72ffc860296ed0R285, but it's fixed in a way to undo your perf optimization inside cross-entropy implementation, by calling contiguous() on your expanded tensor.

@tmarkovich
Copy link
Contributor Author

I just wanted to check in -- it looks like this was imported into the Meta internal repo -- did it get mainlined externally? If so, should we close this PR? And if not, should I make any changes?

@adamlerer
Copy link
Contributor

It is indeed fixed in point release (which everyone should be using instead of 1.10) https://github.com/pytorch/pytorch/pull/69617/files#diff-13bb6989036251a34e2a9f7bd28349761c20950940cd22300a72ffc860296ed0R285, but it's fixed in a way to undo your perf optimization inside cross-entropy implementation, by calling contiguous() on your expanded tensor.

What is "point release" @ngimel ?

@adamlerer
Copy link
Contributor

Sorry about the delay @tmarkovich , this PR should merge today.

@ngimel
Copy link

ngimel commented Feb 7, 2022

What is "point release" @ngimel ?

1.10.1 or 1.10.2 (it's fixed in both).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants