-
Notifications
You must be signed in to change notification settings - Fork 545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
explicitly move tensors to one device #2054
Conversation
* Explicitly move tensors to one device as torch.cat no longer moves tensors (pytorch/pytorch#35045).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm let's make the linter happy also
In case the matrix is type CatLazyTensor, the returned matrix/row can be on different devices. So move the delazified object to the the matrix output device.
@@ -76,7 +76,8 @@ def apply_permutation( | |||
right_permutation = torch.arange(matrix.size(-1), device=matrix.device) | |||
|
|||
# Apply permutations | |||
return delazify(matrix.__getitem__((*batch_idx, left_permutation.unsqueeze(-1), right_permutation.unsqueeze(-2)))) | |||
res = delazify(matrix.__getitem__((*batch_idx, left_permutation.unsqueeze(-1), right_permutation.unsqueeze(-2)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is too long also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about that, still getting used to pull requests and pre-commit. Don't think the problem was with this line in permutation.py
(this one should be 118 long). I think the issue was with my first commit because I resolved the changes, but did not actually push a fix.
fixed line too long flake8 error
Hmm I am not sure about the "build" checks - seems like these aren't actually running? Are they some duplicates of the "run test suite" test? I can't merge w/o those and I'm not sure how to run them. @gpleiss maybe you have an idea? |
Explicitly move tensors to one device before calling torch.cat as torch.cat no longer moves tensors silently (pytorch/pytorch#35407). This solves one of the issues in #2053 resulting when the partition size is set to 0.