-
Notifications
You must be signed in to change notification settings - Fork 445
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Would it be possible to make this small change? I want to confirm that the error is really because of the expand
rather than because of new_zeros
.
torchbiggraph/losses.py
Outdated
@@ -147,15 +147,15 @@ def forward( | |||
if weight is not None: | |||
loss_per_sample = F.cross_entropy( | |||
scores, | |||
pos_scores.new_zeros((), dtype=torch.long).expand(num_pos), | |||
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device), | |
pos_scores.new_zeros((num_pos, ), dtype=torch.long), |
torchbiggraph/losses.py
Outdated
reduction="none", | ||
) | ||
match_shape(weight, num_pos) | ||
loss_per_sample = loss_per_sample * weight | ||
else: | ||
loss_per_sample = F.cross_entropy( | ||
scores, | ||
pos_scores.new_zeros((), dtype=torch.long).expand(num_pos), | ||
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.zeros((num_pos, ), dtype=torch.long, device=scores.device), | |
pos_scores.new_zeros((num_pos, ), dtype=torch.long), |
This change is undoing a perf optimization: the code before was carefully written to allocate only one element in memory, and then adapt the tensor's indexing so that it looks like that element is replicated many times. The new code however allocates multiple elements in memory. I don't know if the impact is significant, but it would certainly be better to avoid it. If there has been a regression in PyTorch between 1.9 and 1.10 that should be reported upstream and ideally fixed there, instead of working around it in our code. |
@lw I agree in principal, but the timeframe on which the upstream bug will be fixed. I've also observed almost no performance degradation in practice with the above change -- with 16 A100 GPUs I'm able to train at 46m edges/second with pytorch 1.9 and @adamlerer I did find that the issue is down to the @lw with that the above in mind, what is the perf difference between Anyway, I'll report this bug upstream to pytorch as well. |
It's more of a memory gain (and, secondarily, the time gain to allocate/fill in that memory). But admittedly it's probably barely significant, since if I remember correctly we're talking about allocating 4 extra bytes for each input edge of a batch (rather than just 4 bytes once for the entire batch). If you've confirmed that the |
I'll go ahead and change to In the mean time, here's the upstream issue: |
@lw you're right that there's a (minor) performance implication to this, but I think it's overshadowed by the benefit of PyTorch 1.10 compatibility (even if this is fixed in a future version, we ideally want compatibility across a range of pytorch versions). |
@adamlerer has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
It is indeed fixed in point release (which everyone should be using instead of 1.10) https://github.com/pytorch/pytorch/pull/69617/files#diff-13bb6989036251a34e2a9f7bd28349761c20950940cd22300a72ffc860296ed0R285, but it's fixed in a way to undo your perf optimization inside cross-entropy implementation, by calling |
I just wanted to check in -- it looks like this was imported into the Meta internal repo -- did it get mainlined externally? If so, should we close this PR? And if not, should I make any changes? |
What is "point release" @ngimel ? |
Sorry about the delay @tmarkovich , this PR should merge today. |
1.10.1 or 1.10.2 (it's fixed in both). |
Types of changes
Motivation and Context / Related issue
This fixes issue #245.
How Has This Been Tested (if it applies)
I've confirmed that I've been able to train on PyTorch 1.10 using this fix.
Checklist