Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for performance bug in PyTorch with subclassed tensors #3683

Merged
merged 3 commits into from
Jun 10, 2022

Conversation

warner-benjamin
Copy link
Collaborator

@warner-benjamin warner-benjamin commented Jun 10, 2022

Resolves #3682 by adding CastToTensor callback as a workaround for performance bug in PyTorch with subclassed tensors.

Fixes 4cff258 by testing if learn.xb & learn.yb are tuples and applying the cast to Tensor if they are. (Unless I made a mistake in testing, b[:i] and b[i:] are tuples).

Unlike 4cff258, this PR adds the workaround as a callback so callbacks which use the input tensor type still can before it is casted to Tensor for the training performance increase.

It also allows turning off the Tensor casting by removing the callback should it ruin a workflow. Although anyone who does this is encouraged to reimplement their own custom callback which casts to Tensor to get the free training performance increase.

Currently CastToTensor.order is right before MixedPrecision.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@warner-benjamin
Copy link
Collaborator Author

I should have fixed the sync error, but if a nbdev _all_ isn't on one line, it won't be added to the module's __all__.

@jph00
Copy link
Member

jph00 commented Jun 10, 2022

Much better - thanks!

@jph00 jph00 merged commit 94edfc5 into fastai:master Jun 10, 2022
@jph00 jph00 added the bug label Jun 10, 2022
@warner-benjamin warner-benjamin deleted the subclass_speed_fix branch October 3, 2022 05:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

workaround pytorch subclass performance bug
2 participants