From 78324ea0671e1bdb0ad476ea3ee6fa6a8120a660 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Sun, 3 Oct 2021 23:46:33 -0400 Subject: [PATCH 1/2] Use distillation modifier from SparseML --- examples/pytorch/question-answering/run_qa.py | 8 --- .../question-answering/sparseml_utils.py | 63 +++++++------------ src/transformers/sparse.py | 10 ++- 3 files changed, 27 insertions(+), 54 deletions(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 2fb4ce3523a0..70d7e45f8107 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -65,12 +65,6 @@ class ModelArguments: distill_teacher: Optional[str] = field( default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} ) - distill_temperature: Optional[float] = field( - default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} - ) - distill_hardness: Optional[float] = field( - default=1.0, metadata={"help": "Proportion of loss coming from teacher model."} - ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -593,8 +587,6 @@ def compute_metrics(p: EvalPrediction): model_args.model_name_or_path, [existing_recipe, new_recipe], teacher=teacher_model, - distill_hardness=model_args.distill_hardness, - distill_temperature=model_args.distill_temperature, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index 0f84c562e6af..2e1129931527 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F -from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.utils import ModuleExporter, device_of from trainer_qa import QuestionAnsweringTrainer from transformers.modeling_outputs import QuestionAnsweringModelOutput from transformers.sparse import SparseMLTrainer @@ -28,46 +28,29 @@ def compute_loss(self, model, inputs, return_outputs=False): if not self.recipes or self.teacher is None: return super().compute_loss(model, inputs, return_outputs=return_outputs) - outputs = model(**inputs) - if self.teacher is None: - loss = outputs["loss"] - else: - input_device = inputs["input_ids"].device - self.teacher = self.teacher.to(input_device) - start_logits_student = outputs["start_logits"] - end_logits_student = outputs["end_logits"] - start_logits_label = inputs["start_positions"] - end_logits_label = inputs["end_positions"] - with torch.no_grad(): - teacher_output = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) - start_logits_teacher = teacher_output["start_logits"] - end_logits_teacher = teacher_output["end_logits"] - loss_start = ( - F.kl_div( - input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), - target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1), - reduction="batchmean", - ) - * (self.distill_temperature ** 2) - ) - loss_end = ( - F.kl_div( - input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1), - target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1), - reduction="batchmean", - ) - * (self.distill_temperature ** 2) + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + target_device = device_of(inputs) + self.teacher.to(target_device) + with torch.no_grad(): + teacher_outputs = self.teacher( + input_ids=inputs["input_ids"], + token_type_ids=inputs["token_type_ids"], + attention_mask=inputs["attention_mask"], ) - teacher_loss = (loss_start + loss_end) / 2.0 - loss_start = self.criterion(start_logits_student, start_logits_label) - loss_end = self.criterion(end_logits_student, end_logits_label) - label_loss = (loss_start + loss_end) / 2.0 - loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) - return (loss, outputs) if return_outputs else loss + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_outputs=teacher_outputs, + ) + return (loss, student_outputs) if return_outputs else loss class QuestionAnsweringModuleExporter(ModuleExporter): diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index a2a6c2497d87..405dd0783fdf 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -26,15 +26,13 @@ class SparseMLTrainer(Trainer): :param args, kwargs: arguments passed into parent class """ - def __init__( - self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs - ): + def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs): super().__init__(*args, **kwargs) self.model_name_or_path = str(model_name_or_path) self.recipes = [recipe for recipe in recipes if recipe] self.teacher = teacher - self.distill_hardness = distill_hardness - self.distill_temperature = distill_temperature + if self.teacher is not None: + self.teacher.eval() self.criterion = torch.nn.CrossEntropyLoss() manager = None @@ -57,7 +55,7 @@ def apply_recipes(self, epoch=0.0): """ if self.manager is not None: org_state_dict = self.model.state_dict() - self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) + self.manager.initialize(self.model, epoch=epoch, distillation_teacher=self.teacher, loggers=self.loggers) new_state_dict = self.model.state_dict() new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] From 6cacf3464bf2eabb66315076a3e8bb06a601cc36 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 6 Oct 2021 13:35:23 -0400 Subject: [PATCH 2/2] Move teacher model's logic to modifier --- .../question-answering/sparseml_utils.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index 2e1129931527..5ffec49dce33 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -1,10 +1,8 @@ from typing import Any import numpy -import torch -import torch.nn.functional as F -from sparseml.pytorch.utils import ModuleExporter, device_of +from sparseml.pytorch.utils import ModuleExporter from trainer_qa import QuestionAnsweringTrainer from transformers.modeling_outputs import QuestionAnsweringModelOutput from transformers.sparse import SparseMLTrainer @@ -31,14 +29,9 @@ def compute_loss(self, model, inputs, return_outputs=False): student_outputs = model(**inputs) loss = student_outputs["loss"] - target_device = device_of(inputs) - self.teacher.to(target_device) - with torch.no_grad(): - teacher_outputs = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) + teacher_input_keys = ["input_ids", "token_type_ids", "attention_mask"] + teacher_inputs = {k: inputs[k] for k in teacher_input_keys} + steps_in_epoch = -1 # Unused loss = self.manager.loss_update( loss, @@ -48,7 +41,7 @@ def compute_loss(self, model, inputs, return_outputs=False): steps_in_epoch, global_step=self.state.global_step, student_outputs=student_outputs, - teacher_outputs=teacher_outputs, + teacher_inputs=teacher_inputs, ) return (loss, student_outputs) if return_outputs else loss