New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to perform batch inference? #26061
Comments
See #24575 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hey there @ryanshrott @NielsRogge 👋 I've added a short section in our basic LLM tutorial page on how to do batched generation in this PR. Taken from the updated guide, here's an example: >>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained(
... "mistralai/Mistral-7B-v0.1", device_map="auto", load_in_4bit=True
... )
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
>>> tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
>>> model_inputs = tokenizer(
... ["A list of colors: red, blue", "Portugal is"], return_tensors="pt", padding=True
... ).to("cuda")
>>> generated_ids = model.generate(**model_inputs)
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
['A list of colors: red, blue, green, yellow, orange, purple, pink,',
'Portugal is a country in southwestern Europe, on the Iber'] |
@gante Thanks. Is this faster than running them in a loop? |
@ryanshrott yes, much faster when measured in thorughput! The caveat is that it requires slightly more memory from your hardware, and it will have a slightly higher latency |
Feature request
I want to pass a list of tests to model.generate.
text = "hey there"
inputs = tokenizer(text, return_tensors="pt").to(0)
out = model.generate(**inputs, max_new_tokens=184)
print(tokenizer.decode(out[0], skip_special_tokens=True))
Motivation
I want to do batch inference.
Your contribution
Testing
The text was updated successfully, but these errors were encountered: