Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script for distilling zero-shot classifier to more efficient student #10244

Merged
merged 18 commits into from Feb 18, 2021

Conversation

joeddav
Copy link
Contributor

@joeddav joeddav commented Feb 17, 2021

This PR introduces a script that provides a way to improve the speed and memory performance of a zero-shot classifier by training a more efficient student model from the zero-shot teacher's predictions over an unlabeled dataset.

For a given sequence, the zero-shot classification pipeline requires each possible label to be fed through the large NLI model separately. This requirement slows results considerably, particularly for tasks with a large number of classes K.

Given (1) an unlabeled corpus and (2) a set of candidate class names, this script allows a user to train a standard classification head with K output dimensions. The script generates a softmax distribution for the provided data & class names, and a student classifier is then fine-tuned on these proxy labels. The resulting student model can be used for classifying novel text instances over these K classes with an order-of-magnitude boost in inference speed in addition to decreased memory usage.

A teacher NLI model can be distilled to a student model by running distill_classifier.py like so:

python distill_classifier.py \
--data_file unlabeled_data.txt \
--class_names_file class_names.txt \
--output_dir ./distilled_model

A number of other args are provided as well, such as --teacher_name_or_path and --student_name_or_path for specifying the pre-trained student & teacher models to be used (by default roberta-large-mnli and distillbert-base-uncased) and --hypothesis_template for customizing the hypothesis template used by the teacher zero-shot model. The training is implemented via Trainer, so any TrainingArguments can be specified as well.

The resulting model can then be used trivially in a text classification pipeline or in any other way:

model = AutoModelForSequenceClassification.from_pretrained("./distilled_model")
tokenizer = AutoTokenizer.from_pretrained("./distilled_model")
distilled_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)

See the included README.md for more details and examples.

Soon I'll introduce a similar script for self-training an NLI model, boosting the model's performance after training on only unlabeled data, which model can then be subsequently distilled with this script like any NLI model.

Update: I also just added a link to a working colab notebook demo.

@joeddav joeddav added Distillation Related to model distillation Examples Which is related to examples in general labels Feb 17, 2021
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic that you're using the Trainer for that. Pinging Sylvain for review.

joeddav and others added 2 commits February 18, 2021 10:45
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great new example! I'm not sure it does support distributed training after reading everything so the PR should either be amended to support it or clearly indicate in the README it does not support it (same for TPUs).

@joeddav
Copy link
Contributor Author

joeddav commented Feb 18, 2021

@LysandreJik cool thanks for the feedback.

@sgugger Thanks, I added fp16 for the teacher predictions. It will also now throw an error if someone tries to run it w/ distributed or TPUs and I added a note in the readme about that as well. It can do multi-gpu though and will do so automatically if multiple GPUs are available on the machine, it just can't do multi-node.

@sgugger
Copy link
Collaborator

sgugger commented Feb 18, 2021

Yes I meant distributed multi-GPU. I did see it will use all GPUs available on the machine however :-)

Comment on lines +256 to +259
if training_args.local_rank != -1:
raise ValueError("Distributed training is not currently supported.")
if training_args.tpu_num_cores is not None:
raise ValueError("TPU acceleration is not currently supported.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

@joeddav joeddav merged commit c6fe175 into huggingface:master Feb 18, 2021
@joeddav joeddav deleted the zero-shot-distillation branch February 18, 2021 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Distillation Related to model distillation Examples Which is related to examples in general
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants