Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU usage by removing device from Transformers class wrapper to use the device/device_map directly exposed by HF Transformers in kwargs #569

Merged
merged 1 commit into from
Jan 4, 2024

Conversation

cpcdoy
Copy link
Contributor

@cpcdoy cpcdoy commented Jan 3, 2024

The issue

Using the device parameter already provided by the Transformers wrapper doesn't work as intended. It first loads the model (by default on CPU) and then pushes the model to the device later. This is inefficient and hasn't been working at all on my end.

Example code snippet for the above:

from guidance import models

# This is confusing since the user can think this is given directly to the wrapped HF transformers class
gpt = models.Transformers('gpt2', device='cuda') 

The fix

Removing the device parameter altogether fixes this and removes the confusion of having a device parameter that isn't actually directly being used by HF Transformers but only by the wrapper itself.

Also, it doesn't affect anything else since there is a self.device = self.model_obj.device that gets the device from the model directly. Also, the tests seem to pass.

It is now only necessary to use device_map from HF Transformers directly through kwargs and the model is loaded correctly.

Example code snippet for the above:

from guidance import models

gpt = models.Transformers('gpt2', device_map='auto') # `auto` or any other device map

@cpcdoy cpcdoy changed the title Remove device from Transformers class wrapper to use the device/device_map directly exposed by HF Transformers in kwargs Fix GPU usage by removing device from Transformers class wrapper to use the device/device_map directly exposed by HF Transformers in kwargs Jan 3, 2024
@slundberg
Copy link
Contributor

Thanks @cpcdoy ! This is a great fix. Once the unit tests run I'll go ahead and merge.

@slundberg slundberg merged commit e706971 into guidance-ai:main Jan 4, 2024
4 checks passed
@cpcdoy cpcdoy mentioned this pull request Jan 6, 2024
@cpcdoy
Copy link
Contributor Author

cpcdoy commented Jan 6, 2024

Fixes #536

@aadityabhatia
Copy link

I tried the following

mistral = models.TransformersChat("mistralai/Mistral-7B-Instruct-v0.2", device=0)
mistral = models.Transformers("mistralai/Mistral-7B-Instruct-v0.2", device=0)

Both of those statements resulted in:

TypeError: MistralForCausalLM.__init__() got an unexpected keyword argument 'device'

Is that expected? Running the latest version from this repo: pip install git+https://github.com/guidance-ai/guidance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants