Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only put tensors on a device #5223

Merged
merged 2 commits into from Jun 23, 2020
Merged

Only put tensors on a device #5223

merged 2 commits into from Jun 23, 2020

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Jun 23, 2020

Fix Trainer when users have inputs containing non-tensor values.

@codecov
Copy link

codecov bot commented Jun 23, 2020

Codecov Report

Merging #5223 into master will decrease coverage by 0.03%.
The diff coverage is 60.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5223      +/-   ##
==========================================
- Coverage   77.98%   77.95%   -0.04%     
==========================================
  Files         138      138              
  Lines       23839    23841       +2     
==========================================
- Hits        18592    18586       -6     
- Misses       5247     5255       +8     
Impacted Files Coverage Δ
src/transformers/trainer.py 39.62% <60.00%> (+0.04%) ⬆️
src/transformers/modeling_openai.py 79.51% <0.00%> (-1.39%) ⬇️
src/transformers/file_utils.py 76.42% <0.00%> (-0.39%) ⬇️
src/transformers/modeling_tf_utils.py 85.86% <0.00%> (-0.30%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c01480b...10ff478. Read the comment docs.

Comment on lines 576 to 577
inputs = {k: v.to(self.args.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

Copy link
Member

@julien-c julien-c Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this personally slightly hard to read.

Maybe:

for k, v in inputs.items():
    if isinstance(v, torch.Tensor):
        inputs[k] = v.to(self.args.device)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(otherwise, LGTM – we might want to document this in the function's type signature – or not)

Copy link
Collaborator Author

@sgugger sgugger Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually using list-comprehensions are faster, but the batches don't have a lot of keys so it probably doesn't matter.

Comment on lines 576 to 577
inputs = {k: v.to(self.args.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(otherwise, LGTM – we might want to document this in the function's type signature – or not)

@sgugger sgugger merged commit 9022ef0 into huggingface:master Jun 23, 2020
@sgugger sgugger deleted the small_fix branch June 23, 2020 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants