diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 2fad54fffe12..15f1ba036f4c 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -65,12 +65,6 @@ class ModelArguments: distill_teacher: Optional[str] = field( default=None, metadata={"help": "Teacher model which needs to be a trained QA model"} ) - distill_temperature: Optional[float] = field( - default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."} - ) - distill_hardness: Optional[float] = field( - default=0.5, metadata={"help": "Proportion of loss coming from teacher model."} - ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -592,8 +586,6 @@ def compute_metrics(p: EvalPrediction): model_args.model_name_or_path, [existing_recipe, new_recipe], teacher=teacher_model, - distill_hardness=model_args.distill_hardness, - distill_temperature=model_args.distill_temperature, model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, diff --git a/examples/pytorch/question-answering/sparseml_utils.py b/examples/pytorch/question-answering/sparseml_utils.py index 0f84c562e6af..5ffec49dce33 100644 --- a/examples/pytorch/question-answering/sparseml_utils.py +++ b/examples/pytorch/question-answering/sparseml_utils.py @@ -1,8 +1,6 @@ from typing import Any import numpy -import torch -import torch.nn.functional as F from sparseml.pytorch.utils import ModuleExporter from trainer_qa import QuestionAnsweringTrainer @@ -28,46 +26,24 @@ def compute_loss(self, model, inputs, return_outputs=False): if not self.recipes or self.teacher is None: return super().compute_loss(model, inputs, return_outputs=return_outputs) - outputs = model(**inputs) - if self.teacher is None: - loss = outputs["loss"] - else: - input_device = inputs["input_ids"].device - self.teacher = self.teacher.to(input_device) - start_logits_student = outputs["start_logits"] - end_logits_student = outputs["end_logits"] - start_logits_label = inputs["start_positions"] - end_logits_label = inputs["end_positions"] - with torch.no_grad(): - teacher_output = self.teacher( - input_ids=inputs["input_ids"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"], - ) - start_logits_teacher = teacher_output["start_logits"] - end_logits_teacher = teacher_output["end_logits"] - loss_start = ( - F.kl_div( - input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), - target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1), - reduction="batchmean", - ) - * (self.distill_temperature ** 2) - ) - loss_end = ( - F.kl_div( - input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1), - target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1), - reduction="batchmean", - ) - * (self.distill_temperature ** 2) - ) - teacher_loss = (loss_start + loss_end) / 2.0 - loss_start = self.criterion(start_logits_student, start_logits_label) - loss_end = self.criterion(end_logits_student, end_logits_label) - label_loss = (loss_start + loss_end) / 2.0 - loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) - return (loss, outputs) if return_outputs else loss + student_outputs = model(**inputs) + loss = student_outputs["loss"] + + teacher_input_keys = ["input_ids", "token_type_ids", "attention_mask"] + teacher_inputs = {k: inputs[k] for k in teacher_input_keys} + + steps_in_epoch = -1 # Unused + loss = self.manager.loss_update( + loss, + model, + self.optimizer, + self.state.epoch, + steps_in_epoch, + global_step=self.state.global_step, + student_outputs=student_outputs, + teacher_inputs=teacher_inputs, + ) + return (loss, student_outputs) if return_outputs else loss class QuestionAnsweringModuleExporter(ModuleExporter): diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index d617855d0e7a..e06f0a3766e6 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -27,15 +27,13 @@ class SparseMLTrainer(Trainer): :param args, kwargs: arguments passed into parent class """ - def __init__( - self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs - ): + def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs): super().__init__(*args, **kwargs) self.model_name_or_path = str(model_name_or_path) self.recipes = [recipe for recipe in recipes if recipe] self.teacher = teacher - self.distill_hardness = distill_hardness - self.distill_temperature = distill_temperature + if self.teacher is not None: + self.teacher.eval() self.criterion = torch.nn.CrossEntropyLoss() manager = None @@ -58,7 +56,7 @@ def apply_recipes(self, epoch=0.0): """ if self.manager is not None: org_state_dict = self.model.state_dict() - self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers) + self.manager.initialize(self.model, epoch=epoch, distillation_teacher=self.teacher, loggers=self.loggers) new_state_dict = self.model.state_dict() new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]