Workaround for performance bug in PyTorch with subclassed tensors #3683
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #3682 by adding
CastToTensor
callback as a workaround for performance bug in PyTorch with subclassed tensors.Fixes 4cff258 by testing if learn.xb & learn.yb are tuples and applying the cast to
Tensor
if they are. (Unless I made a mistake in testing, b[:i] and b[i:] are tuples).Unlike 4cff258, this PR adds the workaround as a callback so callbacks which use the input tensor type still can before it is casted to
Tensor
for the training performance increase.It also allows turning off the
Tensor
casting by removing the callback should it ruin a workflow. Although anyone who does this is encouraged to reimplement their own custom callback which casts toTensor
to get the free training performance increase.Currently
CastToTensor.order
is right beforeMixedPrecision
.