-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Reader's device by default #1208
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's some overlap with #1182 where another util method get_device
was introduced.
Can you please check if we can actually get rid of one of them?
Well, the functionality of
which internally calls:
What do you think? |
Yep, seems like they are quite redundant. The only difference: one is returning a torch.device object, the other one a string. |
By using initialize_device_settings() instead of the get_device() method, the returned device is always a torch.device object. haystack/haystack/generator/transformers.py Line 132 in afee4f3
"The torch.device argument in functions can generally be substituted with a string." https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device In rare cases, FARM checks the given device against strings "cpu" or "gpu" but it then extracts the type of the device as follows: Let's use initialize_device_settings() whenever we need to know the device and let's use device.type if we need to check the device against a string. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for verifying! The changes look good, but can you please verify that the get_device()
method is not used anywhere else and if that's the case delete it?
It's not used. I deleted it now. |
Proposed changes:
init()
method now runsinitialize_device_settings(use_cuda=self.use_gpu)
to get Reader's deviceeval()
,eval_from_file()
, andcalibrate_confidence_scores()
but can be overwritten when calling these methodsNote that
train()
also callsinitialize_device_settings(use_cuda=use_gpu,use_amp=use_amp)
as before becausetrain()
can be initialized with other a different value foruse_gpu
than the Reader itself.closes #1137