Fix GPU usage by removing device
from Transformers
class wrapper to use the device/device_map directly exposed by HF Transformers in kwargs
#569
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The issue
Using the
device
parameter already provided by theTransformers
wrapper doesn't work as intended. It first loads the model (by default on CPU) and then pushes the model to thedevice
later. This is inefficient and hasn't been working at all on my end.Example code snippet for the above:
The fix
Removing the
device
parameter altogether fixes this and removes the confusion of having adevice
parameter that isn't actually directly being used by HF Transformers but only by the wrapper itself.Also, it doesn't affect anything else since there is a
self.device = self.model_obj.device
that gets the device from the model directly. Also, the tests seem to pass.It is now only necessary to use
device_map
from HF Transformers directly throughkwargs
and the model is loaded correctly.Example code snippet for the above: