Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ class ModelArguments:
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}
)
distill_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
)
distill_hardness: Optional[float] = field(
default=0.5, metadata={"help": "Proportion of loss coming from teacher model."}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -592,8 +586,6 @@ def compute_metrics(p: EvalPrediction):
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
distill_hardness=model_args.distill_hardness,
distill_temperature=model_args.distill_temperature,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
Expand Down
60 changes: 18 additions & 42 deletions examples/pytorch/question-answering/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any

import numpy
import torch
import torch.nn.functional as F

from sparseml.pytorch.utils import ModuleExporter
from trainer_qa import QuestionAnsweringTrainer
Expand All @@ -28,46 +26,24 @@ def compute_loss(self, model, inputs, return_outputs=False):
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

outputs = model(**inputs)
if self.teacher is None:
loss = outputs["loss"]
else:
input_device = inputs["input_ids"].device
self.teacher = self.teacher.to(input_device)
start_logits_student = outputs["start_logits"]
end_logits_student = outputs["end_logits"]
start_logits_label = inputs["start_positions"]
end_logits_label = inputs["end_positions"]
with torch.no_grad():
teacher_output = self.teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
start_logits_teacher = teacher_output["start_logits"]
end_logits_teacher = teacher_output["end_logits"]
loss_start = (
F.kl_div(
input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1),
target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)
loss_end = (
F.kl_div(
input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1),
target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)
teacher_loss = (loss_start + loss_end) / 2.0
loss_start = self.criterion(start_logits_student, start_logits_label)
loss_end = self.criterion(end_logits_student, end_logits_label)
label_loss = (loss_start + loss_end) / 2.0
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
return (loss, outputs) if return_outputs else loss
student_outputs = model(**inputs)
loss = student_outputs["loss"]

teacher_input_keys = ["input_ids", "token_type_ids", "attention_mask"]
teacher_inputs = {k: inputs[k] for k in teacher_input_keys}

steps_in_epoch = -1 # Unused
loss = self.manager.loss_update(
loss,
model,
self.optimizer,
self.state.epoch,
steps_in_epoch,
global_step=self.state.global_step,
student_outputs=student_outputs,
teacher_inputs=teacher_inputs,
)
return (loss, student_outputs) if return_outputs else loss


class QuestionAnsweringModuleExporter(ModuleExporter):
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class SparseMLTrainer(Trainer):
:param args, kwargs: arguments passed into parent class
"""

def __init__(
self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs
):
def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name_or_path = str(model_name_or_path)
self.recipes = [recipe for recipe in recipes if recipe]
self.teacher = teacher
self.distill_hardness = distill_hardness
self.distill_temperature = distill_temperature
if self.teacher is not None:
self.teacher.eval()
self.criterion = torch.nn.CrossEntropyLoss()

manager = None
Expand All @@ -58,7 +56,7 @@ def apply_recipes(self, epoch=0.0):
"""
if self.manager is not None:
org_state_dict = self.model.state_dict()
self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers)
self.manager.initialize(self.model, epoch=epoch, distillation_teacher=self.teacher, loggers=self.loggers)
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

Expand Down