Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ModelArguments:
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
)
distill_hardness: Optional[float] = field(
default=1.0, metadata={"help": "Proportion of loss coming from teacher model."}
default=0.5, metadata={"help": "Proportion of loss coming from teacher model."}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
Expand Down Expand Up @@ -217,7 +217,6 @@ def __post_init__(self):
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down
58 changes: 55 additions & 3 deletions examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datasets import load_dataset, load_metric

import transformers
from sparseml_utils import GLUEModuleExporter, SparseMLGLUETrainer
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
Expand All @@ -35,11 +36,11 @@
EvalPrediction,
HfArgumentParser,
PretrainedConfig,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
)
from transformers.sparse import export_model, load_recipe, preprocess_state_dict
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

Expand Down Expand Up @@ -72,6 +73,19 @@ class DataTrainingArguments:
the command line.
"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
"for more information"
},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
num_exported_samples: Optional[int] = field(
default=20, metadata={"help": "Number of exported samples, default to 20"}
)
task_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
Expand Down Expand Up @@ -155,6 +169,9 @@ class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained text classification model"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -305,6 +322,12 @@ def main():
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

# Load and preprocess the state dict if the model existed (in this case we continue to train or
# evaluate the model). The preprocessing step is to restore names of parameters changed by
# QAT process
state_dict = preprocess_state_dict(model_args.model_name_or_path)

config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
Expand All @@ -327,8 +350,19 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
state_dict=state_dict,
)

teacher_model = None
if model_args.distill_teacher is not None:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
model_args.distill_teacher,
from_tf=bool(".ckpt" in model_args.distill_teacher),
cache_dir=model_args.cache_dir,
)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)
# Preprocessing the datasets
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
Expand Down Expand Up @@ -445,17 +479,29 @@ def compute_metrics(p: EvalPrediction):
else:
data_collator = None

# Load possible existing recipe and new one passed in through command argument
existing_recipe = load_recipe(model_args.model_name_or_path)
new_recipe = data_args.recipe

# Initialize our Trainer
trainer = Trainer(
trainer = SparseMLGLUETrainer(
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
trainer.apply_recipes()

# Training
if training_args.do_train:
checkpoint = None
Expand Down Expand Up @@ -536,6 +582,12 @@ def compute_metrics(p: EvalPrediction):

trainer.push_to_hub(**kwargs)

if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
exporter = GLUEModuleExporter(model, output_dir=data_args.onnx_export_path)
export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
60 changes: 60 additions & 0 deletions examples/pytorch/text-classification/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any

import numpy

from sparseml.pytorch.utils import ModuleExporter
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.sparse import SparseMLTrainer


class SparseMLGLUETrainer(SparseMLTrainer):
"""
GLUE trainer with SparseML integration

:param recipe: recipe for model sparsification
:param teacher: teacher model for distillation
:param distill_hardness: ratio of loss by teacher targets (between 0 and 1)
:param distill_temperature: temperature for distillation
:param args, kwargs: arguments passed into parent class
"""

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

student_outputs = model(**inputs)
loss = student_outputs["loss"]

teacher_input_keys = ["input_ids", "token_type_ids", "attention_mask"]
teacher_inputs = {k: inputs[k] for k in teacher_input_keys}

steps_in_epoch = -1 # Unused
loss = self.manager.loss_update(
loss,
model,
self.optimizer,
self.state.epoch,
steps_in_epoch,
global_step=self.state.global_step,
student_outputs=student_outputs,
teacher_inputs=teacher_inputs,
)
return (loss, student_outputs) if return_outputs else loss


class GLUEModuleExporter(ModuleExporter):
"""
Module exporter class for Sequence Classification
"""

@classmethod
def get_output_names(self, out: Any):
if not isinstance(out, SequenceClassifierOutput):
raise ValueError(f"Expected SequenceClassifierOutput, got {type(out)}")
expected = ["logits"]
if numpy.any([name for name in expected if name not in out]):
raise ValueError("Expected output names not found in model output")
return expected