-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add zero-shot classification #121
Conversation
Awesome! I think ideally we should follow the same format as image/text classification, so
We use |
{labels, hypothesis} = Enum.unzip(labels_and_hypothesis) | ||
|
||
all_inputs = | ||
Bumblebee.apply_tokenizer(tokenizer, Enum.zip(prompts, hypothesis), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to note is that here a single input becomes multiple inputs in the batch (each {prompt, hypothesis} pair). We could group them by adding an additional leading axis, however the number of hypothesis may be different for every input in the batch, so we can't really do that. Probably just documenting it is the way to go?
For Stable Diffusion if someone sets num_images_per_prompt: 2
, the number of sub-inputs is fixed and we treat it as a single member of the batch (by adding a leading axis). I wonder if we should instead treat num_images_per_prompt: 2
as two inputs, so for batch size of 1 we would generate each image separately. This way if someone sets 4 images and batch size of 1 it would take more time, but don't blow up the memory, and they can set batch size to 4 to generate at once. This decoupling gives more control to the user.
@seanmor5 @josevalim thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FTR we decided to stick to the current approach, that is :batch_size
always refers to the serving input. It may be inflated by options such as num_images_pre_prompt
, or the number of labels
in this case.
af1d257
to
ab2bf84
Compare
@seanmor5 I fixed the post processing to apply softmax per batch on the entailment label. I also adjusted the assertions based on this: from transformers import pipeline
p = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", candidate_labels=["cooking", "traveling", "dancing"])
p("one day I will see the world") |
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
b543247
to
f8d44e5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add tests tomorrow when I get the multi-prompt case down, but right now it seems to be working fine:
Btw, is there a reason we hid documentation for
TokenClassification
?